diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index d11d196207b5..4c13c5ddb288 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -921,6 +921,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { }]; } +// Attrs describing the reduction operations for the barrier operation. +def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">; +def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">; +def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">; + +def BarrierReduction + : I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation", + [BarrierReductionPopc, BarrierReductionAnd, + BarrierReductionOr]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def BarrierReductionAttr + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { let summary = "CTA Barrier Synchronization Op"; let description = [{ @@ -935,6 +952,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { - `numberOfThreads`: Specifies the number of threads participating in the barrier. When specified, the value must be a multiple of the warp size. If not specified, all threads in the CTA participate in the barrier. + - `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`). + - `reductionPredicate`: specifies the predicate to be used with the + `reductionOp`. The barrier operation guarantees that when the barrier completes, prior memory accesses requested by participating threads are performed relative to all threads @@ -951,31 +971,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar) }]; - let arguments = (ins - Optional:$barrierId, - Optional:$numberOfThreads); - string llvmBuilder = [{ - llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0); - if ($numberOfThreads) - createIntrinsicCall( - builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count, - {id, $numberOfThreads}); - else - createIntrinsicCall( - builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id}); + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); }]; + + let arguments = (ins Optional:$barrierId, Optional:$numberOfThreads, + OptionalAttr:$reductionOp, + Optional:$reductionPredicate); + string llvmBuilder = [{ + auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + if ($reductionOp) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); + }]; + let results = (outs Optional:$res); + let hasVerifier = 1; - let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict"; + let assemblyFormat = + "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? " + "($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict"; - let builders = [ - OpBuilder<(ins), [{ - return build($_builder, $_state, Value{}, Value{}); + let builders = [OpBuilder<(ins), [{ + return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{}); }]>, - OpBuilder<(ins "Value":$barrierId), [{ - return build($_builder, $_state, barrierId, Value{}); - }]> - ]; + OpBuilder<(ins "Value":$barrierId), [{ + return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{}); + }]>]; } def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive"> diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e0c25ab6cdef..0f7b3638fb30 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1517,6 +1517,15 @@ LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); + + if (getBarrierId() && (getReductionOp() || getReductionPredicate())) + return emitOpError("reduction are only available when id is 0"); + + if ((getReductionOp() && !getReductionPredicate()) || + (!getReductionOp() && getReductionPredicate())) + return emitOpError("reduction predicate and reduction operation must be " + "specified together"); + return success(); } @@ -1785,6 +1794,39 @@ std::string NVVM::MBarrierTryWaitParityOp::getPtx() { // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + llvm::Value *barrierId = thisOp.getBarrierId() + ? mt.lookupValue(thisOp.getBarrierId()) + : builder.getInt32(0); + llvm::Intrinsic::ID id; + llvm::SmallVector args; + if (thisOp.getNumberOfThreads()) { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count; + args.push_back(barrierId); + args.push_back(mt.lookupValue(thisOp.getNumberOfThreads())); + } else if (thisOp.getReductionOp()) { + switch (*thisOp.getReductionOp()) { + case NVVM::BarrierReduction::AND: + id = llvm::Intrinsic::nvvm_barrier0_and; + break; + case NVVM::BarrierReduction::OR: + id = llvm::Intrinsic::nvvm_barrier0_or; + break; + case NVVM::BarrierReduction::POPC: + id = llvm::Intrinsic::nvvm_barrier0_popc; + break; + } + args.push_back(mt.lookupValue(thisOp.getReductionPredicate())); + } else { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all; + args.push_back(barrierId); + } + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir new file mode 100644 index 000000000000..d89f93101c1f --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @llvm_nvvm_barrier( +// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]]) +llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) { + // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) + nvvm.barrier + // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) + nvvm.barrier id = %barID + // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) + nvvm.barrier id = %barID number_of_threads = %numberOfThreads + // CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]]) + %0 = nvvm.barrier #nvvm.reduction %redOperand -> i32 + // CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]]) + %1 = nvvm.barrier #nvvm.reduction %redOperand -> i32 + // CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]]) + %2 = nvvm.barrier #nvvm.reduction %redOperand -> i32 + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index fec54cbf5e3e..5cba5c4fceef 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 { llvm.return %1 : f32 } -// CHECK-LABEL: @llvm_nvvm_barrier0 -llvm.func @llvm_nvvm_barrier0() { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier0 - llvm.return -} - -// CHECK-LABEL: @llvm_nvvm_barrier( -// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]]) -llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) - nvvm.barrier id = %barID - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) - nvvm.barrier id = %barID number_of_threads = %numberOfThreads - llvm.return -} - // CHECK-LABEL: @llvm_nvvm_cluster_arrive llvm.func @llvm_nvvm_cluster_arrive() { // CHECK: call void @llvm.nvvm.barrier.cluster.arrive()