[mlir][memref] Introduce memref.distinct_objects op (#156913)
The `distinct_objects` operation takes a list of memrefs and returns a list of memrefs of the same types, with the additional assumption that accesses to these memrefs will never alias with each other. This means that loads and stores to different memrefs in the list can be safely reordered. The discussion https://discourse.llvm.org/t/rfc-introducing-memref-aliasing-attributes/88049
This commit is contained in:
parent
93c830597c
commit
a374017bbc
@ -155,7 +155,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
|
||||
The `assume_alignment` operation takes a memref and an integer alignment
|
||||
value. It returns a new SSA value of the same memref type, but associated
|
||||
with the assumption that the underlying buffer is aligned to the given
|
||||
alignment.
|
||||
alignment.
|
||||
|
||||
If the buffer isn't aligned to the given alignment, its result is poison.
|
||||
This operation doesn't affect the semantics of a program where the
|
||||
@ -170,7 +170,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
|
||||
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
|
||||
|
||||
|
||||
Value getViewSource() { return getMemref(); }
|
||||
}];
|
||||
|
||||
@ -178,6 +178,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DistinctObjectsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
|
||||
]> {
|
||||
let summary = "assumption that acesses to specific memrefs will never alias";
|
||||
let description = [{
|
||||
The `distinct_objects` operation takes a list of memrefs and returns the same
|
||||
memrefs, with the additional assumption that accesses to them will never
|
||||
alias with each other. This means that loads and stores to different
|
||||
memrefs in the list can be safely reordered.
|
||||
|
||||
If the memrefs do alias, the load/store behavior is undefined. This
|
||||
operation doesn't affect the semantics of a valid program. It is
|
||||
intended for optimization purposes, allowing the compiler to generate more
|
||||
efficient code based on the non-aliasing assumption. The optimization is
|
||||
best-effort.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Variadic<AnyMemRef>:$operands);
|
||||
let results = (outs Variadic<AnyMemRef>:$results);
|
||||
|
||||
let assemblyFormat = "$operands attr-dict `:` type($operands)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
|
||||
}
|
||||
};
|
||||
|
||||
struct DistinctObjectsOpLowering
|
||||
: public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
|
||||
explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
|
||||
: ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ValueRange operands = adaptor.getOperands();
|
||||
if (operands.size() <= 1) {
|
||||
// Fast path.
|
||||
rewriter.replaceOp(op, operands);
|
||||
return success();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> ptrs;
|
||||
for (auto [origOperand, newOperand] :
|
||||
llvm::zip_equal(op.getOperands(), operands)) {
|
||||
auto memrefType = cast<MemRefType>(origOperand.getType());
|
||||
MemRefDescriptor memRefDescriptor(newOperand);
|
||||
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
|
||||
memrefType);
|
||||
ptrs.push_back(ptr);
|
||||
}
|
||||
|
||||
auto cond =
|
||||
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
|
||||
// Generate separate_storage assumptions for each pair of pointers.
|
||||
for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
|
||||
for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
|
||||
Value ptr1 = ptrs[i];
|
||||
Value ptr2 = ptrs[j];
|
||||
LLVM::AssumeOp::create(rewriter, loc, cond,
|
||||
LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
|
||||
// The memref descriptor being an SSA value, there is no need to clean it up
|
||||
// in any way.
|
||||
@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
|
||||
patterns.add<
|
||||
AllocaOpLowering,
|
||||
AllocaScopeOpLowering,
|
||||
AtomicRMWOpLowering,
|
||||
AssumeAlignmentOpLowering,
|
||||
AtomicRMWOpLowering,
|
||||
ConvertExtractAlignedPointerAsIndex,
|
||||
DimOpLowering,
|
||||
DistinctObjectsOpLowering,
|
||||
ExtractStridedMetadataOpLowering,
|
||||
GenericAtomicRMWOpLowering,
|
||||
GetGlobalMemrefOpLowering,
|
||||
LoadOpLowering,
|
||||
MemRefCastOpLowering,
|
||||
MemorySpaceCastOpLowering,
|
||||
MemRefReinterpretCastOpLowering,
|
||||
MemRefReshapeOpLowering,
|
||||
MemorySpaceCastOpLowering,
|
||||
PrefetchOpLowering,
|
||||
RankOpLowering,
|
||||
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
|
||||
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
|
||||
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
|
||||
StoreOpLowering,
|
||||
SubViewOpLowering,
|
||||
TransposeOpLowering,
|
||||
|
||||
@ -606,6 +606,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
|
||||
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DistinctObjectsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult DistinctObjectsOp::verify() {
|
||||
if (getOperandTypes() != getResultTypes())
|
||||
return emitOpError("operand types and result types must match");
|
||||
|
||||
if (getOperandTypes().empty())
|
||||
return emitOpError("expected at least one operand");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DistinctObjectsOp::inferReturnTypes(
|
||||
MLIRContext * /*context*/, std::optional<Location> /*location*/,
|
||||
ValueRange operands, DictionaryAttr /*attributes*/,
|
||||
OpaqueProperties /*properties*/, RegionRange /*regions*/,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
|
||||
|
||||
// -----
|
||||
|
||||
// ALL-LABEL: func @distinct_objects
|
||||
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
|
||||
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
|
||||
// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
|
||||
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
|
||||
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
|
||||
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
|
||||
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// ALL-LABEL: func @distinct_objects_noop
|
||||
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>)
|
||||
func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> {
|
||||
// 1-operand version is noop
|
||||
// ALL-NEXT: return %[[ARG0]]
|
||||
%1 = memref.distinct_objects %arg0 : memref<?xf16>
|
||||
return %1 : memref<?xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @assume_alignment_w_offset
|
||||
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
|
||||
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
|
||||
|
||||
@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape(
|
||||
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
|
||||
// expected-error @+1 {{operand types and result types must match}}
|
||||
%0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
|
||||
return %0, %1 : memref<?xi32>, memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @distinct_objects_0_operands() {
|
||||
// expected-error @+1 {{expected at least one operand}}
|
||||
"memref.distinct_objects"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @distinct_objects
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
|
||||
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
|
||||
// CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expand_collapse_shape_static
|
||||
func.func @expand_collapse_shape_static(
|
||||
%arg0: memref<3x4x5xf32>,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user