[mlir][TensorToSPIRV] Add type check for tensor.extract in TensorToSPIRV (#107110)

This patch add a type check for `tensor.extract` in TensorToSPIRV.
Only convert `tensor.extract` with supported element type. Fix #74466.
This commit is contained in:
Longsheng Mou 2024-09-04 10:21:27 +08:00 committed by GitHub
parent 812c96e8b9
commit f4b9839d6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View File

@ -45,6 +45,8 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
if (!isa<spirv::ScalarType>(tensorType.getElementType()))
return rewriter.notifyMatchFailure(extractOp, "unsupported type");
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");

View File

@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// -----
// CHECK-LABEL: test_spirv_unsupported_type_index
func.func @test_spirv_unsupported_type_index(%a : index) {
%cst = arith.constant dense<[1, 2]> : tensor<2xindex>
// CHECK: tensor.extract
%extract = tensor.extract %cst[%a] : tensor<2xindex>
return
}
// CHECK-LABEL: test_spirv_unsupported_type_i128
func.func @test_spirv_unsupported_type_i128(%a : index) {
%cst = arith.constant dense<[1, 2]> : tensor<2xi128>
// CHECK: tensor.extract
%extract = tensor.extract %cst[%a] : tensor<2xi128>
return
}
// -----
//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//