[mlir][Math][SPIRV] fix math.round conversion for unit-dim vectors (#182067)
In SPIR-V, unit dimensional vectors, e.g. `vector<1xf32` are legalized as scalars (vectors are 2, 3, 4, and possibly 8 and 16 dimensional). This PR fixes the `math.round` conversion pattern to legalize these vectors during conversion. Co-authored-by: Ege Beysel <beysel@roofline.ai> --------- Signed-off-by: Artem Gindinson <gindinson@roofline.ai> Co-authored-by: Ege Beysel <beysel@roofline.ai>
This commit is contained in:
parent
fb40f2be55
commit
7be77495ca
@ -449,8 +449,14 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
|
||||
return res;
|
||||
|
||||
Location loc = roundOp.getLoc();
|
||||
Value operand = roundOp.getOperand();
|
||||
Type ty = operand.getType();
|
||||
auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
|
||||
if (!ty) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
roundOp->getLoc(),
|
||||
llvm::formatv("failed to convert type {0} for SPIR-V",
|
||||
roundOp.getType()));
|
||||
}
|
||||
|
||||
Type ety = getElementTypeOrSelf(ty);
|
||||
|
||||
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
|
||||
@ -466,14 +472,15 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
|
||||
rewriter.getFloatAttr(ety, 0.5));
|
||||
}
|
||||
|
||||
auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
|
||||
auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
|
||||
auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
|
||||
auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
|
||||
auto greater =
|
||||
spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
|
||||
auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
|
||||
auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
|
||||
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
|
||||
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add,
|
||||
adaptor.getOperand());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -257,6 +257,27 @@ func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
|
||||
return %0: vector<4xf32>
|
||||
}
|
||||
|
||||
// Unit dimensional vectors are converted to scalars by inserting
|
||||
// unrealized_conversion_cast's.
|
||||
//
|
||||
// CHECK-LABEL: @round_vector_unit_dim
|
||||
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) -> vector<1xf32>
|
||||
func.func @round_vector_unit_dim(%x: vector<1xf32>) -> vector<1xf32> {
|
||||
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<1xf32> to f32
|
||||
// CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
|
||||
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00
|
||||
// CHECK: %[[HALF:.+]] = spirv.Constant 5.000000e-01
|
||||
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[CAST]] : f32
|
||||
// CHECK: %[[FLOOR:.+]] = spirv.GL.Floor %[[ABS]]
|
||||
// CHECK: %[[SUB:.+]] = spirv.FSub %[[ABS]], %[[FLOOR]]
|
||||
// CHECK: %[[GE:.+]] = spirv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
|
||||
// CHECK: %[[SEL:.+]] = spirv.Select %[[GE]], %[[ONE]], %[[ZERO]]
|
||||
// CHECK: %[[ADD:.+]] = spirv.FAdd %[[FLOOR]], %[[SEL]]
|
||||
// CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ADD]] : f32 to i32
|
||||
%0 = math.round %x : vector<1xf32>
|
||||
return %0: vector<1xf32>
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user