//===----------- MultiBuffering.cpp ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements multi buffering transformation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") /// Return true if the op fully overwrite the given `buffer` value. static bool overrideBuffer(Operation *op, Value buffer) { auto copyOp = dyn_cast(op); if (!copyOp) return false; return copyOp.getTarget() == buffer; } /// Replace the uses of `oldOp` with the given `val` and for view-like uses /// propagate the type change. Changing the memref type may require propagating /// it through view-like ops (subview, expand_shape, collapse_shape, cast) so /// we need to propagate the type change and erase old view ops. /// /// Only view-like ops whose result type can be recomputed from the new source /// type and existing op attributes are handled here. Other ops fall back to /// operand replacement without type propagation. static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val) { SmallVector opsToErase; // Iterate with early_inc to erase current user inside the loop. for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { Operation *user = use.getOwner(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(user); MemRefType srcType = cast(val.getType()); // Try to create a new view-like op with updated result type. // Each view-like op has its own method to compute the result type. bool typeInferenceFailed = false; Value replacement = llvm::TypeSwitch(user) .Case([&](memref::SubViewOp subview) -> Value { MemRefType newType = memref::SubViewOp::inferRankReducedResultType( subview.getType().getShape(), srcType, subview.getStaticOffsets(), subview.getStaticSizes(), subview.getStaticStrides()); return memref::SubViewOp::create( rewriter, subview->getLoc(), newType, val, subview.getMixedOffsets(), subview.getMixedSizes(), subview.getMixedStrides()); }) .Case([&](memref::ExpandShapeOp expand) -> Value { FailureOr newType = memref::ExpandShapeOp::computeExpandedType( srcType, expand.getResultType().getShape(), expand.getReassociationIndices()); if (failed(newType)) { typeInferenceFailed = true; return Value(); } return memref::ExpandShapeOp::create( rewriter, expand->getLoc(), *newType, val, expand.getReassociationIndices(), expand.getMixedOutputShape()); }) .Case([&](memref::CollapseShapeOp collapse) -> Value { FailureOr newType = memref::CollapseShapeOp::computeCollapsedType( srcType, collapse.getReassociationIndices()); if (failed(newType)) { typeInferenceFailed = true; return Value(); } return memref::CollapseShapeOp::create( rewriter, collapse->getLoc(), *newType, val, collapse.getReassociationIndices()); }) .Case([&](memref::CastOp cast) -> Value { if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) { typeInferenceFailed = true; return Value(); } return memref::CastOp::create(rewriter, cast->getLoc(), cast.getType(), val); }) .Default([&](Operation *) -> Value { return Value(); }); if (typeInferenceFailed) { user->emitOpError( "failed to compute view-like result type after multi-buffering"); return failure(); } if (replacement) { // Recursively propagate through view-like ops and mark old op for // erasure. if (failed(replaceUsesAndPropagateType(rewriter, user, replacement))) return failure(); opsToErase.push_back(user); } else { // Not a view-like op: just replace operand. rewriter.startOpModification(user); use.set(val); rewriter.finalizeOpModification(user); } } for (Operation *op : opsToErase) { rewriter.eraseOp(op); } return success(); } // Transformation to do multi-buffering/array expansion to remove dependencies // on the temporary allocation between consecutive loop iterations. // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. FailureOr mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiBufferingFactor, bool skipOverrideAnalysis) { LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n"); DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { auto parentLoop = user->getParentOfType(); if (!parentLoop) { if (isa(user)) { // Allow dealloc outside of any loop. // TODO: The whole precondition function here is very brittle and will // need to rethought an isolated into a cleaner analysis. continue; } LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n"); LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n"); return failure(); } if (!skipOverrideAnalysis) { /// Make sure there is no loop-carried dependency on the allocation. if (!overrideBuffer(user, allocOp.getResult())) { LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n"); continue; } // If this user doesn't dominate all the other users keep looking. if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !dom.dominates(user, otherUser); })) { LLVM_DEBUG( DBGS() << "--Skip user: does not dominate all other users\n"); continue; } } else { if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !isa(otherUser) && !parentLoop->isProperAncestor(otherUser); })) { LLVM_DEBUG( DBGS() << "--Skip user: not all other users are in the parent loop\n"); continue; } } candidateLoop = parentLoop; break; } if (!candidateLoop) { LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n"); return failure(); } std::optional inductionVar = candidateLoop.getSingleInductionVar(); std::optional lowerBound = candidateLoop.getSingleLowerBound(); std::optional singleStep = candidateLoop.getSingleStep(); if (!inductionVar || !lowerBound || !singleStep || !llvm::hasSingleElement(candidateLoop.getLoopRegions())) { LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n"); return failure(); } if (!dom.dominates(allocOp.getOperation(), candidateLoop)) { LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n"); return failure(); } LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n"); // 1. Construct the multi-buffered memref type. ArrayRef originalShape = allocOp.getType().getShape(); SmallVector multiBufferedShape{multiBufferingFactor}; llvm::append_range(multiBufferedShape, originalShape); LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n"); MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) .setShape(multiBufferedShape) .setLayout(MemRefLayoutAttrInterface()); LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n"); // 2. Create the multi-buffered alloc. Location loc = allocOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocOp); auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); // 3. Within the loop, build the modular leading index (i.e. each loop // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). rewriter.setInsertionPointToStart( &candidateLoop.getLoopRegions().front()->front()); Value ivVal = *inductionVar; Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound); Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep); AffineExpr iv, lb, step; bindDims(rewriter.getContext(), iv, lb, step); Value bufferIndex = affine::makeComposedAffineApply( rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, {ivVal, lbVal, stepVal}); LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n"); // 4. Build the subview accessing the particular slice, taking modular // rotation into account. int64_t mbMemRefTypeRank = mbMemRefType.getRank(); IntegerAttr zero = rewriter.getIndexAttr(0); IntegerAttr one = rewriter.getIndexAttr(1); SmallVector offsets(mbMemRefTypeRank, zero); SmallVector sizes(mbMemRefTypeRank, one); SmallVector strides(mbMemRefTypeRank, one); // Offset is [bufferIndex, 0 ... 0 ]. offsets.front() = bufferIndex; // Sizes is [1, original_size_0 ... original_size_n ]. for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( originalShape, mbMemRefType, offsets, sizes, strides); Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need // to handle dealloc uses separately.. for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { auto deallocOp = dyn_cast(use.getOwner()); if (!deallocOp) continue; OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(deallocOp); auto newDeallocOp = memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc); (void)newDeallocOp; LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); rewriter.eraseOp(deallocOp); } // 6. RAUW with the particular slice, taking modular rotation into account. if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview))) return failure(); // 7. Finally, erase the old allocOp. rewriter.eraseOp(allocOp); return mbAlloc; } FailureOr mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiBufferingFactor, bool skipOverrideAnalysis) { IRRewriter rewriter(allocOp->getContext()); return multiBuffer(rewriter, allocOp, multiBufferingFactor, skipOverrideAnalysis); }