[mlir][memref] Introduce memref.distinct_objects op (#156913)

The `distinct_objects` operation takes a list of memrefs and returns a
list of memrefs of the same types, with the additional assumption that
accesses to these memrefs will never alias with each other. This means
that loads and stores to different memrefs in the list can be safely
reordered.

The discussion
https://discourse.llvm.org/t/rfc-introducing-memref-aliasing-attributes/88049
This commit is contained in:
Ivan Butygin 2025-10-01 15:01:37 +03:00 committed by GitHub
parent 93c830597c
commit a374017bbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 164 additions and 5 deletions

View File

@ -155,7 +155,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
alignment.
alignment.
If the buffer isn't aligned to the given alignment, its result is poison.
This operation doesn't affect the semantics of a program where the
@ -170,7 +170,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
Value getViewSource() { return getMemref(); }
}];
@ -178,6 +178,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//
def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
]> {
let summary = "assumption that acesses to specific memrefs will never alias";
let description = [{
The `distinct_objects` operation takes a list of memrefs and returns the same
memrefs, with the additional assumption that accesses to them will never
alias with each other. This means that loads and stores to different
memrefs in the list can be safely reordered.
If the memrefs do alias, the load/store behavior is undefined. This
operation doesn't affect the semantics of a valid program. It is
intended for optimization purposes, allowing the compiler to generate more
efficient code based on the non-aliasing assumption. The optimization is
best-effort.
Example:
```mlir
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
```
}];
let arguments = (ins Variadic<AnyMemRef>:$operands);
let results = (outs Variadic<AnyMemRef>:$results);
let assemblyFormat = "$operands attr-dict `:` type($operands)";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//

View File

@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
}
};
struct DistinctObjectsOpLowering
: public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
using ConvertOpToLLVMPattern<
memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
LogicalResult
matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
if (operands.size() <= 1) {
// Fast path.
rewriter.replaceOp(op, operands);
return success();
}
Location loc = op.getLoc();
SmallVector<Value> ptrs;
for (auto [origOperand, newOperand] :
llvm::zip_equal(op.getOperands(), operands)) {
auto memrefType = cast<MemRefType>(origOperand.getType());
MemRefDescriptor memRefDescriptor(newOperand);
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
memrefType);
ptrs.push_back(ptr);
}
auto cond =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
// Generate separate_storage assumptions for each pair of pointers.
for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
Value ptr1 = ptrs[i];
Value ptr2 = ptrs[j];
LLVM::AssumeOp::create(rewriter, loc, cond,
LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
}
}
rewriter.replaceOp(op, operands);
return success();
}
};
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
AtomicRMWOpLowering,
ConvertExtractAlignedPointerAsIndex,
DimOpLowering,
DistinctObjectsOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
MemorySpaceCastOpLowering,
PrefetchOpLowering,
RankOpLowering,
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,

View File

@ -606,6 +606,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
}
//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//
LogicalResult DistinctObjectsOp::verify() {
if (getOperandTypes() != getResultTypes())
return emitOpError("operand types and result types must match");
if (getOperandTypes().empty())
return emitOpError("expected at least one operand");
return success();
}
LogicalResult DistinctObjectsOp::inferReturnTypes(
MLIRContext * /*context*/, std::optional<Location> /*location*/,
ValueRange operands, DictionaryAttr /*attributes*/,
OpaqueProperties /*properties*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
return success();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//

View File

@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
// -----
// ALL-LABEL: func @distinct_objects
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
}
// -----
// ALL-LABEL: func @distinct_objects_noop
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>)
func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> {
// 1-operand version is noop
// ALL-NEXT: return %[[ARG0]]
%1 = memref.distinct_objects %arg0 : memref<?xf16>
return %1 : memref<?xf16>
}
// -----
// CHECK-LABEL: func @assume_alignment_w_offset
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {

View File

@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape(
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
return
}
// -----
func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
// expected-error @+1 {{operand types and result types must match}}
%0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
return %0, %1 : memref<?xi32>, memref<?xf32>
}
// -----
func.func @distinct_objects_0_operands() {
// expected-error @+1 {{expected at least one operand}}
"memref.distinct_objects"() : () -> ()
return
}

View File

@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
return
}
// CHECK-LABEL: func @distinct_objects
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
// CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
}
// CHECK-LABEL: func @expand_collapse_shape_static
func.func @expand_collapse_shape_static(
%arg0: memref<3x4x5xf32>,