[mlir][TensorToSPIRV] Add type check for tensor.extract
in TensorToSPIRV (#107110)
This patch add a type check for `tensor.extract` in TensorToSPIRV. Only convert `tensor.extract` with supported element type. Fix #74466.
This commit is contained in:
parent
812c96e8b9
commit
f4b9839d6f
@ -45,6 +45,8 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
|
||||
|
||||
if (!isa<spirv::ScalarType>(tensorType.getElementType()))
|
||||
return rewriter.notifyMatchFailure(extractOp, "unsupported type");
|
||||
if (!tensorType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
|
||||
|
||||
|
@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_spirv_unsupported_type_index
|
||||
func.func @test_spirv_unsupported_type_index(%a : index) {
|
||||
%cst = arith.constant dense<[1, 2]> : tensor<2xindex>
|
||||
// CHECK: tensor.extract
|
||||
%extract = tensor.extract %cst[%a] : tensor<2xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_spirv_unsupported_type_i128
|
||||
func.func @test_spirv_unsupported_type_i128(%a : index) {
|
||||
%cst = arith.constant dense<[1, 2]> : tensor<2xi128>
|
||||
// CHECK: tensor.extract
|
||||
%extract = tensor.extract %cst[%a] : tensor<2xi128>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
x
Reference in New Issue
Block a user