From d5f7acdbc15fd15244bf6f3e4d4e3ea5a7bd2781 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Tue, 31 Mar 2026 14:31:04 +0200 Subject: [PATCH] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (#189028) This patch introduces the following operators: spirv.Tosa.Cast spirv.Tosa.Rescale Also dialect and serialization round-trip tests have been added. Signed-off-by: Davide Grohmann --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 ++ .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 133 +++++++++++++ .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 21 ++ .../SPIRV/IR/tosa-ops-verification.mlir | 184 ++++++++++++++++++ mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 37 ++++ mlir/test/Target/SPIRV/tosa-ops.mlir | 61 ++++++ mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 1 + 7 files changed, 448 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 9f9e2f5f9a67..11a91958d748 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4985,4 +4985,15 @@ def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr< I32EnumAttrCase<"Ignore", 2>, ]>; +// NOTE: This is an attribute in the SPIR-V *dialect* but a constant () in +// SPIR-V proper. +def SPIRV_TosaExtRoundingModeAttr : SPIRV_I32EnumAttr< + "TosaExtRoundingModeType", "Tosa Ext Rounding Mode Type", + "tosa_ext_rounding_mode_type", + [ + I32EnumAttrCase<"SingleRound", 1>, + I32EnumAttrCase<"InexactRound", 2>, + I32EnumAttrCase<"DoubleRound", 3>, + ]>; + #endif // MLIR_DIALECT_SPIRV_IR_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td index 7fc7f8647849..7fdd00ef2e03 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td @@ -2648,4 +2648,137 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure, } +def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure, + AllShapesMatch<["input", "output"]>]> { + let summary = "Cast operation."; + + let description = [{ + Casts a tensor from one data type to another. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast + + #### Example: + ```mlir + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32> + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16> + ``` + }]; + + let arguments = (ins + SPIRV_TosaAny_TensorArm: $input + ); + + let results = (outs + SPIRV_TosaAny_TensorArm: $output + ); + + let assemblyFormat = [{ + $input + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + }]; +} + + +def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect, + AllShapesMatch<["input", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["output", "output_zp"]>, + ElementTypeMatchesScale32<"multiplier">, + TensorLengthMatchesPerChannel<"multiplier">, + TensorLengthMatchesPerChannel<"shift">, + TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>, + TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>, + TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>, + TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>, + BoolAttrTypeConstraintImplicationOn<"input_unsigned", "input", [I8, I16]>, + BoolAttrTypeConstraintImplicationOn<"output_unsigned", "input", [I8, I16]>, + BoolAttrTypeConstraintImplicationOn<"input_unsigned", "output", [I8, I16]>, + BoolAttrTypeConstraintImplicationOn<"output_unsigned", "output", [I8, I16]>]> { + let summary = "Rescale operator."; + + let description = [{ + Rescale is defined using an integer multiply, add, and shift. + + Rescale supports two precisions of multiplier: 16-bit and 32-bit. The + 32-bit multiplier version supports two rounding modes to enable simpler + lowering of existing frameworks that use two stage rounding. All arithmetic + is designed so that it does not overflow a 64-bit accumulator and that the + result fits in 32 bits. In particular, a 48-bit value (represented as a + 64-bit value in SPIR-V) cannot be scaled with the 32-bit multiplier because + the accumulator would need to have 80 bits. + + The shift and value range are limited to allow a variety of implementations. + The limit of 62 on shift allows the shift to be decomposed as two right + shifts of 31. + + Unsigned 8- and 16-bit values are only allowed in the Rescale operation, + to allow for compatibility with networks which expect unsigned 8-bit or + 16-bit tensors for input and output. + + Undefined behaviour may occur if the calculated result underflows or overflows + their integer ranges. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rescale + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale + + #### Example: + ```mlir + %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %multiplier, %shift, %input_zp, %output_zp : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16> + ``` + }]; + + let arguments = (ins + SPIRV_BoolConstAttr: $scale32, + SPIRV_TosaExtRoundingModeAttr: $rounding_mode, + SPIRV_BoolConstAttr: $per_channel, + SPIRV_BoolConstAttr: $input_unsigned, + SPIRV_BoolConstAttr: $output_unsigned, + SPIRV_TosaInteger_TensorArm: $input, + SPIRV_Int16OrInt32_TensorArm1D: $multiplier, + SPIRV_Int8_TensorArm1D: $shift, + SPIRV_TosaInteger_1DTensorArmOfLength1: $input_zp, + SPIRV_TosaInteger_1DTensorArmOfLength1: $output_zp + ); + + let results = (outs + SPIRV_TosaInteger_TensorArm: $output + ); + + let assemblyFormat = [{ + `scale32` `=` $scale32 `,` + `rounding_mode` `=` $rounding_mode `,` + `per_channel` `=` $per_channel `,` + `input_unsigned` `=` $input_unsigned `,` + `output_unsigned` `=` $output_unsigned `,` + $input `,` + $multiplier `,` + $shift `,` + $input_zp `,` + $output_zp + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getMultiplierType() { + return cast<::mlir::spirv::TensorArmType>(getMultiplier().getType()); + } + ::mlir::spirv::TensorArmType getShiftType() { + return cast<::mlir::spirv::TensorArmType>(getShift().getType()); + } + }]; +} + + #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td index f116c4dcdd49..1fc3bdad48e7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td @@ -38,6 +38,8 @@ class TensorArmRankOf allowedTypes, list ranks> [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">; +def SPIRV_Int8_TensorArm1D : TensorArmRankOf<[SPIRV_Int8], [1]>; +def SPIRV_Int16OrInt32_TensorArm1D : TensorArmRankOf<[SPIRV_Int16, SPIRV_Int32], [1]>; def SPIRV_Int32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>; def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>; def SPIRV_TosaInteger_TensorArm1D : TensorArmRankOf<[SPIRV_TosaInteger], [1]>; @@ -92,6 +94,7 @@ def SPIRV_Int32_1DTensorArmOfLength1To6Attr : ConfinedAttr< I32ElementsAttr, [SPIRV_DenseElementAttrsWithTensorArmType, Is1DTensorArmAttrOfLength<[1, 2, 3, 4, 5, 6]>]>; def SPIRV_Int8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>; +def SPIRV_TosaInteger_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaInteger]>; def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>; def SPIRV_TosaAny_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaAny]>; @@ -121,6 +124,13 @@ class TypeConstraintImplicationOn, !foreach(allowedType, allowedTypes, ElementTypeIsPred)>>; +class BoolAttrTypeConstraintImplicationOn allowedTypes>: + PredOpTrait<"if " # boolAttr # " is true then " # + other # " must have a type in [" # + !interleave(!foreach(type, allowedTypes, type.summary), ",") # "]", + Implies.ret # "()" >, + !foreach(allowedType, allowedTypes, ElementTypeIsPred)>>; + class AxisValueLessThanRankOf: PredOpTrait<"axis attribute value should be lower than rank(" # input # ")", Implies.result>, [CPred<"getAxis() < " # Rank.result>]>>; @@ -201,5 +211,16 @@ class VariadicInputAllSameRank: " && ::llvm::cast<::mlir::ShapedType>(t).getRank() == " # Rank.result # "; })">>; +class ElementTypeMatchesScale32 : + PredOpTrait($" # tensor # ".getType()).getElementType()." + "isInteger(getScale32() ? 32 : 16)">>; + +class TensorLengthMatchesPerChannel : + PredOpTrait($" # tensor # ".getType()).getShape()[0] == " + "(getPerChannel() ? " + "::llvm::cast<::mlir::ShapedType>($input.getType()).getRank() - 1 : 1)">>; + #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir index f95fedba7430..f981fe785363 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir @@ -1940,3 +1940,187 @@ spirv.ARM.Graph @resize_bf16_input_output_element_type_must_be_bf16(%arg0: !spir %4 = spirv.Tosa.Resize mode = , %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xbf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32> spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Cast +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x5xi32>) { + // expected-error @+1 {{op failed to verify that all of {input, output} have same shape}} + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x5xi32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x5xi32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Rescale +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @rescale_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x5xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that all of {input, output} have same shape}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x5xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x5xi16> +} + +spirv.ARM.Graph @rescale_input_and_input_zp_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that all of {input, input_zp} have same element type}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_output_and_output_zp_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that all of {output, output_zp} have same element type}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_scale32_true_requires_i32_multiplier(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that multiplier must have element type i32 when scale32 is true, otherwise i16}} + %5 = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_scale32_false_requires_i16_multiplier(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that multiplier must have element type i32 when scale32 is true, otherwise i16}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16> + %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_per_channel_false_requires_multiplier_length_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16> + %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_per_channel_false_requires_shift_length_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_i8_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi64>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64> + // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64> +} + +spirv.ARM.Graph @rescale_i16_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi64>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64> + // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64> +} + +spirv.ARM.Graph @rescale_i32_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi64>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64> + // expected-error @+1 {{op failed to verify that if input has type 32-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64> +} + +spirv.ARM.Graph @rescale_i64_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi64>) -> (!spirv.arm.tensor<2x3x4xi64>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64> + // expected-error @+1 {{op failed to verify that if input has type 64-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64> +} + +spirv.ARM.Graph @rescale_input_unsigned_true_requires_i8_or_i16_input(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that if input_unsigned is true then input must have a type in [8-bit signless integer,16-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = true, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_output_unsigned_true_requires_i8_or_i16_input(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi16>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op failed to verify that if output_unsigned is true then input must have a type in [8-bit signless integer,16-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +} + +spirv.ARM.Graph @rescale_input_unsigned_true_requires_i8_or_i16_output(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi32>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op failed to verify that if input_unsigned is true then output must have a type in [8-bit signless integer,16-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = true, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<2x3x4xi32> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi32> +} + +spirv.ARM.Graph @rescale_output_unsigned_true_requires_i8_or_i16_output(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi32>) { + %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op failed to verify that if output_unsigned is true then output must have a type in [8-bit signless integer,16-bit signless integer]}} + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<2x3x4xi32> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi32> +} diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir index b10724b16a84..e1fcf2e3ec05 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir @@ -1123,3 +1123,40 @@ spirv.ARM.Graph @resize_fp(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv. // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x753x297x63xf32> spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Cast - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @cast_int(%arg0: !spirv.arm.tensor<1x65538x1x2xi8>) -> (!spirv.arm.tensor<1x65538x1x2xi32>) { + // CHECK: {{%.*}} = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32> + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65538x1x2xi32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x65538x1x2xi32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Cast - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @cast_fp(%arg0: !spirv.arm.tensor<11x5x14x4xf32>) -> (!spirv.arm.tensor<11x5x14x4xf16>) { + // CHECK: {{%.*}} = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16> + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x5x14x4xf16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x5x14x4xf16> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Rescale - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @rescale_int(%arg0: !spirv.arm.tensor<17x29x19xi16>) -> (!spirv.arm.tensor<17x29x19xi16>) { + %5 = spirv.Constant dense<1866149760> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<31> : !spirv.arm.tensor<1xi8> + %7 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %8 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16> + %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %5, %6, %7, %8 : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<17x29x19xi16> + spirv.ARM.GraphOutputs %9 : !spirv.arm.tensor<17x29x19xi16> +} diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir index ebc6a290a9dc..47e819636e98 100644 --- a/mlir/test/Target/SPIRV/tosa-ops.mlir +++ b/mlir/test/Target/SPIRV/tosa-ops.mlir @@ -1964,3 +1964,64 @@ spirv.module Logical Vulkan requires #spirv.vce } } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Cast - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @cast_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @cast_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @cast_int, @cast_int_arg_0, @cast_int_res_0 + spirv.ARM.Graph @cast_int(%arg0: !spirv.arm.tensor<1x65538x1x2xi8>) -> (!spirv.arm.tensor<1x65538x1x2xi32>) { + // CHECK: {{%.*}} = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32> + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65538x1x2xi32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x65538x1x2xi32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Cast - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @cast_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @cast_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @cast_fp, @cast_fp_arg_0, @cast_fp_res_0 + spirv.ARM.Graph @cast_fp(%arg0: !spirv.arm.tensor<11x5x14x4xf32>) -> (!spirv.arm.tensor<11x5x14x4xf16>) { + // CHECK: {{%.*}} = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16> + %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x5x14x4xf16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x5x14x4xf16> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Rescale - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @rescale_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @rescale_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @rescale_int, @rescale_int_arg_0, @rescale_int_res_0 + spirv.ARM.Graph @rescale_int(%arg0: !spirv.arm.tensor<17x29x19xi16>) -> (!spirv.arm.tensor<17x29x19xi16>) { + %5 = spirv.Constant dense<1866149760> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<31> : !spirv.arm.tensor<1xi8> + %7 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %8 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16> + %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = , per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %5, %6, %7, %8 : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<17x29x19xi16> + spirv.ARM.GraphOutputs %9 : !spirv.arm.tensor<17x29x19xi16> + } +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 2344465f7521..8caeb079aceb 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -504,6 +504,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = { "SPIRV_TosaExtAccTypeAttr", "SPIRV_TosaExtResizeModeAttr", "SPIRV_TosaExtNaNPropagationModeAttr", + "SPIRV_TosaExtRoundingModeAttr", "SPIRV_QuadSwapDirectionAttr", };