[MemRef] Add dim reification for AssumeAlignmentOp (#174477)

This commit is contained in:
Jorn Tuyls 2026-01-07 09:30:42 +01:00 committed by GitHub
parent b919d62eae
commit c85b8ff4d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 1 deletions

View File

@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
["reifyDimOfResult"]>
]> {
let summary =
"assumption that gives alignment information to the input memref";

View File

@ -606,6 +606,13 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
}
FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
int resultIndex,
int dim) {
assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
return getMixedSize(builder, getLoc(), getMemref(), dim);
}
//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//

View File

@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
}
return %result : tensor<?x?xf32>
}
// -----
// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
// CHECK-LABEL: func @dim_of_assume_alignment_static(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: return %[[C2]], %[[C3]] : index, index
func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
%d0 = memref.dim %0, %c0 : memref<2x3xf32>
%d1 = memref.dim %0, %c1 : memref<2x3xf32>
return %d0, %d1 : index, index
}
// -----
// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
// CHECK: return %[[C4]], %[[D1]] : index, index
func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
%d0 = memref.dim %0, %c0 : memref<4x?xf32>
%d1 = memref.dim %0, %c1 : memref<4x?xf32>
return %d0, %d1 : index, index
}