//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" namespace mlir { namespace linalg { namespace { /// Verify that the runtime sizes of the operands to linalg structured ops are /// compatible with the runtime sizes inferred by composing the loop ranges with /// the linalg op's indexing maps. This is similar to the verifier except that /// here we insert IR to perform the verification at runtime. template struct StructuredOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< StructuredOpInterface, T> { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto linalgOp = llvm::cast(op); SmallVector loopRanges = linalgOp.createLoopRanges(builder, loc); auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges); auto zero = arith::ConstantIndexOp::create(builder, loc, 0); auto one = arith::ConstantIndexOp::create(builder, loc, 1); // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); return builder.createOrFold(loc, endValue, one); }); for (OpOperand &opOperand : linalgOp->getOpOperands()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); auto startIndices = affine::makeComposedFoldedMultiResultAffineApply( builder, loc, indexingMap, starts); auto endIndices = affine::makeComposedFoldedMultiResultAffineApply( builder, loc, indexingMap, ends); for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) { auto startIndex = getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]); auto endIndex = getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]); // Generate: // minIndex = min(startIndex, endIndex) // assert(minIndex >= 0) // To ensure we do not generate a negative index. We take the minimum of // the start and end indices in order to handle reverse loops such as // `affine_map<(i) -> (3 - i)>` auto min = builder.createOrFold(loc, startIndex, endIndex); auto cmpOp = builder.createOrFold( loc, index::IndexCmpPredicate::SGE, min, zero); auto msg = RuntimeVerifiableOpInterface::generateErrorMessage( linalgOp, "unexpected negative result on dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber())); builder.createOrFold(loc, cmpOp, msg); // Generate: // inferredDimSize = max(startIndex, endIndex) + 1 // actualDimSize = dim(operand) // assert(inferredDimSize <= actualDimSize) // To ensure that we do not index past the bounds of the operands. auto max = builder.createOrFold(loc, startIndex, endIndex); auto inferredDimSize = builder.createOrFold(loc, max, one); auto actualDimSize = createOrFoldDimOp(builder, loc, opOperand.get(), dim); // Similar to the verifier, when the affine expression in the indexing // map is complicated, we just check that the inferred dimension sizes // are in the boundary of the operands' size. Being more precise than // that is difficult. auto predicate = isa(indexingMap.getResult(dim)) ? index::IndexCmpPredicate::EQ : index::IndexCmpPredicate::SLE; cmpOp = builder.createOrFold( loc, predicate, inferredDimSize, actualDimSize); msg = RuntimeVerifiableOpInterface::generateErrorMessage( linalgOp, "dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber()) + " is incompatible with inferred dimension size"); builder.createOrFold(loc, cmpOp, msg); } } } }; template void attachInterface(MLIRContext *ctx) { (OpTs::template attachInterface>(*ctx), ...); } } // namespace } // namespace linalg } // namespace mlir void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) { attachInterface< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); }); }