//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass drops return values from functions if they are equivalent to one of // their arguments. E.g.: // // ``` // func.func @foo(%m : memref) -> (memref) { // return %m : memref // } // ``` // // This functions is rewritten to: // // ``` // func.func @foo(%m : memref) { // return // } // ``` // // All call sites are updated accordingly. If a function returns a cast of a // function argument, it is also considered equivalent. A cast is inserted at // the call site in that case. #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" namespace mlir { namespace bufferization { #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" } // namespace bufferization } // namespace mlir using namespace mlir; /// Get all the ReturnOp in the funcOp. static SmallVector getReturnOps(func::FuncOp funcOp) { SmallVector returnOps; for (Block &b : funcOp.getBody()) { if (auto candidateOp = dyn_cast(b.getTerminator())) { returnOps.push_back(candidateOp); } } return returnOps; } /// Get the operands at the specified position for all returnOps. static SmallVector getReturnOpsOperandInPos(ArrayRef returnOps, size_t pos) { return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) { return returnOp.getOperand(pos); }); } /// Check if all given values are the same buffer as the block argument (modulo /// cast ops). static bool operandsEqualFuncArgument(ArrayRef operands, BlockArgument argument) { for (Value val : operands) { while (auto castOp = val.getDefiningOp()) val = castOp.getSource(); if (val != argument) return false; } return true; } LogicalResult mlir::bufferization::dropEquivalentBufferResults( ModuleOp module, DropBufferResultsOpts options) { IRRewriter rewriter(module.getContext()); DenseMap> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null(callOp.resolveCallable())) { if (calledFunc.isPublic() && !options.modifyPublicFunctions) return WalkResult::advance(); if (!calledFunc.isExternal()) callerMap[calledFunc].insert(callOp); } return WalkResult::advance(); }); for (auto funcOp : module.getOps()) { if (funcOp.isPublic() && !options.modifyPublicFunctions) continue; if (funcOp.isExternal()) continue; SmallVector returnOps = getReturnOps(funcOp); if (returnOps.empty()) continue; // Compute erased results. size_t numReturnOps = returnOps.size(); size_t numReturnValues = funcOp.getFunctionType().getNumResults(); SmallVector> newReturnValues(numReturnOps); BitVector erasedResultIndices(numReturnValues); DenseMap resultToArgs; for (size_t i = 0; i < numReturnValues; ++i) { bool erased = false; SmallVector returnOperands = getReturnOpsOperandInPos(returnOps, i); for (BlockArgument bbArg : funcOp.getArguments()) { if (operandsEqualFuncArgument(returnOperands, bbArg)) { resultToArgs[i] = bbArg.getArgNumber(); erased = true; break; } } if (erased) { erasedResultIndices.set(i); } else { for (auto [newReturnValue, operand] : llvm::zip(newReturnValues, returnOperands)) { newReturnValue.push_back(operand); } } } // Update function. if (failed(funcOp.eraseResults(erasedResultIndices))) return failure(); for (auto [returnOp, newReturnValue] : llvm::zip(returnOps, newReturnValues)) returnOp.getOperandsMutable().assign(newReturnValue); // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { rewriter.setInsertionPoint(callOp); auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp, callOp.getOperands()); SmallVector newResults; int64_t nextResult = 0; for (int64_t i = 0; i < callOp.getNumResults(); ++i) { if (!resultToArgs.count(i)) { // This result was not erased. newResults.push_back(newCallOp.getResult(nextResult++)); continue; } // This result was erased. Value replacement = callOp.getOperand(resultToArgs[i]); Type expectedType = callOp.getResult(i).getType(); if (replacement.getType() != expectedType) { // A cast must be inserted at the call site. replacement = memref::CastOp::create(rewriter, callOp.getLoc(), expectedType, replacement); } newResults.push_back(replacement); } rewriter.replaceOp(callOp, newResults); } } return success(); } namespace { struct DropEquivalentBufferResultsPass : bufferization::impl::DropEquivalentBufferResultsPassBase< DropEquivalentBufferResultsPass> { using Base::Base; void runOnOperation() override { // Convert pass options. options.modifyPublicFunctions = modifyPublicFunctions; if (failed(bufferization::dropEquivalentBufferResults(getOperation(), options))) return signalPassFailure(); } private: bufferization::DropBufferResultsOpts options; }; } // namespace