[mlir][MemRef] Canonicalize reinterpret_cast(extract_strided_metadata)
Add a canonicalizetion step for reinterpret_cast(extract_strided_metadata). This step replaces this sequence of operations by either: - A noop, i.e., the original memref is directly used, or - A plain cast of the original memref The choice is ultimately made based on whether the original memref type is equal to what the reinterpret_cast iss producing. For instance, the reinterpret_cast could be changing some dimensions from static to dynamic and in such case, we need to keep a cast. The transformation is currently only performed when the reinterpret_cast uses exactly the same arguments as what the extract_strided_metadata produces. It may be possible to be more aggressive here but I wanted to start with a relatively simple MLIR patch for my first one! Differential Revision: https://reviews.llvm.org/D132776
This commit is contained in:
parent
9af0a142e4
commit
ba916c0cf6
@ -1142,6 +1142,7 @@ def MemRef_ReinterpretCastOp
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1600,6 +1600,65 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref.
|
||||
struct ReinterpretCastOpExtractStridedMetadataFolder
|
||||
: public OpRewritePattern<ReinterpretCastOp> {
|
||||
public:
|
||||
using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ReinterpretCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto extractStridedMetadata =
|
||||
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
|
||||
if (!extractStridedMetadata)
|
||||
return failure();
|
||||
// Check if the reinterpret cast reconstructs a memref with the exact same
|
||||
// properties as the extract strided metadata.
|
||||
|
||||
// First, check that the strides are the same.
|
||||
if (extractStridedMetadata.getStrides().size() != op.getStrides().size())
|
||||
return failure();
|
||||
for (auto [extractStride, reinterpretStride] :
|
||||
llvm::zip(extractStridedMetadata.getStrides(), op.getStrides()))
|
||||
if (extractStride != reinterpretStride)
|
||||
return failure();
|
||||
|
||||
// Second, check the sizes.
|
||||
if (extractStridedMetadata.getSizes().size() != op.getSizes().size())
|
||||
return failure();
|
||||
for (auto [extractSize, reinterpretSize] :
|
||||
llvm::zip(extractStridedMetadata.getSizes(), op.getSizes()))
|
||||
if (extractSize != reinterpretSize)
|
||||
return failure();
|
||||
|
||||
// Finally, check the offset.
|
||||
if (op.getOffsets().size() != 1 &&
|
||||
extractStridedMetadata.getOffset() != *op.getOffsets().begin())
|
||||
return failure();
|
||||
|
||||
// At this point, we know that the back and forth between extract strided
|
||||
// metadata and reinterpret cast is a noop. However, the final type of the
|
||||
// reinterpret cast may not be exactly the same as the original memref.
|
||||
// E.g., it could be changing a dimension from static to dynamic. Check that
|
||||
// here and add a cast if necessary.
|
||||
Type srcTy = extractStridedMetadata.getSource().getType();
|
||||
if (srcTy == op.getResult().getType())
|
||||
rewriter.replaceOp(op, extractStridedMetadata.getSource());
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
|
||||
extractStridedMetadata.getSource());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reassociative reshape ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -740,6 +740,63 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a reinterpret cast of an equivalent extract strided metadata
|
||||
// is canonicalized to a plain cast when the destination type is different
|
||||
// than the type of the original memref.
|
||||
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
|
||||
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
|
||||
// CHECK: return %[[CAST]]
|
||||
func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, offset : ?, strides : [?, ?]> {
|
||||
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
|
||||
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
return %m2 : memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a reinterpret cast of an equivalent extract strided metadata
|
||||
// is completely removed when the original memref has the same type.
|
||||
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
|
||||
// CHECK: return %[[ARG]]
|
||||
func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2xf32>) -> memref<8x2xf32> {
|
||||
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
|
||||
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<8x2xf32>
|
||||
return %m2 : memref<8x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't simplify reinterpret cast of extract strided metadata
|
||||
// when the strides don't match.
|
||||
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
|
||||
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1]
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]> {
|
||||
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
|
||||
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
|
||||
return %m2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
|
||||
}
|
||||
// -----
|
||||
|
||||
// Check that we don't simplify reinterpret cast of extract strided metadata
|
||||
// when the offset doesn't match.
|
||||
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
|
||||
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1]
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, offset : ?, strides : [?, ?]> {
|
||||
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
|
||||
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
return %m2 : memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
|
||||
%arg1 : index) -> memref<?xf32, offset : ?, strides : [?]> {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
Loading…
x
Reference in New Issue
Block a user