[mlir][Math][SPIRV] fix math.round conversion for unit-dim vectors (#182067)

In SPIR-V, unit dimensional vectors, e.g. `vector<1xf32` are legalized
as scalars (vectors are 2, 3, 4, and possibly 8 and 16 dimensional).
This PR fixes the `math.round` conversion pattern to legalize these
vectors during conversion.

Co-authored-by: Ege Beysel <beysel@roofline.ai>

---------

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
Co-authored-by: Ege Beysel <beysel@roofline.ai>
This commit is contained in:
Artem Gindinson 2026-02-18 18:36:06 +01:00 committed by GitHub
parent fb40f2be55
commit 7be77495ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 4 deletions

View File

@ -449,8 +449,14 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
return res;
Location loc = roundOp.getLoc();
Value operand = roundOp.getOperand();
Type ty = operand.getType();
auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
if (!ty) {
return rewriter.notifyMatchFailure(
roundOp->getLoc(),
llvm::formatv("failed to convert type {0} for SPIR-V",
roundOp.getType()));
}
Type ety = getElementTypeOrSelf(ty);
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
@ -466,14 +472,15 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
rewriter.getFloatAttr(ety, 0.5));
}
auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
auto greater =
spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add,
adaptor.getOperand());
return success();
}
};

View File

@ -257,6 +257,27 @@ func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
return %0: vector<4xf32>
}
// Unit dimensional vectors are converted to scalars by inserting
// unrealized_conversion_cast's.
//
// CHECK-LABEL: @round_vector_unit_dim
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) -> vector<1xf32>
func.func @round_vector_unit_dim(%x: vector<1xf32>) -> vector<1xf32> {
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<1xf32> to f32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00
// CHECK: %[[HALF:.+]] = spirv.Constant 5.000000e-01
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[CAST]] : f32
// CHECK: %[[FLOOR:.+]] = spirv.GL.Floor %[[ABS]]
// CHECK: %[[SUB:.+]] = spirv.FSub %[[ABS]], %[[FLOOR]]
// CHECK: %[[GE:.+]] = spirv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
// CHECK: %[[SEL:.+]] = spirv.Select %[[GE]], %[[ONE]], %[[ZERO]]
// CHECK: %[[ADD:.+]] = spirv.FAdd %[[FLOOR]], %[[SEL]]
// CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ADD]] : f32 to i32
%0 = math.round %x : vector<1xf32>
return %0: vector<1xf32>
}
} // end module
// -----