[MemRef] Add dim reification for AssumeAlignmentOp (#174477)
This commit is contained in:
parent
b919d62eae
commit
c85b8ff4d7
@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
|
||||
Pure,
|
||||
ViewLikeOpInterface,
|
||||
SameOperandsAndResultType,
|
||||
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
|
||||
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
|
||||
["reifyDimOfResult"]>
|
||||
]> {
|
||||
let summary =
|
||||
"assumption that gives alignment information to the input memref";
|
||||
|
||||
@ -606,6 +606,13 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
|
||||
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
|
||||
}
|
||||
|
||||
FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
|
||||
int resultIndex,
|
||||
int dim) {
|
||||
assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
|
||||
return getMixedSize(builder, getLoc(), getMemref(), dim);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DistinctObjectsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
|
||||
}
|
||||
return %result : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
|
||||
// CHECK-LABEL: func @dim_of_assume_alignment_static(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: return %[[C2]], %[[C3]] : index, index
|
||||
func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
|
||||
%d0 = memref.dim %0, %c0 : memref<2x3xf32>
|
||||
%d1 = memref.dim %0, %c1 : memref<2x3xf32>
|
||||
return %d0, %d1 : index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
|
||||
// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
|
||||
// CHECK: return %[[C4]], %[[D1]] : index, index
|
||||
func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
|
||||
%d0 = memref.dim %0, %c0 : memref<4x?xf32>
|
||||
%d1 = memref.dim %0, %c1 : memref<4x?xf32>
|
||||
return %d0, %d1 : index, index
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user