[mlir][NVVM] Add support for barrier0-reduction operation (#167036)

Add support for `nvvm.barrier0.[popc|and|or]` operation. It is added as
a separate operation since `Barrier0Op` has no result.

https://docs.nvidia.com/cuda/nvvm-ir-spec/#barrier-and-memory-fence

This will be used in CUDA Fortran lowering: 

49f55f4991/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp (L1081)

And could be used later in the CUDA C/C++ with CIR

49f55f4991/clang/lib/Headers/__clang_cuda_device_functions.h (L524)

---------

Co-authored-by: Guray Ozen <guray.ozen@gmail.com>
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-11-12 14:56:10 -08:00 committed by GitHub
parent b6bcfdea40
commit bdf3f24ec0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 108 additions and 39 deletions

View File

@ -921,6 +921,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
}];
}
// Attrs describing the reduction operations for the barrier operation.
def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">;
def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">;
def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">;
def BarrierReduction
: I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation",
[BarrierReductionPopc, BarrierReductionAnd,
BarrierReductionOr]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def BarrierReductionAttr
: EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> {
let assemblyFormat = "`<` $value `>`";
}
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
let summary = "CTA Barrier Synchronization Op";
let description = [{
@ -935,6 +952,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
- `numberOfThreads`: Specifies the number of threads participating in the barrier.
When specified, the value must be a multiple of the warp size. If not specified,
all threads in the CTA participate in the barrier.
- `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`).
- `reductionPredicate`: specifies the predicate to be used with the
`reductionOp`.
The barrier operation guarantees that when the barrier completes, prior memory
accesses requested by participating threads are performed relative to all threads
@ -951,31 +971,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
}];
let arguments = (ins
Optional<I32>:$barrierId,
Optional<I32>:$numberOfThreads);
string llvmBuilder = [{
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
if ($numberOfThreads)
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
{id, $numberOfThreads});
else
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
OptionalAttr<BarrierReductionAttr>:$reductionOp,
Optional<I32>:$reductionPredicate);
string llvmBuilder = [{
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
if ($reductionOp)
$res = createIntrinsicCall(builder, id, args);
else
createIntrinsicCall(builder, id, args);
}];
let results = (outs Optional<I32>:$res);
let hasVerifier = 1;
let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
let assemblyFormat =
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
"($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
let builders = [
OpBuilder<(ins), [{
return build($_builder, $_state, Value{}, Value{});
let builders = [OpBuilder<(ins), [{
return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
}]>,
OpBuilder<(ins "Value":$barrierId), [{
return build($_builder, $_state, barrierId, Value{});
}]>
];
OpBuilder<(ins "Value":$barrierId), [{
return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
}]>];
}
def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">

View File

@ -1517,6 +1517,15 @@ LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
return emitOpError("reduction are only available when id is 0");
if ((getReductionOp() && !getReductionPredicate()) ||
(!getReductionOp() && getReductionPredicate()))
return emitOpError("reduction predicate and reduction operation must be "
"specified together");
return success();
}
@ -1785,6 +1794,39 @@ std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::BarrierOp>(op);
llvm::Value *barrierId = thisOp.getBarrierId()
? mt.lookupValue(thisOp.getBarrierId())
: builder.getInt32(0);
llvm::Intrinsic::ID id;
llvm::SmallVector<llvm::Value *> args;
if (thisOp.getNumberOfThreads()) {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
args.push_back(barrierId);
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
} else if (thisOp.getReductionOp()) {
switch (*thisOp.getReductionOp()) {
case NVVM::BarrierReduction::AND:
id = llvm::Intrinsic::nvvm_barrier0_and;
break;
case NVVM::BarrierReduction::OR:
id = llvm::Intrinsic::nvvm_barrier0_or;
break;
case NVVM::BarrierReduction::POPC:
id = llvm::Intrinsic::nvvm_barrier0_popc;
break;
}
args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
} else {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
args.push_back(barrierId);
}
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);

View File

@ -0,0 +1,20 @@
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
// CHECK-LABEL: @llvm_nvvm_barrier(
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
nvvm.barrier id = %barID
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
%0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
%1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
%2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32
llvm.return
}

View File

@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
llvm.return %1 : f32
}
// CHECK-LABEL: @llvm_nvvm_barrier0
llvm.func @llvm_nvvm_barrier0() {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier0
llvm.return
}
// CHECK-LABEL: @llvm_nvvm_barrier(
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
nvvm.barrier id = %barID
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
llvm.return
}
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
llvm.func @llvm_nvvm_cluster_arrive() {
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()