From 1087c1079f870518b6bf6e2f6ed764d3e90611ae Mon Sep 17 00:00:00 2001 From: lonely eagle <2020382038@qq.com> Date: Mon, 6 Oct 2025 17:05:21 +0800 Subject: [PATCH] [mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params (#160985) Add hoist-dynamic-allocs-option to buffer-results-to-out-params. This PR supported that obtain the size of the dynamic shape memref through the caller-callee relationship. --- .../Dialect/Bufferization/Transforms/Passes.h | 13 ++- .../Bufferization/Transforms/Passes.td | 2 + .../Transforms/BufferResultsToOutParams.cpp | 96 +++++++++++++++++-- ...ts-to-out-params-hosit-dynamic-allocs.mlir | 79 +++++++++++++++ ...ts-to-out-params-hosit-static-allocs.mlir} | 0 5 files changed, 178 insertions(+), 12 deletions(-) create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir rename mlir/test/Transforms/{buffer-results-to-out-params-elim.mlir => buffer-results-to-out-params-hosit-static-allocs.mlir} (100%) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index a2409f2796b9..67ac487d8226 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -131,8 +131,8 @@ struct BufferResultsToOutParamsOpts { /// Allocator function: Generate a memref allocation with the given type. /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped /// results, we don't allow passing a range of values for dynamic dims. - using AllocationFn = - std::function(OpBuilder &, Location, MemRefType)>; + using AllocationFn = std::function(OpBuilder &, Location, + MemRefType, ValueRange)>; /// Memcpy function: Generate a memcpy between two memrefs. using MemCpyFn = @@ -147,8 +147,9 @@ struct BufferResultsToOutParamsOpts { /// Allocation function; used to allocate a memref. /// Default memref.alloc is used AllocationFn allocationFn = [](OpBuilder &builder, Location loc, - MemRefType type) { - return memref::AllocOp::create(builder, loc, type).getResult(); + MemRefType type, ValueRange dynamicSizes) { + return memref::AllocOp::create(builder, loc, type, dynamicSizes) + .getResult(); }; /// Memcpy function; used to create a copy between two memrefs. @@ -166,6 +167,10 @@ struct BufferResultsToOutParamsOpts { /// If true, the pass eliminates the memref.alloc and memcpy if the returned /// memref is allocated in the current function. bool hoistStaticAllocs = false; + + /// If true, the pass eliminates the memref.alloc and memcpy if the returned + /// memref is allocated in the current function and has dynamic shape. + bool hoistDynamicAllocs = false; }; /// Replace buffers that are returned from a function with an out parameter. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index a0d113c150c5..cad44cb15f47 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass "Add the attribute 'bufferize.result' to all output parameters.">, Option<"hoistStaticAllocs", "hoist-static-allocs", "bool", /*default=*/"false", "Hoist static allocations to call sites.">, + Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool", + /*default=*/"false", "Hoist dynamic allocations to call sites.">, ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index e30e094c2846..25f941dc1651 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -23,6 +23,8 @@ namespace bufferization { using namespace mlir; using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; +using AllocDynamicSizesMap = + llvm::DenseMap>>; /// Return `true` if the given MemRef type has a fully dynamic layout. static bool hasFullyDynamicLayoutMap(MemRefType type) { @@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) { return type.getLayout().isIdentity(); } +/// Return the dynamic shapes of the `memref` based on the defining op. If the +/// complete dynamic shape fails to be captured, return an empty value. +/// Currently, only function block arguments are supported for capturing. +static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { + Operation *defOp = memref.getDefiningOp(); + if (!defOp) + return {}; + auto operands = defOp->getOperands(); + SmallVector dynamicSizes; + for (Value size : operands) { + if (!isa(size.getType())) + continue; + + BlockArgument sizeSrc = dyn_cast(size); + if (!sizeSrc) + return {}; + auto arguments = funcOp.getArguments(); + auto iter = llvm::find(arguments, sizeSrc); + if (iter == arguments.end()) + return {}; + dynamicSizes.push_back(*iter); + } + return dynamicSizes; +} + +/// Returns the dynamic sizes at the callee, through the call relationship +/// between the caller and callee. +static SmallVector mapDynamicSizeAtCaller(func::CallOp call, + func::FuncOp callee, + ValueRange dynamicSizes) { + SmallVector mappedDynamicSizes; + for (Value size : dynamicSizes) { + for (auto [src, dst] : + llvm::zip_first(call.getOperands(), callee.getArguments())) { + if (size != dst) + continue; + mappedDynamicSizes.push_back(src); + } + } + assert(mappedDynamicSizes.size() == dynamicSizes.size() && + "could not find all dynamic sizes"); + return mappedDynamicSizes; +} + // Updates the func op and entry block. // // Any args appended to the entry block are added to `appendedEntryArgs`. @@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func, // the given out-params. static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, + AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options) { auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; @@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, keepAsReturnOperands.push_back(operand); } OpBuilder builder(op); + SmallVector> dynamicSizes; for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { - if (options.hoistStaticAllocs && + bool hoistStaticAllocs = + options.hoistStaticAllocs && + cast(orig.getType()).hasStaticShape(); + bool hoistDynamicAllocs = + options.hoistDynamicAllocs && + !cast(orig.getType()).hasStaticShape(); + if ((hoistStaticAllocs || hoistDynamicAllocs) && isa_and_nonnull( - orig.getDefiningOp()) && - mlir::cast(orig.getType()).hasStaticShape()) { + orig.getDefiningOp())) { orig.replaceAllUsesWith(arg); + if (hoistDynamicAllocs) { + SmallVector dynamicSize = getDynamicSize(orig, func); + dynamicSizes.push_back(dynamicSize); + } orig.getDefiningOp()->erase(); } else { if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) @@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, } func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands); op.erase(); + auto dynamicSizePair = + std::pair>>(func, + dynamicSizes); + map.insert(dynamicSizePair); return WalkResult::advance(); }); return failure(res.wasInterrupted()); @@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. static LogicalResult -updateCalls(ModuleOp module, +updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options) { bool didFail = false; SymbolTable symtab(module); @@ -166,8 +227,15 @@ updateCalls(ModuleOp module, } SmallVector outParams; OpBuilder builder(op); + SmallVector> dynamicSizes = map.lookup(callee); + size_t dynamicSizesIndex = 0; for (Value memref : replaceWithOutParams) { - if (!cast(memref.getType()).hasStaticShape()) { + SmallVector dynamicSize = dynamicSizes.size() > dynamicSizesIndex + ? dynamicSizes[dynamicSizesIndex] + : SmallVector(); + bool memrefStaticShape = + cast(memref.getType()).hasStaticShape(); + if (!memrefStaticShape && dynamicSize.empty()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; @@ -177,8 +245,15 @@ updateCalls(ModuleOp module, auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); + + if (memrefStaticShape) { + dynamicSize = {}; + } else { + ++dynamicSizesIndex; + dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize); + } auto maybeOutParam = - options.allocationFn(builder, op.getLoc(), allocType); + options.allocationFn(builder, op.getLoc(), allocType, dynamicSize); if (failed(maybeOutParam)) { op.emitError() << "failed to create allocation op"; didFail = true; @@ -213,6 +288,9 @@ updateCalls(ModuleOp module, LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options) { + // It maps the shape source of the dynamic shape memref returned by each + // function. + AllocDynamicSizesMap map; for (auto func : module.getOps()) { if (!options.filterFn(&func)) continue; @@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); if (func.isExternal()) continue; - if (failed(updateReturnOps(func, appendedEntryArgs, options))) { + if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) { return failure(); } } - if (failed(updateCalls(module, options))) + if (failed(updateCalls(module, map, options))) return failure(); return success(); } @@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass options.addResultAttribute = true; if (hoistStaticAllocs) options.hoistStaticAllocs = true; + if (hoistDynamicAllocs) + options.hoistDynamicAllocs = true; if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), options))) diff --git a/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir new file mode 100644 index 000000000000..f33eb8e26fbc --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s + +func.func private @single_alloc(%size : index) -> (memref) { + %alloc = memref.alloc(%size) : memref + return %alloc : memref +} + +func.func @single_alloc_test(%size : index) { + %alloc = call @single_alloc(%size) : (index) -> (memref) + "test.sink"(%alloc) : (memref) -> () +} + +// CHECK-LABEL: func.func private @single_alloc( +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref) { + +// CHECK-LABEL: func.func @single_alloc_test( +// CHECK-SAME: %[[size:.*]]: index) { +// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref +// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref) -> () +// CHECK: "test.sink"(%[[alloc]]) : (memref) -> () +// CHECK: } + +// ----- + +func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref, memref) { + %alloc0 = memref.alloc(%size0, %size1) : memref + %alloc1 = memref.alloc(%size1) : memref + return %alloc0, %alloc1 : memref, memref +} + +func.func @mult_alloc_test(%size0 : index, %size1: index) { + %alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref, memref) + "test.sink"(%alloc0, %alloc1) : (memref, memref) -> () +} + +// CHECK-LABEL: func private @mult_alloc( +// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref, %{{.*}}: memref) { + +// CHECK-LABEL: func @mult_alloc_test( +// CHECK-SAME: %[[size0:.*]]: index, +// CHECK-SAME: %[[size1:.*]]: index) { +// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref +// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref, memref) -> () +// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref, memref) -> () +// CHECK: } + + +// ----- + +func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref, memref<4xf32>, memref) { + %alloc0 = memref.alloc(%size0, %size1) : memref + %alloc1 = memref.alloc() : memref<4xf32> + %alloc2 = memref.alloc(%size1) : memref + return %alloc0, %alloc1, %alloc2 : memref, memref<4xf32>, memref +} + +func.func @complex_alloc_test(%size0 : index, %size1: index) { + %alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref, memref<4xf32>, memref) + "test.sink"(%alloc0, %alloc1, %alloc2) : (memref, memref<4xf32>, memref) -> () +} + +// CHECK-LABEL: func private @complex_alloc( +// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref, +// CHECK-SAME: %{{.*}}: memref<4xf32>, +// CHECK-SAME: %{{.*}}: memref) { + +// CHECK-LABEL: func @complex_alloc_test( +// CHECK-SAME: %[[size0:.*]]: index, +// CHECK-SAME: %[[size1:.*]]: index) { +// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32> +// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref +// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref, memref<4xf32>, memref) -> () +// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref, memref<4xf32>, memref) -> () +// CHECK: } diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir similarity index 100% rename from mlir/test/Transforms/buffer-results-to-out-params-elim.mlir rename to mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir