[mlir][AMDGPU] Update gather_to_lds with explicit-async support (#181082)

This commit takes advantage of the new `load.async.to.lds` intrinsic in
order to add an `async` mode to `gather_to_lds`. In this mode,
completion of the load needs to be managed with `asyncmark` and
`wait.asyncmark` intrinsics instead of being implicitly derived by alias
analysis.

This commit adds the flag, a lowering for it, and updates tests.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Krzysztof Drewniak 2026-02-16 12:52:35 -08:00 committed by GitHub
parent e17de2ec23
commit 149fa17adf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 113 additions and 16 deletions

View File

@ -1079,7 +1079,8 @@ def AMDGPU_GatherToLDSOp :
Variadic<Index>:$srcIndices,
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
Variadic<Index>:$dstIndices,
TypeAttr:$transferType
TypeAttr:$transferType,
UnitAttr:$async
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
@ -1099,6 +1100,13 @@ def AMDGPU_GatherToLDSOp :
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
the size of the data to be transferred and the number of threads in the subgroup.
The transfer type must be a scalar type or a vector type with a single element type.
* If `$async` is set, the compiler will not attempt to infer the
memory waits needed to ensure that the DMA operation has succeeded
before a load that might access the stored-to LDS is performed.
Instead, the `rocdl.asyncmark` and `rocdl.wait.asyncmark N`
operations must be used to explicitly indicate the desired completion
behavior. This enables more precise calculation of these waits at the
cost of requiring user management of asynchrony.
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.
@ -1106,7 +1114,7 @@ def AMDGPU_GatherToLDSOp :
Note: only supported on gfx9 and gfx10.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
(`async` $async^)? $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
let hasCanonicalizer = 1;

View File

@ -1817,11 +1817,19 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()));
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
/*offset=*/rewriter.getI32IntegerAttr(0),
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});
if (op.getAsync()) {
rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
/*offset=*/rewriter.getI32IntegerAttr(0),
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});
} else {
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
/*offset=*/rewriter.getI32IntegerAttr(0),
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});
}
return success();
}

View File

@ -233,3 +233,79 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu.add
: f32, memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<64x64xf32, #gpu.address_space<workgroup>>
func.return
}
// CHECK-LABEL: func @global_load_to_rocdl_async_f32
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, #gpu.address_space<global>>)
func.func @global_load_to_rocdl_async_f32(%global : memref<128x72xf32, #gpu.address_space<global>>) {
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
// CHECK-DAG: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: rocdl.load.async.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
amdgpu.gather_to_lds async %global[%c12, %c0], %alloc[%c32, %c0]
: f32, memref<128x72xf32, #gpu.address_space<global>>, memref<64x64xf32, #gpu.address_space<workgroup>>
func.return
}
// CHECK-LABEL: func @global_load_to_rocdl_async_f32_fat_raw_buffer
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>)
func.func @global_load_to_rocdl_async_f32_fat_raw_buffer(%global : memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>) {
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
// CHECK-DAG: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: rocdl.load.async.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
amdgpu.gather_to_lds async %global[%c12, %c0], %alloc[%c32, %c0]
: f32, memref<128x72xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<64x64xf32, #gpu.address_space<workgroup>>
func.return
}

View File

@ -658,21 +658,26 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16
// CHECK-LABEL: func @gather_to_lds
func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>, %smem3 : memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: amdgpu.gather_to_lds async %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: amdgpu.gather_to_lds async %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: rocdl.asyncmark
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}]
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
amdgpu.gather_to_lds async %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds async %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
rocdl.asyncmark
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
rocdl.wait.asyncmark 0
func.return
}
// CHECK-LABEL: func @gather_to_lds_0d
func.func @gather_to_lds_0d(%mem1 : memref<f16>, %smem1 : memref<f16, #gpu.address_space<workgroup>>) {
// CHECK: amdgpu.gather_to_lds %{{.*}}[], %{{.*}}[]
amdgpu.gather_to_lds %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
// CHECK: amdgpu.gather_to_lds async %{{.*}}[], %{{.*}}[]
amdgpu.gather_to_lds async %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
rocdl.asyncmark
rocdl.wait.asyncmark 0
func.return
}