diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 610ce1f13c56..78f0fe139296 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -449,8 +449,14 @@ struct RoundOpPattern final : public OpConversionPattern { return res; Location loc = roundOp.getLoc(); - Value operand = roundOp.getOperand(); - Type ty = operand.getType(); + auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType()); + if (!ty) { + return rewriter.notifyMatchFailure( + roundOp->getLoc(), + llvm::formatv("failed to convert type {0} for SPIR-V", + roundOp.getType())); + } + Type ety = getElementTypeOrSelf(ty); auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); @@ -466,14 +472,15 @@ struct RoundOpPattern final : public OpConversionPattern { rewriter.getFloatAttr(ety, 0.5)); } - auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand); + auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand()); auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); auto greater = spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); auto add = spirv::FAddOp::create(rewriter, loc, floor, select); - rewriter.replaceOpWithNewOp(roundOp, add, operand); + rewriter.replaceOpWithNewOp(roundOp, add, + adaptor.getOperand()); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir index b8e001c9f695..608abffd8bd8 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -257,6 +257,27 @@ func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> { return %0: vector<4xf32> } +// Unit dimensional vectors are converted to scalars by inserting +// unrealized_conversion_cast's. +// +// CHECK-LABEL: @round_vector_unit_dim +// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) -> vector<1xf32> +func.func @round_vector_unit_dim(%x: vector<1xf32>) -> vector<1xf32> { + // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<1xf32> to f32 + // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00 + // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 + // CHECK: %[[HALF:.+]] = spirv.Constant 5.000000e-01 + // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[CAST]] : f32 + // CHECK: %[[FLOOR:.+]] = spirv.GL.Floor %[[ABS]] + // CHECK: %[[SUB:.+]] = spirv.FSub %[[ABS]], %[[FLOOR]] + // CHECK: %[[GE:.+]] = spirv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]] + // CHECK: %[[SEL:.+]] = spirv.Select %[[GE]], %[[ONE]], %[[ZERO]] + // CHECK: %[[ADD:.+]] = spirv.FAdd %[[FLOOR]], %[[SEL]] + // CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ADD]] : f32 to i32 + %0 = math.round %x : vector<1xf32> + return %0: vector<1xf32> +} + } // end module // -----