diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index 0fb58623bdaf..468fffdd2df9 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -45,6 +45,8 @@ public: ConversionPatternRewriter &rewriter) const override { auto tensorType = cast(extractOp.getTensor().getType()); + if (!isa(tensorType.getElementType())) + return rewriter.notifyMatchFailure(extractOp, "unsupported type"); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir index 32d0fbea65b1..b69c2d0408d1 100644 --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { // ----- +// CHECK-LABEL: test_spirv_unsupported_type_index +func.func @test_spirv_unsupported_type_index(%a : index) { + %cst = arith.constant dense<[1, 2]> : tensor<2xindex> + // CHECK: tensor.extract + %extract = tensor.extract %cst[%a] : tensor<2xindex> + return +} + +// CHECK-LABEL: test_spirv_unsupported_type_i128 +func.func @test_spirv_unsupported_type_i128(%a : index) { + %cst = arith.constant dense<[1, 2]> : tensor<2xi128> + // CHECK: tensor.extract + %extract = tensor.extract %cst[%a] : tensor<2xi128> + return +} + +// ----- + //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===//