llvm-project/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
Hanumanth 81964597f9
[mlir][tensor] Fix runtime verification for tensor.extract_slice for empty tensor slices (#166569)
I hit another runtime verification issue (similar to
https://github.com/llvm/llvm-project/pull/164878) while working with
TFLite models. The verifier is incorrectly rejecting
`tensor.extract_slice` operations when extracting an empty slice
(size=0) that starts exactly at the tensor boundary.

The current runtime verification unconditionally enforces `offset <
dim_size`. This makes sense for non-empty slices, but it's too strict
for empty slices, causing false positives that lead to spurious runtime
assertions.

**Simple example that demonstrates the issue:**

```mlir
func.func @extract_empty_slice(%tensor: tensor<?xf32>, %offset: index, %size: index) {
  // When called with: tensor size=10, offset=10, size=0
  // Runtime verification fails: "offset 0 is out-of-bounds"
  %slice = tensor.extract_slice %tensor[%offset] [%size] [1] 
    : tensor<?xf32> to tensor<?xf32>
  return
}
```

For the above example, the check evaluates `10 < 10` which is false, so
verification fails. However, I believe this operation should be valid -
we're extracting zero elements, so there's no actual out-of-bounds
access.

**Real-world repro from the TensorFlow Lite models:**

This issue manifests while lowering TFLite models and a lot of our
system tests are failing due to this. Here's a simplified version
showing the problematic pattern:

In this code, `%extracted_slice_0` becomes an empty tensor when SSA
value `%15` reaches 10 (on the final loop iteration), making `%16 = 0`.
The operation extracts zero elements along dimension 0, which is
semantically valid but fails runtime verification.

```mlir
func.func @simplified_repro_from_tensorflowlite_model(%arg0: tensor<10x4x1xf32>) -> tensor<10x4x1xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c10 = arith.constant 10 : index
  %c-1 = arith.constant -1 : index
  
  %0 = "tosa.const"() <{values = dense<0> : tensor<i32>}> : () -> tensor<i32>
  %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
  %2 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
  %3 = "tosa.const"() <{values = dense<-1> : tensor<2xi32>}> : () -> tensor<2xi32>
  %4 = "tosa.const"() <{values = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
  %5 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x4x1xf32>}> : () -> tensor<1x4x1xf32>
  %c4_1 = tosa.const_shape  {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
  
  %6:2 = scf.while (%arg1 = %0, %arg2 = %arg0) 
    : (tensor<i32>, tensor<10x4x1xf32>) -> (tensor<i32>, tensor<10x4x1xf32>) {
    %7 = tosa.greater %2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %extracted = tensor.extract %7[] : tensor<i1>
    scf.condition(%extracted) %arg1, %arg2 : tensor<i32>, tensor<10x4x1xf32>
  } do {
  ^bb0(%arg1: tensor<i32>, %arg2: tensor<10x4x1xf32>):
    %7 = tosa.add %arg1, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
    
    // First slice
    %8 = tosa.reshape %arg1, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
    %9 = tosa.concat %8, %3 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
    
    %extracted_0 = tensor.extract %9[%c0] : tensor<3xi32>
    %10 = index.casts %extracted_0 : i32 to index
    %11 = arith.cmpi eq, %10, %c-1 : index
    %12 = arith.select %11, %c10, %10 : index
    
    %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [%12, 4, 1] [1, 1, 1] 
      : tensor<10x4x1xf32> to tensor<?x4x1xf32>
    
    // Second slice - this is where the failure occurs
    %13 = tosa.reshape %7, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
    %14 = tosa.concat %13, %4 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
    
    %extracted_1 = tensor.extract %14[%c0] : tensor<3xi32>
    %15 = index.castu %extracted_1 : i32 to index
    %16 = arith.subi %c10, %15 : index  // size = 10 - offset
    
    %extracted_2 = tensor.extract %14[%c1] : tensor<3xi32>
    %17 = index.castu %extracted_2 : i32 to index
    
    %extracted_3 = tensor.extract %14[%c2] : tensor<3xi32>
    %18 = index.castu %extracted_3 : i32 to index
    
    // On the last loop iteration: %15=10, %16=0
    // %extracted_slice_0 becomes an empty tensor
    // Runtime verification fails: "offset 0 is out-of-bounds"
    %extracted_slice_0 = tensor.extract_slice %arg2[%15, %17, %18] [%16, 4, 1] [1, 1, 1] 
      : tensor<10x4x1xf32> to tensor<?x4x1xf32>
    
    %19 = tosa.concat %extracted_slice, %5, %extracted_slice_0 {axis = 0 : i32} 
      : (tensor<?x4x1xf32>, tensor<1x4x1xf32>, tensor<?x4x1xf32>) -> tensor<10x4x1xf32>
    
    scf.yield %7, %19 : tensor<i32>, tensor<10x4x1xf32>
  }
  
  return %6#1 : tensor<10x4x1xf32>
}
```
**The fix:**

Make the offset check conditional on slice size:
- Empty slice (size == 0): allow `0 <= offset <= dim_size`
- Non-empty slice (size > 0): require `0 <= offset < dim_size`


**Question for reviewers:**
Should we also relax the static verifier to allow this edge case?
Currently, the static verifier rejects the following IR:

```mlir
%tensor = arith.constant dense<1.0> : tensor<10xf32>
%slice = tensor.extract_slice %tensor[10] [0] [1] : tensor<10xf32> to tensor<0xf32>
```
Since we're allowing it at runtime for dynamic shapes, it seems
inconsistent to reject it statically. However, I wanted to get feedback
before making that change - this PR focuses only on the runtime
verification fix for dynamic shapes.

P.S. We have a similar issue with `memref.subview`. I will send a
separate patch for the issue.

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2025-11-12 08:37:15 +09:00

257 lines
11 KiB
C++

//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
namespace mlir {
namespace tensor {
namespace {
/// Generate a runtime check for lb <= value < ub.
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
Value lb, Value ub) {
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value, lb);
Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, value, ub);
Value inBounds =
builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
return inBounds;
}
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<TensorType>(castOp.getSource().getType());
// Nothing to check if the result is an unranked tensor.
auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
if (!resultType)
return;
if (isa<UnrankedTensorType>(srcType)) {
// Check rank.
Value srcRank = RankOp::create(builder, loc, castOp.getSource());
Value resultRank =
arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
Value isSameRank = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
cf::AssertOp::create(builder, loc, isSameRank,
generateErrorMessage(op, "rank mismatch"));
}
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
// Static dim size -> static/dynamic dim size does not need verification.
if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
if (!rankedSrcType.isDynamicDim(it.index()))
continue;
// Static/dynamic dim size -> dynamic dim size does not need verification.
if (resultType.isDynamicDim(it.index()))
continue;
Value srcDimSz =
DimOp::create(builder, loc, castOp.getSource(), it.index());
Value resultDimSz =
arith::ConstantIndexOp::create(builder, loc, it.value());
Value isSameSz = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
cf::AssertOp::create(
builder, loc, isSameSz,
generateErrorMessage(op, "size mismatch of dim " +
std::to_string(it.index())));
}
}
};
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto dimOp = cast<DimOp>(op);
Value rank = RankOp::create(builder, loc, dimOp.getSource());
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
cf::AssertOp::create(
builder, loc,
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
generateErrorMessage(op, "index is out of bounds"));
}
};
/// Verifies that the indices on extract/insert ops are in-bounds of the
/// tensor's index space: 0 <= index#i < dim#i
template <typename OpTy>
struct ExtractInsertOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractInsertOpInterface<OpTy>, OpTy> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto extractInsertOp = cast<OpTy>(op);
Value tensor;
if constexpr (std::is_same_v<OpTy, ExtractOp>) {
tensor = extractInsertOp.getTensor();
} else if constexpr (std::is_same_v<OpTy, InsertOp>) {
tensor = extractInsertOp.getDest();
} else {
llvm_unreachable("invalid op");
}
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto rank = tensorType.getRank();
if (rank == 0) {
// Nothing to check for 0-d tensors.
return;
}
auto indices = extractInsertOp.getIndices();
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
Value inBounds =
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
: inBounds;
}
cf::AssertOp::create(builder, loc, assertCond,
generateErrorMessage(op, "out-of-bounds access"));
}
};
struct ExtractSliceOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractSliceOpInterface, ExtractSliceOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto extractSliceOp = cast<ExtractSliceOp>(op);
RankedTensorType sourceType = extractSliceOp.getSource().getType();
// For each dimension, assert that:
// For empty slices (size == 0) : 0 <= offset <= dim_size
// For non-empty slices (size > 0): 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride <
// dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
builder.setInsertionPoint(extractSliceOp);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedStrides()[i]);
Value dimSize = builder.createOrFold<tensor::DimOp>(
loc, extractSliceOp.getSource(), i);
// Verify that offset is in-bounds (conditional on slice size).
Value sizeIsZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, size, zero);
auto offsetCheckIf = scf::IfOp::create(
builder, loc, sizeIsZero,
[&](OpBuilder &b, Location loc) {
// For empty slices, offset can be at the boundary: 0 <= offset <=
// dimSize.
Value offsetGEZero = arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sge, offset, zero);
Value offsetLEDimSize = arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sle, offset, dimSize);
Value emptyOffsetValid =
arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
scf::YieldOp::create(b, loc, emptyOffsetValid);
},
[&](OpBuilder &b, Location loc) {
// For non-empty slices, offset must be a valid index: 0 <= offset <
// dimSize.
Value offsetInBounds =
generateInBoundsCheck(b, loc, offset, zero, dimSize);
scf::YieldOp::create(b, loc, offsetInBounds);
});
Value offsetCondition = offsetCheckIf.getResult(0);
cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
// Verify that the slice endpoint is in-bounds (only for non-empty
// slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
auto ifOp = scf::IfOp::create(
builder, loc, sizeIsNonZero,
[&](OpBuilder &b, Location loc) {
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
Value sizeMinusOneTimesStride =
arith::MulIOp::create(b, loc, sizeMinusOne, stride);
Value lastPos =
arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
scf::YieldOp::create(b, loc, lastPosInBounds);
},
[&](OpBuilder &b, Location loc) {
Value trueVal =
arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
scf::YieldOp::create(b, loc, trueVal);
});
Value finalCondition = ifOp.getResult(0);
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
}
}
};
} // namespace
} // namespace tensor
} // namespace mlir
void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
});
}