[mlir][bufferization] Unranked memref support for clone (#94757)

bufferization.clone does not currently support lowering to memref for
unranked memrefs. This interferes with bufferizing unranked tensors at
boundaries where a clone operation is needed.
```
func.func @foo(%input: memref<*xf32>, %shape: memref<?xindex>) -> memref<*xf32>
{
  %reshape = memref.reshape %input(%shape) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
  %copy = bufferization.clone %reshape : memref<*xf32> to memref<*xf32>
  return %copy : memref<*xf32>
}
```
Patterns such as that are possibly when bufferizing functions with input
and output unranked tensors. The clone operation currently fails to
legalize during the bufferization-to-memref conversion with unranked
memrefs.

This change modifies the conversion of bufferization.clone to memref to
generate the runtime calculations and allocation to allow for cloning an
unranked memref.
This commit is contained in:
ryankima 2024-06-13 09:58:00 -04:00 committed by GitHub
parent 0aeaa2d93d
commit 28d6aa90b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 32 deletions

View File

@ -42,25 +42,60 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
LogicalResult
matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check for unranked memref types which are currently not supported.
Location loc = op->getLoc();
Type type = op.getType();
if (isa<UnrankedMemRefType>(type)) {
return rewriter.notifyMatchFailure(
op, "UnrankedMemRefType is not supported.");
}
Value alloc;
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
// Constants
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
// Dynamically evaluate the size and shape of the unranked memref
Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
MemRefType allocType =
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
// Create a loop to query dimension sizes, store them as a shape, and
// compute the total size of the memref
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
ValueRange args) {
auto acc = args.front();
auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
rewriter.create<memref::StoreOp>(loc, dim, shape, i);
acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
rewriter.create<scf::YieldOp>(loc, acc);
};
auto size = rewriter
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
loopBody)
.getResult(0);
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
unrankedType.getElementType());
// Allocate new memref with 1D dynamic shape, then reshape into the
// shape of the original unranked memref
alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
alloc =
rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
} else {
MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
layout, memrefType.getMemorySpace());
// Since this implementation always allocates, certain result types of the
// clone op cannot be lowered.
// Since this implementation always allocates, certain result types of
// the clone op cannot be lowered.
if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
return failure();
// Transform a clone operation into alloc + copy operation and pay
// attention to the shape dimensions.
Location loc = op->getLoc();
SmallVector<Value, 4> dynamicOperands;
for (int i = 0; i < memrefType.getRank(); ++i) {
if (!memrefType.isDynamicDim(i))
@ -70,11 +105,13 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
}
// Allocate a memref with identity layout.
Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
dynamicOperands);
alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
// Cast the allocation to the specified type if needed.
if (memrefType != allocType)
alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
alloc =
rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
}
rewriter.replaceOp(op, alloc);
rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
return success();

View File

@ -22,7 +22,7 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
}
// CHECK: %[[CONST:.*]] = arith.constant
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
// CHECK: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
// CHECK-NEXT: memref.dealloc %[[ARG]]
@ -30,13 +30,26 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
// -----
// CHECK-LABEL: @conversion_unknown
func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
memref.dealloc %arg0 : memref<*xf32>
return %1 : memref<*xf32>
}
// CHECK: %[[RANK:.*]] = memref.rank %[[ARG:.*]]
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
// CHECK-NEXT: %[[FOR:.*]] = scf.for
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]] %[[ARG:.*]]
// CHECK-NEXT: memref.store %[[DIM:.*]], %[[ALLOCA:.*]][%[[ARG:.*]]]
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG:.*]], %[[DIM:.*]]
// CHECK-NEXT: scf.yield %[[MUL:.*]]
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[FOR:.*]])
// CHECK-NEXT: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC:.*]]
// CHECK-NEXT: memref.copy
// CHECK-NEXT: memref.dealloc
// CHECK-NEXT: return %[[RESHAPE:.*]]
// -----
// CHECK-LABEL: func @conversion_with_layout_map(