//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; namespace mlir { namespace tensor { namespace { /// Generate a runtime check for lb <= value < ub. Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, Value lb, Value ub) { Value inBounds1 = builder.createOrFold( loc, arith::CmpIPredicate::sge, value, lb); Value inBounds2 = builder.createOrFold( loc, arith::CmpIPredicate::slt, value, ub); Value inBounds = builder.createOrFold(loc, inBounds1, inBounds2); return inBounds; } struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, function_ref generateErrorMessage) const { auto castOp = cast(op); auto srcType = cast(castOp.getSource().getType()); // Nothing to check if the result is an unranked tensor. auto resultType = dyn_cast(castOp.getType()); if (!resultType) return; if (isa(srcType)) { // Check rank. Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, generateErrorMessage(op, "rank mismatch")); } // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. if (auto rankedSrcType = dyn_cast(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; // Static/dynamic dim size -> dynamic dim size does not need verification. if (resultType.isDynamicDim(it.index())) continue; Value srcDimSz = DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = arith::ConstantIndexOp::create(builder, loc, it.value()); Value isSameSz = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, generateErrorMessage(op, "size mismatch of dim " + std::to_string(it.index()))); } } }; struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, function_ref generateErrorMessage) const { auto dimOp = cast(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), generateErrorMessage(op, "index is out of bounds")); } }; /// Verifies that the indices on extract/insert ops are in-bounds of the /// tensor's index space: 0 <= index#i < dim#i template struct ExtractInsertOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractInsertOpInterface, OpTy> { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, function_ref generateErrorMessage) const { auto extractInsertOp = cast(op); Value tensor; if constexpr (std::is_same_v) { tensor = extractInsertOp.getTensor(); } else if constexpr (std::is_same_v) { tensor = extractInsertOp.getDest(); } else { llvm_unreachable("invalid op"); } auto tensorType = cast(tensor.getType()); auto rank = tensorType.getRank(); if (rank == 0) { // Nothing to check for 0-d tensors. return; } auto indices = extractInsertOp.getIndices(); auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { Value dimOp = builder.createOrFold(loc, tensor, i); Value inBounds = generateInBoundsCheck(builder, loc, indices[i], zero, dimOp); assertCond = i > 0 ? builder.createOrFold(loc, assertCond, inBounds) : inBounds; } cf::AssertOp::create(builder, loc, assertCond, generateErrorMessage(op, "out-of-bounds access")); } }; struct ExtractSliceOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractSliceOpInterface, ExtractSliceOp> { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, function_ref generateErrorMessage) const { auto extractSliceOp = cast(op); RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: // For empty slices (size == 0) : 0 <= offset <= dim_size // For non-empty slices (size > 0): 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i : llvm::seq(0, sourceType.getRank())) { builder.setInsertionPoint(extractSliceOp); Value offset = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); Value dimSize = builder.createOrFold( loc, extractSliceOp.getSource(), i); // Verify that offset is in-bounds (conditional on slice size). Value sizeIsZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, size, zero); auto offsetCheckIf = scf::IfOp::create( builder, loc, sizeIsZero, [&](OpBuilder &b, Location loc) { // For empty slices, offset can be at the boundary: 0 <= offset <= // dimSize. Value offsetGEZero = arith::CmpIOp::create( b, loc, arith::CmpIPredicate::sge, offset, zero); Value offsetLEDimSize = arith::CmpIOp::create( b, loc, arith::CmpIPredicate::sle, offset, dimSize); Value emptyOffsetValid = arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); scf::YieldOp::create(b, loc, emptyOffsetValid); }, [&](OpBuilder &b, Location loc) { // For non-empty slices, offset must be a valid index: 0 <= offset < // dimSize. Value offsetInBounds = generateInBoundsCheck(b, loc, offset, zero, dimSize); scf::YieldOp::create(b, loc, offsetInBounds); }); Value offsetCondition = offsetCheckIf.getResult(0); cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that the slice endpoint is in-bounds (only for non-empty // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); auto ifOp = scf::IfOp::create( builder, loc, sizeIsNonZero, [&](OpBuilder &b, Location loc) { // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); Value sizeMinusOneTimesStride = arith::MulIOp::create(b, loc, sizeMinusOne, stride); Value lastPos = arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(b, loc, lastPos, zero, dimSize); scf::YieldOp::create(b, loc, lastPosInBounds); }, [&](OpBuilder &b, Location loc) { Value trueVal = arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); scf::YieldOp::create(b, loc, trueVal); }); Value finalCondition = ifOp.getResult(0); cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); } } }; } // namespace } // namespace tensor } // namespace mlir void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { CastOp::attachInterface(*ctx); DimOp::attachInterface(*ctx); ExtractOp::attachInterface>(*ctx); ExtractSliceOp::attachInterface(*ctx); InsertOp::attachInterface>(*ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); }); }