[mlir][bufferization] Unranked memref support for clone (#94757)
bufferization.clone does not currently support lowering to memref for unranked memrefs. This interferes with bufferizing unranked tensors at boundaries where a clone operation is needed. ``` func.func @foo(%input: memref<*xf32>, %shape: memref<?xindex>) -> memref<*xf32> { %reshape = memref.reshape %input(%shape) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> %copy = bufferization.clone %reshape : memref<*xf32> to memref<*xf32> return %copy : memref<*xf32> } ``` Patterns such as that are possibly when bufferizing functions with input and output unranked tensors. The clone operation currently fails to legalize during the bufferization-to-memref conversion with unranked memrefs. This change modifies the conversion of bufferization.clone to memref to generate the runtime calculations and allocation to allow for cloning an unranked memref.
This commit is contained in:
parent
0aeaa2d93d
commit
28d6aa90b0
@ -42,25 +42,60 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Check for unranked memref types which are currently not supported.
|
||||
Location loc = op->getLoc();
|
||||
|
||||
Type type = op.getType();
|
||||
if (isa<UnrankedMemRefType>(type)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "UnrankedMemRefType is not supported.");
|
||||
}
|
||||
Value alloc;
|
||||
|
||||
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
|
||||
// Constants
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
// Dynamically evaluate the size and shape of the unranked memref
|
||||
Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
|
||||
MemRefType allocType =
|
||||
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
|
||||
Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
|
||||
|
||||
// Create a loop to query dimension sizes, store them as a shape, and
|
||||
// compute the total size of the memref
|
||||
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
|
||||
ValueRange args) {
|
||||
auto acc = args.front();
|
||||
auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
|
||||
|
||||
rewriter.create<memref::StoreOp>(loc, dim, shape, i);
|
||||
acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
|
||||
|
||||
rewriter.create<scf::YieldOp>(loc, acc);
|
||||
};
|
||||
auto size = rewriter
|
||||
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
|
||||
loopBody)
|
||||
.getResult(0);
|
||||
|
||||
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
|
||||
unrankedType.getElementType());
|
||||
|
||||
// Allocate new memref with 1D dynamic shape, then reshape into the
|
||||
// shape of the original unranked memref
|
||||
alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
|
||||
alloc =
|
||||
rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
|
||||
} else {
|
||||
MemRefType memrefType = cast<MemRefType>(type);
|
||||
MemRefLayoutAttrInterface layout;
|
||||
auto allocType =
|
||||
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
|
||||
layout, memrefType.getMemorySpace());
|
||||
// Since this implementation always allocates, certain result types of the
|
||||
// clone op cannot be lowered.
|
||||
// Since this implementation always allocates, certain result types of
|
||||
// the clone op cannot be lowered.
|
||||
if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
|
||||
return failure();
|
||||
|
||||
// Transform a clone operation into alloc + copy operation and pay
|
||||
// attention to the shape dimensions.
|
||||
Location loc = op->getLoc();
|
||||
SmallVector<Value, 4> dynamicOperands;
|
||||
for (int i = 0; i < memrefType.getRank(); ++i) {
|
||||
if (!memrefType.isDynamicDim(i))
|
||||
@ -70,11 +105,13 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
|
||||
}
|
||||
|
||||
// Allocate a memref with identity layout.
|
||||
Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
|
||||
dynamicOperands);
|
||||
alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
|
||||
// Cast the allocation to the specified type if needed.
|
||||
if (memrefType != allocType)
|
||||
alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
|
||||
alloc =
|
||||
rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
|
||||
return success();
|
||||
|
@ -22,7 +22,7 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
|
||||
}
|
||||
|
||||
// CHECK: %[[CONST:.*]] = arith.constant
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
|
||||
// CHECK: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
|
||||
// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
|
||||
// CHECK-NEXT: memref.dealloc %[[ARG]]
|
||||
@ -30,13 +30,26 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @conversion_unknown
|
||||
func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
|
||||
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
|
||||
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
|
||||
memref.dealloc %arg0 : memref<*xf32>
|
||||
return %1 : memref<*xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[RANK:.*]] = memref.rank %[[ARG:.*]]
|
||||
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
|
||||
// CHECK-NEXT: %[[FOR:.*]] = scf.for
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]] %[[ARG:.*]]
|
||||
// CHECK-NEXT: memref.store %[[DIM:.*]], %[[ALLOCA:.*]][%[[ARG:.*]]]
|
||||
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG:.*]], %[[DIM:.*]]
|
||||
// CHECK-NEXT: scf.yield %[[MUL:.*]]
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[FOR:.*]])
|
||||
// CHECK-NEXT: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC:.*]]
|
||||
// CHECK-NEXT: memref.copy
|
||||
// CHECK-NEXT: memref.dealloc
|
||||
// CHECK-NEXT: return %[[RESHAPE:.*]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conversion_with_layout_map(
|
||||
|
Loading…
x
Reference in New Issue
Block a user