llvm-project/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
Hanumanth a6788b5246
[mlir][tensor] Fix runtime verification for tensor.extract_slice when size dimension value is 0 (#164878)
Previously, the runtime verification pass would insert assertion
statements with conditions that always evaluate to false for
semantically valid `tensor.extract_slice` operations where one of the
dimensions had a size of 0.

The `tensor.extract_slice` runtime verification logic was
unconditionally generating checks for the position of the last element
(`offset + (size - 1) * stride`). When `size` is 0, this causes the
assertion condition to always be false, leading to runtime failures even
though the operation is semantically valid.

This patch fixes the issue by making the `lastPos` check conditional.
The offset is always verified, but the endpoint check is only performed
when `size > 0` to avoid generating spurious assert statements.

This issue was discovered through LiteRT model, where a dynamic shape
calculation resulted in a zero-sized dimension being passed to
`tensor.extract_slice`.

The following is a simplified IR snippet from the model. After running
the runtime verification pass, an assertion that always fails is
generated because the SSA value `%3` becomes 0.

```mlir
func.func @simple_repro_from_liteRT_model(%arg0: tensor<10x4x1xf32>) -> tensor<?x?x?xf32> {
  %cst = arith.constant dense<0> : tensor<1xi32>
  %cst_0 = arith.constant dense<-1> : tensor<2xi32>
  %c-1 = arith.constant -1 : index
  %c0 = arith.constant 0 : index
  %c10 = arith.constant 10 : index
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c2 = arith.constant 2 : index
  %0 = tensor.empty() : tensor<3xi32>
  %inserted_slice = tensor.insert_slice %cst into %0[0] [1] [1] : tensor<1xi32> into tensor<3xi32>
  %inserted_slice_1 = tensor.insert_slice %cst_0 into %inserted_slice[1] [2] [1] : tensor<2xi32> into tensor<3xi32>
  %extracted = tensor.extract %inserted_slice_1[%c0] : tensor<3xi32>
  %1 = index.casts %extracted : i32 to index
  %2 = arith.cmpi eq, %1, %c-1 : index
  %3 = arith.select %2, %c10, %1 : index
  %extracted_2 = tensor.extract %inserted_slice_1[%c1] : tensor<3xi32>
  %4 = index.casts %extracted_2 : i32 to index
  %5 = arith.cmpi eq, %4, %c-1 : index
  %6 = arith.select %5, %c4, %4 : index
  %extracted_3 = tensor.extract %inserted_slice_1[%c2] : tensor<3xi32>
  %7 = index.casts %extracted_3 : i32 to index
  %8 = arith.cmpi eq, %7, %c-1 : index
  %9 = arith.select %8, %c1, %7 : index
  %extracted_slice = tensor.extract_slice %arg0[0, 0, 0] [%3, %6, %9] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
  return %extracted_slice : tensor<?x?x?xf32>
}
```

The issue can be reproduced more simply with the following test case,
where `dim_0` is `0`. When the runtime verification pass is applied to
this code with `dim_0 = 0`, it generates an assertion that will always
fail at runtime.

```mlir
func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>,
                                      %dim_0: index,
                                      %dim_1: index,
                                      %dim_2: index) {
  %slice = tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1]
    : tensor<10x4x1xf32> to tensor<?x?x?xf32>
  return
}

func.func @test_zero_size_extraction() {
  %input = arith.constant dense<1.0> : tensor<10x4x1xf32>
  // Define slice dimensions: 0x4x1 (zero-size in first dimension)
  %dim_0 = arith.constant 0 : index
  %dim_1 = arith.constant 4 : index
  %dim_2 = arith.constant 1 : index
  func.call @extract_slice_zero_size_dim(%input, %dim_0, %dim_1, %dim_2)
    : (tensor<10x4x1xf32>, index, index, index) -> ()
  return
}
```

P.S. We probably have a similar issue with `memref.subview`. I will
check this and send a separate PR for the issue.

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2025-10-27 11:43:18 -07:00

238 lines
9.7 KiB
C++

//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
namespace mlir {
namespace tensor {
namespace {
/// Generate a runtime check for lb <= value < ub.
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
Value lb, Value ub) {
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value, lb);
Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, value, ub);
Value inBounds =
builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
return inBounds;
}
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<TensorType>(castOp.getSource().getType());
// Nothing to check if the result is an unranked tensor.
auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
if (!resultType)
return;
if (isa<UnrankedTensorType>(srcType)) {
// Check rank.
Value srcRank = RankOp::create(builder, loc, castOp.getSource());
Value resultRank =
arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
Value isSameRank = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
cf::AssertOp::create(builder, loc, isSameRank,
generateErrorMessage(op, "rank mismatch"));
}
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
// Static dim size -> static/dynamic dim size does not need verification.
if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
if (!rankedSrcType.isDynamicDim(it.index()))
continue;
// Static/dynamic dim size -> dynamic dim size does not need verification.
if (resultType.isDynamicDim(it.index()))
continue;
Value srcDimSz =
DimOp::create(builder, loc, castOp.getSource(), it.index());
Value resultDimSz =
arith::ConstantIndexOp::create(builder, loc, it.value());
Value isSameSz = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
cf::AssertOp::create(
builder, loc, isSameSz,
generateErrorMessage(op, "size mismatch of dim " +
std::to_string(it.index())));
}
}
};
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto dimOp = cast<DimOp>(op);
Value rank = RankOp::create(builder, loc, dimOp.getSource());
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
cf::AssertOp::create(
builder, loc,
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
generateErrorMessage(op, "index is out of bounds"));
}
};
/// Verifies that the indices on extract/insert ops are in-bounds of the
/// tensor's index space: 0 <= index#i < dim#i
template <typename OpTy>
struct ExtractInsertOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractInsertOpInterface<OpTy>, OpTy> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto extractInsertOp = cast<OpTy>(op);
Value tensor;
if constexpr (std::is_same_v<OpTy, ExtractOp>) {
tensor = extractInsertOp.getTensor();
} else if constexpr (std::is_same_v<OpTy, InsertOp>) {
tensor = extractInsertOp.getDest();
} else {
llvm_unreachable("invalid op");
}
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto rank = tensorType.getRank();
if (rank == 0) {
// Nothing to check for 0-d tensors.
return;
}
auto indices = extractInsertOp.getIndices();
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
Value inBounds =
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
: inBounds;
}
cf::AssertOp::create(builder, loc, assertCond,
generateErrorMessage(op, "out-of-bounds access"));
}
};
struct ExtractSliceOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractSliceOpInterface, ExtractSliceOp> {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto extractSliceOp = cast<ExtractSliceOp>(op);
RankedTensorType sourceType = extractSliceOp.getSource().getType();
// For each dimension, assert that:
// 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride < dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
// Reset insertion point to before the operation for each dimension
builder.setInsertionPoint(extractSliceOp);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedStrides()[i]);
// Verify that offset is in-bounds.
Value dimSize = builder.createOrFold<tensor::DimOp>(
loc, extractSliceOp.getSource(), i);
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
cf::AssertOp::create(builder, loc, offsetInBounds,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
// Only verify if size > 0
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
sizeIsNonZero, /*withElseRegion=*/true);
// Populate the "then" region (for size > 0).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
Value lastPos =
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
scf::YieldOp::create(builder, loc, lastPosInBounds);
// Populate the "else" region (for size == 0).
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
Value trueVal =
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
scf::YieldOp::create(builder, loc, trueVal);
builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
}
}
};
} // namespace
} // namespace tensor
} // namespace mlir
void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
});
}