[mlir][AMDGPU] Update gather_to_lds with explicit-async support (#181082)
This commit takes advantage of the new `load.async.to.lds` intrinsic in order to add an `async` mode to `gather_to_lds`. In this mode, completion of the load needs to be managed with `asyncmark` and `wait.asyncmark` intrinsics instead of being implicitly derived by alias analysis. This commit adds the flag, a lowering for it, and updates tests. Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
e17de2ec23
commit
149fa17adf
@ -1079,7 +1079,8 @@ def AMDGPU_GatherToLDSOp :
|
||||
Variadic<Index>:$srcIndices,
|
||||
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
|
||||
Variadic<Index>:$dstIndices,
|
||||
TypeAttr:$transferType
|
||||
TypeAttr:$transferType,
|
||||
UnitAttr:$async
|
||||
)>,
|
||||
Results<(outs)> {
|
||||
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
|
||||
@ -1099,6 +1100,13 @@ def AMDGPU_GatherToLDSOp :
|
||||
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
|
||||
the size of the data to be transferred and the number of threads in the subgroup.
|
||||
The transfer type must be a scalar type or a vector type with a single element type.
|
||||
* If `$async` is set, the compiler will not attempt to infer the
|
||||
memory waits needed to ensure that the DMA operation has succeeded
|
||||
before a load that might access the stored-to LDS is performed.
|
||||
Instead, the `rocdl.asyncmark` and `rocdl.wait.asyncmark N`
|
||||
operations must be used to explicitly indicate the desired completion
|
||||
behavior. This enables more precise calculation of these waits at the
|
||||
cost of requiring user management of asynchrony.
|
||||
|
||||
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
|
||||
will write to.
|
||||
@ -1106,7 +1114,7 @@ def AMDGPU_GatherToLDSOp :
|
||||
Note: only supported on gfx9 and gfx10.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
|
||||
(`async` $async^)? $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
@ -1817,11 +1817,19 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
|
||||
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
|
||||
(adaptor.getDstIndices()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
|
||||
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
|
||||
/*offset=*/rewriter.getI32IntegerAttr(0),
|
||||
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
|
||||
ArrayAttr{});
|
||||
if (op.getAsync()) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
|
||||
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
|
||||
/*offset=*/rewriter.getI32IntegerAttr(0),
|
||||
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
|
||||
ArrayAttr{});
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
|
||||
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
|
||||
/*offset=*/rewriter.getI32IntegerAttr(0),
|
||||
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
|
||||
ArrayAttr{});
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -233,3 +233,79 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu.add
|
||||
: f32, memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<64x64xf32, #gpu.address_space<workgroup>>
|
||||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @global_load_to_rocdl_async_f32
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, #gpu.address_space<global>>)
|
||||
func.func @global_load_to_rocdl_async_f32(%global : memref<128x72xf32, #gpu.address_space<global>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c12 = arith.constant 12 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
|
||||
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
|
||||
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
|
||||
// CHECK-DAG: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
|
||||
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
|
||||
// CHECK-DAG: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
|
||||
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc()
|
||||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
|
||||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
|
||||
|
||||
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
|
||||
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
|
||||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
|
||||
|
||||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
|
||||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
|
||||
|
||||
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
|
||||
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
|
||||
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
|
||||
|
||||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
|
||||
// CHECK: rocdl.load.async.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
|
||||
amdgpu.gather_to_lds async %global[%c12, %c0], %alloc[%c32, %c0]
|
||||
: f32, memref<128x72xf32, #gpu.address_space<global>>, memref<64x64xf32, #gpu.address_space<workgroup>>
|
||||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @global_load_to_rocdl_async_f32_fat_raw_buffer
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>)
|
||||
func.func @global_load_to_rocdl_async_f32_fat_raw_buffer(%global : memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c12 = arith.constant 12 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
|
||||
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
|
||||
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
|
||||
// CHECK-DAG: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
|
||||
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
|
||||
// CHECK-DAG: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
|
||||
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc()
|
||||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
|
||||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
|
||||
|
||||
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
|
||||
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
|
||||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
|
||||
|
||||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
|
||||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
|
||||
|
||||
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
|
||||
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
|
||||
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
|
||||
|
||||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
|
||||
// CHECK: rocdl.load.async.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
|
||||
amdgpu.gather_to_lds async %global[%c12, %c0], %alloc[%c32, %c0]
|
||||
: f32, memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<64x64xf32, #gpu.address_space<workgroup>>
|
||||
func.return
|
||||
}
|
||||
|
||||
@ -658,21 +658,26 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16
|
||||
|
||||
// CHECK-LABEL: func @gather_to_lds
|
||||
func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>, %smem3 : memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: amdgpu.gather_to_lds async %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: amdgpu.gather_to_lds async %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: rocdl.asyncmark
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}]
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
|
||||
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
|
||||
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
|
||||
amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
amdgpu.gather_to_lds async %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
|
||||
amdgpu.gather_to_lds async %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
|
||||
rocdl.asyncmark
|
||||
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
|
||||
amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
|
||||
rocdl.wait.asyncmark 0
|
||||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @gather_to_lds_0d
|
||||
func.func @gather_to_lds_0d(%mem1 : memref<f16>, %smem1 : memref<f16, #gpu.address_space<workgroup>>) {
|
||||
// CHECK: amdgpu.gather_to_lds %{{.*}}[], %{{.*}}[]
|
||||
amdgpu.gather_to_lds %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
|
||||
// CHECK: amdgpu.gather_to_lds async %{{.*}}[], %{{.*}}[]
|
||||
amdgpu.gather_to_lds async %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
|
||||
rocdl.asyncmark
|
||||
rocdl.wait.asyncmark 0
|
||||
func.return
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user