diff --git a/mlir/include/mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h b/mlir/include/mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h new file mode 100644 index 000000000000..e156c3093e21 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h @@ -0,0 +1,27 @@ +//===- GPUToLLVMSPVPass.h - Convert GPU kernel to LLVM operations *- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUTOLLVMSPV_GPUTOLLVMSPVPASS_H_ +#define MLIR_CONVERSION_GPUTOLLVMSPV_GPUTOLLVMSPVPASS_H_ + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTGPUOPSTOLLVMSPVOPS +#include "mlir/Conversion/Passes.h.inc" + +void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTOLLVMSPV_GPUTOLLVMSPVPASS_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 2179ae18ac07..7700299b3a4f 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -34,6 +34,7 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index e6d678dc1b12..eb58f4adc31d 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -508,6 +508,24 @@ def LowerHostCodeToLLVMPass : Pass<"lower-host-to-llvm", "ModuleOp"> { let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// GPUToLLVMSPV +//===----------------------------------------------------------------------===// + +def ConvertGpuOpsToLLVMSPVOps : Pass<"convert-gpu-to-llvm-spv", "gpu::GPUModuleOp"> { + let summary = + "Generate LLVM operations to be ingested by a SPIR-V backend for gpu operations"; + let dependentDialects = [ + "LLVM::LLVMDialect", + "spirv::SPIRVDialect", + ]; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word">, + ]; +} + //===----------------------------------------------------------------------===// // GPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 41ab7046b91c..0a03a2e133db 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -23,6 +23,7 @@ add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) +add_subdirectory(GPUToLLVMSPV) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) add_subdirectory(GPUToSPIRV) diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/CMakeLists.txt b/mlir/lib/Conversion/GPUToLLVMSPV/CMakeLists.txt new file mode 100644 index 000000000000..da5650b2b68d --- /dev/null +++ b/mlir/lib/Conversion/GPUToLLVMSPV/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_conversion_library(MLIRGPUToLLVMSPV + GPUToLLVMSPV.cpp + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRSPIRVDialect +) diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp new file mode 100644 index 000000000000..30812a330ef1 --- /dev/null +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -0,0 +1,329 @@ +//===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +namespace mlir { +#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, + StringRef name, + ArrayRef paramTypes, + Type resultType) { + auto func = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTable, name)); + if (!func) { + OpBuilder b(symbolTable->getRegion(0)); + func = b.create( + symbolTable->getLoc(), name, + LLVM::LLVMFunctionType::get(resultType, paramTypes)); + func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + } + return func; +} + +static LLVM::CallOp createSPIRVBuiltinCall(Location loc, + ConversionPatternRewriter &rewriter, + LLVM::LLVMFuncOp func, + ValueRange args) { + auto call = rewriter.create(loc, func, args); + call.setCConv(func.getCConv()); + return call; +} + +namespace { +//===----------------------------------------------------------------------===// +// Barriers +//===----------------------------------------------------------------------===// + +/// Replace `gpu.barrier` with an `llvm.call` to `barrier` with +/// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope: +/// ``` +/// // gpu.barrier +/// %c1 = llvm.mlir.constant(1: i32) : i32 +/// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> () +/// ``` +struct GPUBarrierConversion final : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + constexpr StringLiteral funcName = "_Z7barrierj"; + + Operation *moduleOp = op->getParentWithTrait(); + assert(moduleOp && "Expecting module"); + Type flagTy = rewriter.getI32Type(); + Type voidTy = rewriter.getType(); + LLVM::LLVMFuncOp func = + lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy); + + // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`. + // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`. + constexpr int64_t localMemFenceFlag = 1; + Location loc = op->getLoc(); + Value flag = + rewriter.create(loc, flagTy, localMemFenceFlag); + rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SPIR-V Builtins +//===----------------------------------------------------------------------===// + +/// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with +/// a constant argument for the `dimension` attribute. Return type will depend +/// on index width option: +/// ``` +/// // %thread_id_y = gpu.thread_id y +/// %c1 = llvm.mlir.constant(1: i32) : i32 +/// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64 +/// ``` +struct LaunchConfigConversion : ConvertToLLVMPattern { + LaunchConfigConversion(StringRef funcName, StringRef rootOpName, + MLIRContext *context, + const LLVMTypeConverter &typeConverter, + PatternBenefit benefit) + : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), + funcName(funcName) {} + + virtual gpu::Dimension getDimension(Operation *op) const = 0; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Operation *moduleOp = op->getParentWithTrait(); + assert(moduleOp && "Expecting module"); + Type dimTy = rewriter.getI32Type(); + Type indexTy = getTypeConverter()->getIndexType(); + LLVM::LLVMFuncOp func = + lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, indexTy); + + Location loc = op->getLoc(); + gpu::Dimension dim = getDimension(op); + Value dimVal = rewriter.create(loc, dimTy, + static_cast(dim)); + rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); + return success(); + } + + StringRef funcName; +}; + +template +struct LaunchConfigOpConversion final : LaunchConfigConversion { + static StringRef getFuncName(); + + explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(), + &typeConverter.getContext(), typeConverter, + benefit) {} + + gpu::Dimension getDimension(Operation *op) const final { + return cast(op).getDimension(); + } +}; + +template <> +StringRef LaunchConfigOpConversion::getFuncName() { + return "_Z12get_group_idj"; +} + +template <> +StringRef LaunchConfigOpConversion::getFuncName() { + return "_Z14get_num_groupsj"; +} + +template <> +StringRef LaunchConfigOpConversion::getFuncName() { + return "_Z14get_local_sizej"; +} + +template <> +StringRef LaunchConfigOpConversion::getFuncName() { + return "_Z12get_local_idj"; +} + +template <> +StringRef LaunchConfigOpConversion::getFuncName() { + return "_Z13get_global_idj"; +} + +//===----------------------------------------------------------------------===// +// Shuffles +//===----------------------------------------------------------------------===// + +/// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V +/// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a +/// `true` constant for the `valid` result type. Conversion will only take place +/// if `width` is constant and equal to the `subgroup` pass option: +/// ``` +/// // %0 = gpu.shuffle idx %value, %offset, %width : f64 +/// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset) +/// : (f64, i32) -> f64 +/// ``` +struct GPUShuffleConversion final : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + static StringRef getBaseName(gpu::ShuffleMode mode) { + switch (mode) { + case gpu::ShuffleMode::IDX: + return "sub_group_shuffle"; + case gpu::ShuffleMode::XOR: + return "sub_group_shuffle_xor"; + case gpu::ShuffleMode::UP: + return "sub_group_shuffle_up"; + case gpu::ShuffleMode::DOWN: + return "sub_group_shuffle_down"; + } + llvm_unreachable("Unhandled shuffle mode"); + } + + static StringRef getTypeMangling(Type type) { + return TypeSwitch(type) + .Case([](auto) { return "fj"; }) + .Case([](auto) { return "dj"; }) + .Case([](auto intTy) { + switch (intTy.getWidth()) { + case 32: + return "ij"; + case 64: + return "lj"; + } + llvm_unreachable("Invalid integer width"); + }); + } + + static std::string getFuncName(gpu::ShuffleOp op) { + StringRef baseName = getBaseName(op.getMode()); + StringRef typeMangling = getTypeMangling(op.getType(0)); + return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName, + typeMangling); + } + + /// Get the subgroup size from the target or return a default. + static int getSubgroupSize(Operation *op) { + return spirv::lookupTargetEnvOrDefault(op) + .getResourceLimits() + .getSubgroupSize(); + } + + static bool hasValidWidth(gpu::ShuffleOp op) { + llvm::APInt val; + Value width = op.getWidth(); + return matchPattern(width, m_ConstantInt(&val)) && + val == getSubgroupSize(op); + } + + LogicalResult + matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + if (!hasValidWidth(op)) + return rewriter.notifyMatchFailure( + op, "shuffle width and subgroup size mismatch"); + + std::string funcName = getFuncName(op); + + Operation *moduleOp = op->getParentWithTrait(); + assert(moduleOp && "Expecting module"); + Type valueType = adaptor.getValue().getType(); + Type offsetType = adaptor.getOffset().getType(); + Type resultType = valueType; + LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( + moduleOp, funcName, {valueType, offsetType}, resultType); + + Location loc = op->getLoc(); + std::array args{adaptor.getValue(), adaptor.getOffset()}; + Value result = + createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); + Value trueVal = + rewriter.create(loc, rewriter.getI1Type(), true); + rewriter.replaceOp(op, {result, trueVal}); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// GPU To LLVM-SPV Pass. +//===----------------------------------------------------------------------===// + +struct GPUToLLVMSPVConversionPass final + : impl::ConvertGpuOpsToLLVMSPVOpsBase { + using Base::Base; + + void runOnOperation() final { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + LowerToLLVMOptions options(context); + if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(indexBitwidth); + + LLVMTypeConverter converter(context, options); + LLVMConversionTarget target(*context); + + target.addIllegalOp(); + + populateGpuToLLVMSPVConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GPU To LLVM-SPV Patterns. +//===----------------------------------------------------------------------===// + +namespace mlir { +void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add, + LaunchConfigOpConversion, + LaunchConfigOpConversion, + LaunchConfigOpConversion, + LaunchConfigOpConversion>(typeConverter); +} +} // namespace mlir diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir new file mode 100644 index 000000000000..654041b8e9aa --- /dev/null +++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir @@ -0,0 +1,209 @@ +// RUN: mlir-opt -pass-pipeline="builtin.module(gpu.module(convert-gpu-to-llvm-spv))" -split-input-file -verify-diagnostics %s \ +// RUN: | FileCheck --check-prefixes=CHECK-64,CHECK %s +// RUN: mlir-opt -pass-pipeline="builtin.module(gpu.module(convert-gpu-to-llvm-spv{index-bitwidth=32}))" -split-input-file -verify-diagnostics %s \ +// RUN: | FileCheck --check-prefixes=CHECK-32,CHECK %s + +gpu.module @builtins { + // CHECK-64: llvm.func spir_funccc @_Z14get_num_groupsj(i32) -> i64 + // CHECK-64: llvm.func spir_funccc @_Z12get_local_idj(i32) -> i64 + // CHECK-64: llvm.func spir_funccc @_Z14get_local_sizej(i32) -> i64 + // CHECK-64: llvm.func spir_funccc @_Z13get_global_idj(i32) -> i64 + // CHECK-64: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 + // CHECK-32: llvm.func spir_funccc @_Z14get_num_groupsj(i32) -> i32 + // CHECK-32: llvm.func spir_funccc @_Z12get_local_idj(i32) -> i32 + // CHECK-32: llvm.func spir_funccc @_Z14get_local_sizej(i32) -> i32 + // CHECK-32: llvm.func spir_funccc @_Z13get_global_idj(i32) -> i32 + // CHECK-32: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i32 + + // CHECK-LABEL: gpu_block_id + func.func @gpu_block_id() -> (index, index, index) { + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_group_idj([[C0]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_group_idj([[C0]]) : (i32) -> i32 + %block_id_x = gpu.block_id x + // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_group_idj([[C1]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_group_idj([[C1]]) : (i32) -> i32 + %block_id_y = gpu.block_id y + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_group_idj([[C2]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_group_idj([[C2]]) : (i32) -> i32 + %block_id_z = gpu.block_id z + return %block_id_x, %block_id_y, %block_id_z : index, index, index + } + + // CHECK-LABEL: gpu_global_id + func.func @gpu_global_id() -> (index, index, index) { + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z13get_global_idj([[C0]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z13get_global_idj([[C0]]) : (i32) -> i32 + %global_id_x = gpu.global_id x + // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z13get_global_idj([[C1]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z13get_global_idj([[C1]]) : (i32) -> i32 + %global_id_y = gpu.global_id y + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z13get_global_idj([[C2]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z13get_global_idj([[C2]]) : (i32) -> i32 + %global_id_z = gpu.global_id z + return %global_id_x, %global_id_y, %global_id_z : index, index, index + } + + // CHECK-LABEL: gpu_block_dim + func.func @gpu_block_dim() -> (index, index, index) { + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_local_sizej([[C0]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_local_sizej([[C0]]) : (i32) -> i32 + %block_dim_x = gpu.block_dim x + // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_local_sizej([[C1]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_local_sizej([[C1]]) : (i32) -> i32 + %block_dim_y = gpu.block_dim y + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_local_sizej([[C2]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_local_sizej([[C2]]) : (i32) -> i32 + %block_dim_z = gpu.block_dim z + return %block_dim_x, %block_dim_y, %block_dim_z : index, index, index + } + + // CHECK-LABEL: gpu_thread_id + func.func @gpu_thread_id() -> (index, index, index) { + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_local_idj([[C0]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_local_idj([[C0]]) : (i32) -> i32 + %thread_id_x = gpu.thread_id x + // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_local_idj([[C1]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_local_idj([[C1]]) : (i32) -> i32 + %thread_id_y = gpu.thread_id y + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z12get_local_idj([[C2]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z12get_local_idj([[C2]]) : (i32) -> i32 + %thread_id_z = gpu.thread_id z + return %thread_id_x, %thread_id_y, %thread_id_z : index, index, index + } + + // CHECK-LABEL: gpu_grid_dim + func.func @gpu_grid_dim() -> (index, index, index) { + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_num_groupsj([[C0]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_num_groupsj([[C0]]) : (i32) -> i32 + %grid_dim_x = gpu.grid_dim x + // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_num_groupsj([[C1]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_num_groupsj([[C1]]) : (i32) -> i32 + %grid_dim_y = gpu.grid_dim y + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-64: llvm.call spir_funccc @_Z14get_num_groupsj([[C2]]) : (i32) -> i64 + // CHECK-32: llvm.call spir_funccc @_Z14get_num_groupsj([[C2]]) : (i32) -> i32 + %grid_dim_z = gpu.grid_dim z + return %grid_dim_x, %grid_dim_y, %grid_dim_z : index, index, index + } +} + +// ----- + +gpu.module @barriers { + // CHECK: llvm.func spir_funccc @_Z7barrierj(i32) + + // CHECK-LABEL: gpu_barrier + func.func @gpu_barrier() { + // CHECK: [[FLAGS:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z7barrierj([[FLAGS]]) : (i32) -> () + gpu.barrier + return + } +} + +// ----- + +// Check `gpu.shuffle` conversion with default subgroup size. + +gpu.module @shuffles { + // CHECK: llvm.func spir_funccc @_Z22sub_group_shuffle_downdj(f64, i32) -> f64 + // CHECK: llvm.func spir_funccc @_Z20sub_group_shuffle_upfj(f32, i32) -> f32 + // CHECK: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 + // CHECK: llvm.func spir_funccc @_Z17sub_group_shuffleij(i32, i32) -> i32 + + // CHECK-LABEL: gpu_shuffles + // CHECK-SAME: (%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: i32) + func.func @gpu_shuffles(%val0: i32, %id: i32, + %val1: i64, %mask: i32, + %val2: f32, %delta_up: i32, + %val3: f64, %delta_down: i32) { + %width = arith.constant 32 : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleij(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorlj(%[[VAL_2]], %[[VAL_3]]) : (i64, i32) -> i64 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upfj(%[[VAL_4]], %[[VAL_5]]) : (f32, i32) -> f32 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[VAL_6]], %[[VAL_7]]) : (f64, i32) -> f64 + // CHECK: llvm.mlir.constant(true) : i1 + %shuffleResult0, %valid0 = gpu.shuffle idx %val0, %id, %width : i32 + %shuffleResult1, %valid1 = gpu.shuffle xor %val1, %mask, %width : i64 + %shuffleResult2, %valid2 = gpu.shuffle up %val2, %delta_up, %width : f32 + %shuffleResult3, %valid3 = gpu.shuffle down %val3, %delta_down, %width : f64 + return + } +} + +// ----- + +// Check `gpu.shuffle` conversion with explicit subgroup size. + +gpu.module @shuffles attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + // CHECK: llvm.func spir_funccc @_Z22sub_group_shuffle_downdj(f64, i32) -> f64 + // CHECK: llvm.func spir_funccc @_Z20sub_group_shuffle_upfj(f32, i32) -> f32 + // CHECK: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 + // CHECK: llvm.func spir_funccc @_Z17sub_group_shuffleij(i32, i32) -> i32 + + // CHECK-LABEL: gpu_shuffles + // CHECK-SAME: (%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: i32) + func.func @gpu_shuffles(%val0: i32, %id: i32, + %val1: i64, %mask: i32, + %val2: f32, %delta_up: i32, + %val3: f64, %delta_down: i32) { + %width = arith.constant 16 : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleij(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorlj(%[[VAL_2]], %[[VAL_3]]) : (i64, i32) -> i64 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upfj(%[[VAL_4]], %[[VAL_5]]) : (f32, i32) -> f32 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[VAL_6]], %[[VAL_7]]) : (f64, i32) -> f64 + // CHECK: llvm.mlir.constant(true) : i1 + %shuffleResult0, %valid0 = gpu.shuffle idx %val0, %id, %width : i32 + %shuffleResult1, %valid1 = gpu.shuffle xor %val1, %mask, %width : i64 + %shuffleResult2, %valid2 = gpu.shuffle up %val2, %delta_up, %width : f32 + %shuffleResult3, %valid3 = gpu.shuffle down %val3, %delta_down, %width : f64 + return + } +} + +// ----- + +// Cannot convert due to shuffle width and target subgroup size mismatch + +gpu.module @shuffles_mismatch { + func.func @gpu_shuffles(%val: i32, %id: i32) { + %width = arith.constant 16 : i32 + // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}} + %shuffleResult, %valid = gpu.shuffle idx %val, %id, %width : i32 + return + } +} + +// ----- + +// Cannot convert due to variable shuffle width + +gpu.module @shuffles_mismatch { + func.func @gpu_shuffles(%val: i32, %id: i32, %width: i32) { + // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}} + %shuffleResult, %valid = gpu.shuffle idx %val, %id, %width : i32 + return + } +}