[mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params (#160985)
Add hoist-dynamic-allocs-option to buffer-results-to-out-params. This PR supported that obtain the size of the dynamic shape memref through the caller-callee relationship.
This commit is contained in:
parent
1bd9c1bde3
commit
1087c1079f
@ -131,8 +131,8 @@ struct BufferResultsToOutParamsOpts {
|
||||
/// Allocator function: Generate a memref allocation with the given type.
|
||||
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
|
||||
/// results, we don't allow passing a range of values for dynamic dims.
|
||||
using AllocationFn =
|
||||
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
|
||||
using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
|
||||
MemRefType, ValueRange)>;
|
||||
|
||||
/// Memcpy function: Generate a memcpy between two memrefs.
|
||||
using MemCpyFn =
|
||||
@ -147,8 +147,9 @@ struct BufferResultsToOutParamsOpts {
|
||||
/// Allocation function; used to allocate a memref.
|
||||
/// Default memref.alloc is used
|
||||
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
|
||||
MemRefType type) {
|
||||
return memref::AllocOp::create(builder, loc, type).getResult();
|
||||
MemRefType type, ValueRange dynamicSizes) {
|
||||
return memref::AllocOp::create(builder, loc, type, dynamicSizes)
|
||||
.getResult();
|
||||
};
|
||||
|
||||
/// Memcpy function; used to create a copy between two memrefs.
|
||||
@ -166,6 +167,10 @@ struct BufferResultsToOutParamsOpts {
|
||||
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
|
||||
/// memref is allocated in the current function.
|
||||
bool hoistStaticAllocs = false;
|
||||
|
||||
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
|
||||
/// memref is allocated in the current function and has dynamic shape.
|
||||
bool hoistDynamicAllocs = false;
|
||||
};
|
||||
|
||||
/// Replace buffers that are returned from a function with an out parameter.
|
||||
|
||||
@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass
|
||||
"Add the attribute 'bufferize.result' to all output parameters.">,
|
||||
Option<"hoistStaticAllocs", "hoist-static-allocs", "bool",
|
||||
/*default=*/"false", "Hoist static allocations to call sites.">,
|
||||
Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
|
||||
/*default=*/"false", "Hoist dynamic allocations to call sites.">,
|
||||
];
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
@ -23,6 +23,8 @@ namespace bufferization {
|
||||
using namespace mlir;
|
||||
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
|
||||
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
|
||||
using AllocDynamicSizesMap =
|
||||
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
|
||||
|
||||
/// Return `true` if the given MemRef type has a fully dynamic layout.
|
||||
static bool hasFullyDynamicLayoutMap(MemRefType type) {
|
||||
@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
|
||||
return type.getLayout().isIdentity();
|
||||
}
|
||||
|
||||
/// Return the dynamic shapes of the `memref` based on the defining op. If the
|
||||
/// complete dynamic shape fails to be captured, return an empty value.
|
||||
/// Currently, only function block arguments are supported for capturing.
|
||||
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
|
||||
Operation *defOp = memref.getDefiningOp();
|
||||
if (!defOp)
|
||||
return {};
|
||||
auto operands = defOp->getOperands();
|
||||
SmallVector<Value> dynamicSizes;
|
||||
for (Value size : operands) {
|
||||
if (!isa<IndexType>(size.getType()))
|
||||
continue;
|
||||
|
||||
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
|
||||
if (!sizeSrc)
|
||||
return {};
|
||||
auto arguments = funcOp.getArguments();
|
||||
auto iter = llvm::find(arguments, sizeSrc);
|
||||
if (iter == arguments.end())
|
||||
return {};
|
||||
dynamicSizes.push_back(*iter);
|
||||
}
|
||||
return dynamicSizes;
|
||||
}
|
||||
|
||||
/// Returns the dynamic sizes at the callee, through the call relationship
|
||||
/// between the caller and callee.
|
||||
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
|
||||
func::FuncOp callee,
|
||||
ValueRange dynamicSizes) {
|
||||
SmallVector<Value> mappedDynamicSizes;
|
||||
for (Value size : dynamicSizes) {
|
||||
for (auto [src, dst] :
|
||||
llvm::zip_first(call.getOperands(), callee.getArguments())) {
|
||||
if (size != dst)
|
||||
continue;
|
||||
mappedDynamicSizes.push_back(src);
|
||||
}
|
||||
}
|
||||
assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
|
||||
"could not find all dynamic sizes");
|
||||
return mappedDynamicSizes;
|
||||
}
|
||||
|
||||
// Updates the func op and entry block.
|
||||
//
|
||||
// Any args appended to the entry block are added to `appendedEntryArgs`.
|
||||
@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
|
||||
// the given out-params.
|
||||
static LogicalResult
|
||||
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
|
||||
AllocDynamicSizesMap &map,
|
||||
const bufferization::BufferResultsToOutParamsOpts &options) {
|
||||
auto res = func.walk([&](func::ReturnOp op) {
|
||||
SmallVector<Value, 6> copyIntoOutParams;
|
||||
@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
|
||||
keepAsReturnOperands.push_back(operand);
|
||||
}
|
||||
OpBuilder builder(op);
|
||||
SmallVector<SmallVector<Value>> dynamicSizes;
|
||||
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
|
||||
if (options.hoistStaticAllocs &&
|
||||
bool hoistStaticAllocs =
|
||||
options.hoistStaticAllocs &&
|
||||
cast<MemRefType>(orig.getType()).hasStaticShape();
|
||||
bool hoistDynamicAllocs =
|
||||
options.hoistDynamicAllocs &&
|
||||
!cast<MemRefType>(orig.getType()).hasStaticShape();
|
||||
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
|
||||
isa_and_nonnull<bufferization::AllocationOpInterface>(
|
||||
orig.getDefiningOp()) &&
|
||||
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
|
||||
orig.getDefiningOp())) {
|
||||
orig.replaceAllUsesWith(arg);
|
||||
if (hoistDynamicAllocs) {
|
||||
SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
|
||||
dynamicSizes.push_back(dynamicSize);
|
||||
}
|
||||
orig.getDefiningOp()->erase();
|
||||
} else {
|
||||
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
|
||||
@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
|
||||
}
|
||||
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
|
||||
op.erase();
|
||||
auto dynamicSizePair =
|
||||
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
|
||||
dynamicSizes);
|
||||
map.insert(dynamicSizePair);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return failure(res.wasInterrupted());
|
||||
@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
|
||||
// Updates all CallOps in the scope of the given ModuleOp by allocating
|
||||
// temporary buffers for newly introduced out params.
|
||||
static LogicalResult
|
||||
updateCalls(ModuleOp module,
|
||||
updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
|
||||
const bufferization::BufferResultsToOutParamsOpts &options) {
|
||||
bool didFail = false;
|
||||
SymbolTable symtab(module);
|
||||
@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
|
||||
}
|
||||
SmallVector<Value, 6> outParams;
|
||||
OpBuilder builder(op);
|
||||
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
|
||||
size_t dynamicSizesIndex = 0;
|
||||
for (Value memref : replaceWithOutParams) {
|
||||
if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
|
||||
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
|
||||
? dynamicSizes[dynamicSizesIndex]
|
||||
: SmallVector<Value>();
|
||||
bool memrefStaticShape =
|
||||
cast<MemRefType>(memref.getType()).hasStaticShape();
|
||||
if (!memrefStaticShape && dynamicSize.empty()) {
|
||||
op.emitError()
|
||||
<< "cannot create out param for dynamically shaped result";
|
||||
didFail = true;
|
||||
@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
|
||||
auto allocType =
|
||||
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
|
||||
AffineMap(), memrefType.getMemorySpace());
|
||||
|
||||
if (memrefStaticShape) {
|
||||
dynamicSize = {};
|
||||
} else {
|
||||
++dynamicSizesIndex;
|
||||
dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
|
||||
}
|
||||
auto maybeOutParam =
|
||||
options.allocationFn(builder, op.getLoc(), allocType);
|
||||
options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
|
||||
if (failed(maybeOutParam)) {
|
||||
op.emitError() << "failed to create allocation op";
|
||||
didFail = true;
|
||||
@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
|
||||
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
|
||||
ModuleOp module,
|
||||
const bufferization::BufferResultsToOutParamsOpts &options) {
|
||||
// It maps the shape source of the dynamic shape memref returned by each
|
||||
// function.
|
||||
AllocDynamicSizesMap map;
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
if (!options.filterFn(&func))
|
||||
continue;
|
||||
@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
|
||||
return failure();
|
||||
if (func.isExternal())
|
||||
continue;
|
||||
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
|
||||
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (failed(updateCalls(module, options)))
|
||||
if (failed(updateCalls(module, map, options)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
|
||||
options.addResultAttribute = true;
|
||||
if (hoistStaticAllocs)
|
||||
options.hoistStaticAllocs = true;
|
||||
if (hoistDynamicAllocs)
|
||||
options.hoistDynamicAllocs = true;
|
||||
|
||||
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
|
||||
options)))
|
||||
|
||||
@ -0,0 +1,79 @@
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s
|
||||
|
||||
func.func private @single_alloc(%size : index) -> (memref<?xf32>) {
|
||||
%alloc = memref.alloc(%size) : memref<?xf32>
|
||||
return %alloc : memref<?xf32>
|
||||
}
|
||||
|
||||
func.func @single_alloc_test(%size : index) {
|
||||
%alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>)
|
||||
"test.sink"(%alloc) : (memref<?xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func private @single_alloc(
|
||||
// CHECK-SAME: %{{.*}}: index,
|
||||
// CHECK-SAME: %{{.*}}: memref<?xf32>) {
|
||||
|
||||
// CHECK-LABEL: func.func @single_alloc_test(
|
||||
// CHECK-SAME: %[[size:.*]]: index) {
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32>
|
||||
// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> ()
|
||||
// CHECK: "test.sink"(%[[alloc]]) : (memref<?xf32>) -> ()
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) {
|
||||
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
|
||||
%alloc1 = memref.alloc(%size1) : memref<?xf32>
|
||||
return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32>
|
||||
}
|
||||
|
||||
func.func @mult_alloc_test(%size0 : index, %size1: index) {
|
||||
%alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>)
|
||||
"test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @mult_alloc(
|
||||
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
|
||||
// CHECK-SAME: %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @mult_alloc_test(
|
||||
// CHECK-SAME: %[[size0:.*]]: index,
|
||||
// CHECK-SAME: %[[size1:.*]]: index) {
|
||||
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
|
||||
// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
|
||||
// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> ()
|
||||
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> ()
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) {
|
||||
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
|
||||
%alloc1 = memref.alloc() : memref<4xf32>
|
||||
%alloc2 = memref.alloc(%size1) : memref<?xf32>
|
||||
return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32>
|
||||
}
|
||||
|
||||
func.func @complex_alloc_test(%size0 : index, %size1: index) {
|
||||
%alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>)
|
||||
"test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @complex_alloc(
|
||||
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
|
||||
// CHECK-SAME: %{{.*}}: memref<?x?xf32>,
|
||||
// CHECK-SAME: %{{.*}}: memref<4xf32>,
|
||||
// CHECK-SAME: %{{.*}}: memref<?xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @complex_alloc_test(
|
||||
// CHECK-SAME: %[[size0:.*]]: index,
|
||||
// CHECK-SAME: %[[size1:.*]]: index) {
|
||||
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
|
||||
// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
|
||||
// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
|
||||
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
|
||||
// CHECK: }
|
||||
Loading…
x
Reference in New Issue
Block a user