//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements loop fusion. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; namespace { /// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which /// access the same memref with no dependences. // See MatchTestPattern for details on candidate loop selection. // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. struct LoopFusion : public FunctionPass { LoopFusion() {} PassResult runOnMLFunction(MLFunction *f) override; static char passID; }; // LoopCollector walks the statements in an MLFunction and builds a map from // StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts // to the list of loads and stores with its StmtBlock. class LoopCollector : public StmtWalker { public: DenseMap> loopMap; DenseMap> loadsAndStoresMap; bool hasIfStmt = false; void visitForStmt(ForStmt *forStmt) { loopMap[forStmt->getBlock()].push_back(forStmt); } void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } void visitOperationStmt(OperationStmt *opStmt) { if (auto *parentStmt = opStmt->getParentStmt()) { if (auto *parentForStmt = dyn_cast(parentStmt)) { if (opStmt->isa() || opStmt->isa()) { loadsAndStoresMap[parentForStmt].push_back(opStmt); } } } } }; } // end anonymous namespace char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } // TODO(andydavis) Remove the following test code when more general loop // fusion is supported. struct FusionCandidate { // Loop nest of ForStmts with 'accessA' in the inner-most loop. SmallVector forStmtsA; // Load or store operation within loop nest 'forStmtsA'. MemRefAccess accessA; // Loop nest of ForStmts with 'accessB' in the inner-most loop. SmallVector forStmtsB; // Load or store operation within loop nest 'forStmtsB'. MemRefAccess accessB; }; static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { access->memref = cast(loadOp->getMemRef()); access->opStmt = loadOrStoreOpStmt; auto loadMemrefType = loadOp->getMemRefType(); access->indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { access->indices.push_back(cast(index)); } } else { assert(loadOrStoreOpStmt->isa()); auto storeOp = loadOrStoreOpStmt->dyn_cast(); access->opStmt = loadOrStoreOpStmt; access->memref = cast(storeOp->getMemRef()); auto storeMemrefType = storeOp->getMemRefType(); access->indices.reserve(storeMemrefType.getRank()); for (auto *index : storeOp->getIndices()) { access->indices.push_back(cast(index)); } } } // Checks if 'forStmtA' and 'forStmtB' match specific test criterion: // constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and // a single LoadOp in 'forStmtB'. // Returns true if the test pattern matches, false otherwise. static bool MatchTestPatternLoopPair(LoopCollector *lc, FusionCandidate *candidate, ForStmt *forStmtA, ForStmt *forStmtB) { if (forStmtA == nullptr || forStmtB == nullptr) return false; // Return if 'forStmtA' and 'forStmtB' do not have matching constant // bounds and step. if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() || forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() || forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() || forStmtA->getStep() != forStmtB->getStep()) return false; // Return if 'forStmtA' or 'forStmtB' have nested loops. if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB)) return false; // Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store. if (lc->loadsAndStoresMap[forStmtA].size() != 1 || lc->loadsAndStoresMap[forStmtB].size() != 1) return false; // Get load/store access for forStmtA. getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0], &candidate->accessA); // Return if 'accessA' is not a store. if (!candidate->accessA.opStmt->isa()) return false; // Get load/store access for forStmtB. getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0], &candidate->accessB); // Return if accesses do not access the same memref. if (candidate->accessA.memref != candidate->accessB.memref) return false; candidate->forStmtsA.push_back(forStmtA); candidate->forStmtsB.push_back(forStmtB); return true; } // Returns the child ForStmt of 'parent' if unique, returns false otherwise. ForStmt *getSingleForStmtChild(ForStmt *parent) { if (parent->getStatements().size() == 1 && isa(parent->front())) return dyn_cast(&parent->front()); return nullptr; } // Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true // on success and resturns fusion candidate in 'candidate'. Returns false // otherwise. // Currently supported test patterns: // *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the // only op in the second loop (both load/store accessing the same memref). // *) As above, but with one level of perfect loop nesting. // // TODO(andydavis) Look into using ntv@ pattern matcher here. static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) { LoopCollector lc; lc.walk(f); // Return if an IfStmt was found or if less than two ForStmts were found. if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2) return false; auto *forStmtA = lc.loopMap[f][0]; auto *forStmtB = lc.loopMap[f][1]; if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) { // Check for one level of loop nesting. candidate->forStmtsA.push_back(forStmtA); candidate->forStmtsB.push_back(forStmtB); return MatchTestPatternLoopPair(&lc, candidate, getSingleForStmtChild(forStmtA), getSingleForStmtChild(forStmtB)); } return true; } // FuseLoops implements the code generation mechanics of loop fusion. // Fuses the operations statments from the inner-most loop in 'c.forStmtsB', // by cloning them into the inner-most loop in 'c.forStmtsA', then erasing // old statements and loops. static void fuseLoops(const FusionCandidate &c) { MLFuncBuilder builder(c.forStmtsA.back(), StmtBlock::iterator(c.forStmtsA.back()->end())); DenseMap operandMap; assert(c.forStmtsA.size() == c.forStmtsB.size()); for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) { // Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'. operandMap[c.forStmtsB[i]] = c.forStmtsA[i]; } // Clone the body of inner-most loop in 'forStmtsB', into the body of // inner-most loop in 'forStmtsA'. SmallVector stmtsToErase; auto *innerForStmtB = c.forStmtsB.back(); for (auto &stmt : *innerForStmtB) { builder.clone(stmt, operandMap); stmtsToErase.push_back(&stmt); } // Erase 'forStmtB' and its statement list. for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it) (*it)->erase(); // Erase 'forStmtsB' loop nest. for (int i = static_cast(c.forStmtsB.size()) - 1; i >= 0; --i) c.forStmtsB[i]->erase(); } PassResult LoopFusion::runOnMLFunction(MLFunction *f) { FusionCandidate candidate; if (!MatchTestPattern(f, &candidate)) return failure(); // TODO(andydavis) Add checks for fusion-preventing dependences and ordering // constraints which would prevent fusion. // TODO(andydavis) This check if overly conservative for now. Support fusing // statements with compatible dependences (i.e. statements where the // dependence between the statements does not reverse direction when the // statements are fused into the same loop). if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB)) { // Current conservatinve test policy: No dependence exists between accesses // in different loop nests -> fuse loops. fuseLoops(candidate); } return success(); } static PassRegistration pass("loop-fusion", "Fuse loop nests");