From d03f30fb522eda7dedd15fd83ba1077d14826e0f Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Mon, 26 May 2025 13:52:19 +0200 Subject: [PATCH] [mlir][TOSA] restore unrealized casts when lowering rescale ops (#141096) Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective. The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal `arith` casts that are constrained to signless types. Whether the `arith` casts themselves are signed or unsigned should still depend on the rescale's `*_unsigned` attribute values. --------- Signed-off-by: Artem Gindinson --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 15 ++ .../TosaToLinalg/tosa-to-linalg.mlir | 165 ++++++++++++++---- 2 files changed, 149 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 0b69cd2814fb..6d73f23e2aae 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1492,6 +1492,15 @@ public: : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; + if (valueTy.isUnsignedInteger()) { + value = nestedBuilder + .create( + nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) + .getResult(0); + } if (valueTy.getIntOrFloatBitWidth() < 32) { if (op.getInputUnsigned()) { value = nestedBuilder.create( @@ -1537,6 +1546,12 @@ public: value); } + if (outIntType.isUnsignedInteger()) { + value = nestedBuilder + .create(nestedLoc, + outIntType, value) + .getResult(0); + } nestedBuilder.create(loc, value); }); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 185f1973ecdc..fb912e49ff92 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1152,16 +1152,50 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () { // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @rescale_i8_unsigned_output +// CHECK-LABEL: @rescale_i8_unsigned_output_explicit // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () { +func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () { + // CHECK: [[C0:%.+]] = arith.constant 19689 + // CHECK: [[C1:%.+]] = arith.constant 15 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>) + // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8): + // CHECK-DAG: [[C17:%.+]] = arith.constant 17 + // CHECK-DAG: [[C234:%.+]] = arith.constant 234 + // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]] + // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]] + // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 + // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 + // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] + // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] + // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8 + // CHECK: linalg.yield [[TRUNC_ITOU]] + %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> + %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8> + + // CHECK: return + return +} + +// ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_i8_unsigned_output_implicit +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () { // CHECK: [[C0:%.+]] = arith.constant 19689 // CHECK: [[C1:%.+]] = arith.constant 15 // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): - // CHECK: [[C17:%.+]] = arith.constant 17 - // CHECK: [[C234:%.+]] = arith.constant 234 + // CHECK-DAG: [[C17:%.+]] = arith.constant 17 + // CHECK-DAG: [[C234:%.+]] = arith.constant 234 // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]] // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]] // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} @@ -1169,14 +1203,48 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () { // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] + // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] + // CHECK-NOT: builtin.unrealized_conversion_cast [[TRUNC]] + // CHECK: linalg.yield [[TRUNC]] + %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> + %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + + // CHECK: return + return +} + +// ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_i48_unsigned_output_implicit +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () { + // CHECK: [[C19689:%.+]] = arith.constant 19689 + // CHECK: [[C15:%.+]] = arith.constant 15 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>) + // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8): + // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] + // CHECK-DAG: [[C0:%.+]] = arith.constant 0 + // CHECK-DAG: [[C234:%.+]] = arith.constant 234 + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]] + // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]] + // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 + // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 + // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK: linalg.yield [[TRUNC]] %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> - %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48> %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> - %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8> // CHECK: return return @@ -1230,19 +1298,52 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () { } // ----- - // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @rescale_i8_unsigned_input +// CHECK-LABEL: @rescale_i8_unsigned_input_explicit // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () { +func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () { + // CHECK: [[C0:%.+]] = arith.constant 19689 + // CHECK: [[C1:%.+]] = arith.constant 15 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>) + // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8): + // CHECK-DAG: [[C17:%.+]] = arith.constant 17 + // CHECK-DAG: [[C22:%.+]] = arith.constant 22 + // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8 + // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]] + // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]] + // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128 + // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127 + // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] + // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] + // CHECK: linalg.yield [[TRUNC]] + %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> + %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + + return +} + +// ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_i8_unsigned_input_implicit +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () { // CHECK: [[C0:%.+]] = arith.constant 19689 // CHECK: [[C1:%.+]] = arith.constant 15 // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): - // CHECK: [[C128:%.+]] = arith.constant 128 - // CHECK: [[C22:%.+]] = arith.constant 22 + // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] + // CHECK-DAG: [[C128:%.+]] = arith.constant 128 + // CHECK-DAG: [[C22:%.+]] = arith.constant 22 // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]] // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]] // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} @@ -1265,32 +1366,34 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () { // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @rescale_i48_unsigned_output +// CHECK-LABEL: @rescale_i8_unsigned_input_output_explicit // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () { - // CHECK: [[C19689:%.+]] = arith.constant 19689 - // CHECK: [[C15:%.+]] = arith.constant 15 +func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> () { + // CHECK: [[C0:%.+]] = arith.constant 19689 + // CHECK: [[C1:%.+]] = arith.constant 15 // CHECK: [[INIT:%.+]] = tensor.empty() - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>) - // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8): - // CHECK: [[C0:%.+]] = arith.constant 0 - // CHECK: [[C234:%.+]] = arith.constant 234 - // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]] - // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"} - // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]] - // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 - // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xui8>) + // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: ui8): + // CHECK-DAG: [[C17:%.+]] = arith.constant 17 + // CHECK-DAG: [[C22:%.+]] = arith.constant 22 + // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8 + // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]] + // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]] + // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128 + // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127 // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] - // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] - // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] - // CHECK: linalg.yield [[TRUNC]] + // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] + // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8 + // CHECK: linalg.yield [[TRUNC_ITOU]] %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> - %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48> - %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> - %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8> + %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8> - // CHECK: return return }