llvm-project/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Kazu Hirata fa9adbfda9
[mlir] Remove unused includes (NFC) (#147101)
These are identified by misc-include-cleaner.  I've filtered out those
that break builds.  Also, I'm staying away from llvm-config.h,
config.h, and Compiler.h, which likely cause platform- or
compiler-specific build failures.
2025-07-04 13:30:21 -07:00

1859 lines
78 KiB
C++

//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert gpu.launch_func op into a sequence of
// GPU runtime calls. As most of GPU runtimes does not have a stable published
// ABI, this pass uses a slim runtime layer that builds on top of the public
// API from GPU runtime headers.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#define DEBUG_TYPE "gpu-to-llvm"
namespace mlir {
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
class GpuToLLVMConversionPass
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
public:
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const final {
Base::getDependentDialects(registry);
registerConvertToLLVMDependentDialectLoading(registry);
}
// Run the dialect converter on the module.
void runOnOperation() override;
};
template <typename OpTy>
class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
public:
explicit ConvertOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
protected:
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
MemRefType type, MemRefDescriptor desc) const {
Type indexType = ConvertToLLVMPattern::getIndexType();
if (type.hasStaticShape())
return ConvertToLLVMPattern::createIndexAttrConstant(
rewriter, loc, indexType, type.getNumElements());
// Compute the number of elements by multiplying all the dim sizes.
uint64_t rank = type.getRank();
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
for (unsigned i = 1; i < rank; i++)
numElements = rewriter.create<LLVM::MulOp>(
loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
return numElements;
}
MLIRContext *context = &this->getTypeConverter()->getContext();
Type llvmVoidType = LLVM::LLVMVoidType::get(context);
LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
Type llvmInt8Type = IntegerType::get(context, 8);
Type llvmInt16Type = IntegerType::get(context, 16);
Type llvmInt32Type = IntegerType::get(context, 32);
Type llvmInt64Type = IntegerType::get(context, 64);
Type llvmFloat32Type = Float32Type::get(context);
Type llvmIntPtrType = IntegerType::get(
context, this->getTypeConverter()->getPointerBitwidth(0));
FunctionCallBuilder streamCreateCallBuilder = {
"mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
FunctionCallBuilder streamDestroyCallBuilder = {
"mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
FunctionCallBuilder streamSynchronizeCallBuilder = {
"mgpuStreamSynchronize",
llvmVoidType,
{llvmPointerType /* void *stream */}};
FunctionCallBuilder streamWaitEventCallBuilder = {
"mgpuStreamWaitEvent",
llvmVoidType,
{llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
FunctionCallBuilder eventCreateCallBuilder = {
"mgpuEventCreate", llvmPointerType /* void *event */, {}};
FunctionCallBuilder eventDestroyCallBuilder = {
"mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
FunctionCallBuilder eventSynchronizeCallBuilder = {
"mgpuEventSynchronize",
llvmVoidType,
{llvmPointerType /* void *event */}};
FunctionCallBuilder eventRecordCallBuilder = {
"mgpuEventRecord",
llvmVoidType,
{llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
FunctionCallBuilder hostRegisterCallBuilder = {
"mgpuMemHostRegisterMemRef",
llvmVoidType,
{llvmIntPtrType /* intptr_t rank */,
llvmPointerType /* void *memrefDesc */,
llvmIntPtrType /* intptr_t elementSizeBytes */}};
FunctionCallBuilder hostUnregisterCallBuilder = {
"mgpuMemHostUnregisterMemRef",
llvmVoidType,
{llvmIntPtrType /* intptr_t rank */,
llvmPointerType /* void *memrefDesc */,
llvmIntPtrType /* intptr_t elementSizeBytes */}};
FunctionCallBuilder allocCallBuilder = {
"mgpuMemAlloc",
llvmPointerType /* void * */,
{llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */,
llvmInt8Type /* bool isHostShared */}};
FunctionCallBuilder deallocCallBuilder = {
"mgpuMemFree",
llvmVoidType,
{llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
FunctionCallBuilder memcpyCallBuilder = {
"mgpuMemcpy",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder memset16CallBuilder = {
"mgpuMemset16",
llvmVoidType,
{llvmPointerType /* void *dst */,
llvmInt16Type /* unsigned short value */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder memset32CallBuilder = {
"mgpuMemset32",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder setDefaultDeviceCallBuilder = {
"mgpuSetDefaultDevice",
llvmVoidType,
{llvmInt32Type /* uint32_t devIndex */}};
FunctionCallBuilder createDnVecCallBuilder = {
"mgpuCreateDnVec",
llvmPointerType,
{llvmIntPtrType, llvmPointerType, llvmInt32Type,
llvmPointerType /* void *stream */}};
FunctionCallBuilder destroyDnVecCallBuilder = {
"mgpuDestroyDnVec",
llvmVoidType,
{llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder createDnMatCallBuilder = {
"mgpuCreateDnMat",
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
llvmPointerType /* void *stream */}};
FunctionCallBuilder destroyDnMatCallBuilder = {
"mgpuDestroyDnMat",
llvmVoidType,
{llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder createCooCallBuilder = {
"mgpuCreateCoo",
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
llvmPointerType /* void *stream */}};
FunctionCallBuilder createCooAoSCallBuilder = {
"mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmInt32Type,
llvmPointerType /* void *stream */}};
FunctionCallBuilder createCsrCallBuilder = {
"mgpuCreateCsr",
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder createCscCallBuilder = {
"mgpuCreateCsc",
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder createBsrCallBuilder = {
"mgpuCreateBsr",
llvmPointerType,
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
llvmInt32Type, llvmInt32Type, llvmInt32Type,
llvmPointerType /* void *stream */}};
FunctionCallBuilder destroySpMatCallBuilder = {
"mgpuDestroySpMat",
llvmVoidType,
{llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder spMVBufferSizeCallBuilder = {
"mgpuSpMVBufferSize",
llvmIntPtrType,
{llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder spMVCallBuilder = {
"mgpuSpMV",
llvmVoidType,
{llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
"mgpuSpMMBufferSize",
llvmIntPtrType,
{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder createSpMMCallBuilder = {
"mgpuSpMM",
llvmVoidType,
{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmPointerType,
llvmPointerType /* void *stream */}};
FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
"mgpuSDDMMBufferSize",
llvmIntPtrType,
{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder createSDDMMCallBuilder = {
"mgpuSDDMM",
llvmVoidType,
{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmPointerType,
llvmPointerType /* void *stream */}};
FunctionCallBuilder createLtDnMatCallBuilder = {
"mgpuCreateCuSparseLtDnMat",
llvmVoidType,
{llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
"mgpuDestroyCuSparseLtSpMat",
llvmVoidType,
{llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
"mgpuDestroyCuSparseLtDnMat",
llvmVoidType,
{llvmPointerType, llvmPointerType /* void *stream */}};
FunctionCallBuilder create2To4SpMatCallBuilder = {
"mgpuCusparseLtCreate2To4SpMat",
llvmVoidType,
{llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
"mgpuCuSparseLtSpMMBufferSize",
llvmVoidType,
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
llvmPointerType /*void *stream*/}};
FunctionCallBuilder createCuSparseLtSpMMBuilder = {
"mgpuCuSparseLtSpMM",
llvmVoidType,
{llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
"mgpuSpGEMMCreateDescr",
llvmPointerType,
{llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
"mgpuSpGEMMDestroyDescr",
llvmVoidType,
{llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
"mgpuSpGEMMWorkEstimation",
llvmIntPtrType,
{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpGEMMComputeBuilder = {
"mgpuSpGEMMCompute",
llvmIntPtrType,
{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpGEMMCopyBuilder = {
"mgpuSpGEMMCopy",
llvmVoidType,
{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSpMatGetSizeBuilder = {
"mgpuSpMatGetSize",
llvmVoidType,
{llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
FunctionCallBuilder createSetCsrPointersBuilder = {
"mgpuSetCsrPointers",
llvmVoidType,
{llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
llvmPointerType /*crd*/, llvmPointerType /*val*/,
llvmPointerType /*void *stream*/}};
};
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertHostRegisterOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class ConvertHostUnregisterOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
public:
ConvertHostUnregisterOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
}
private:
LogicalResult
matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertAllocOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
public:
ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertDeallocOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
public:
ConvertDeallocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class ConvertAsyncYieldToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
public:
ConvertAsyncYieldToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertWaitOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
public:
ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertWaitAsyncOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite patter to legalize gpu.launch_func with LLVM types.
class LegalizeLaunchFuncOpPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
public:
LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv)
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
kernelBarePtrCallConv(kernelBarePtrCallConv),
kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
private:
LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
bool kernelBarePtrCallConv;
bool kernelIntersperseSizeCallConv;
};
/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertMemcpyOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
public:
ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertMemsetOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
public:
ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
/// Currently supports CUDA and ROCm (HIP)
class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
public:
ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
typeConverter) {}
LogicalResult
matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Generic rewriting rule for operation on sparse matrices.
/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
class Convert##op_name##ToGpuRuntimeCallPattern \
: public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
public: \
Convert##op_name##ToGpuRuntimeCallPattern( \
const LLVMTypeConverter &typeConverter) \
: ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
\
private: \
LogicalResult \
matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
ConversionPatternRewriter &rewriter) const override; \
};
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
} // namespace
void GpuToLLVMConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
// Perform progressive lowering of vector transfer operations.
{
RewritePatternSet patterns(&getContext());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
vector::populateVectorTransferLoweringPatterns(patterns,
/*maxTransferRank=*/1);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
LLVMTypeConverter converter(context, options);
// Populate all patterns from all dialects that implement the
// `ConvertToLLVMPatternInterface` interface.
for (Dialect *dialect : context->getLoadedDialects()) {
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
}
// Preserve GPU modules and binaries. Modules are preserved as they can be
// converted later by `gpu-module-to-binary`.
target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
// Accept as legal LaunchFuncOps if the operands have been lowered.
target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
[&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
// These aren't covered by the ConvertToLLVMPatternInterface right now.
populateVectorToLLVMConversionPatterns(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateGpuToLLVMConversionPatterns(converter, patterns,
kernelBarePtrCallConv,
kernelIntersperseSizeCallConv);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
ArrayRef<Value> arguments) const {
auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto function = [&] {
if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
return function;
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
return builder.create<LLVM::CallOp>(loc, function, arguments);
}
// Corresponding to cusparseIndexType_t defined in cusparse.h.
static int32_t getCuSparseIndexTypeFrom(Type type) {
if (type.isInteger(16))
return 1; // CUSPARSE_INDEX_16U
if (type.isInteger(32))
return 2; // CUSPARSE_INDEX_32I
return 3; // CUSPARSE_INDEX_64I
}
static int32_t getCuSparseLtDataTypeFrom(Type type) {
if (type.isF16())
return 0; // CUSPARSE_COMPUTE_16F,
if (type.isInteger(32))
return 1; // CUSPARSE_COMPUTE_32I
llvm_unreachable("unsupported type");
// TODO: add support to TF32
}
// Corresponding to cudaDataType_t defined in CUDA library_types.h.
static int32_t getCuSparseDataTypeFrom(Type type) {
if (llvm::isa<ComplexType>(type)) {
// get the element type
auto elementType = cast<ComplexType>(type).getElementType();
if (elementType.isBF16())
return 15; // CUDA_C_16BF
if (elementType.isF16())
return 6; // CUDA_C_16F
if (elementType.isF32())
return 4; // CUDA_C_32F
if (elementType.isF64())
return 5; // CUDA_C_64F
if (elementType.isInteger(8))
return 7; // CUDA_C_8I
if (elementType.isInteger(16))
return 21; // CUDA_C_16I
if (elementType.isInteger(32))
return 11; // CUDA_C_32I
}
if (type.isBF16())
return 14; // CUDA_R_16BF
if (type.isF16())
return 2; // CUDA_R_16F
if (type.isF32())
return 0; // CUDA_R_32F
if (type.isF64())
return 1; // CUDA_R_64F
if (type.isInteger(8))
return 3; // CUDA_R_8I
if (type.isInteger(16))
return 20; // CUDA_R_16I
if (type.isInteger(32))
return 10; // CUDA_R_32I
llvm_unreachable("unsupported element type");
}
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
}
// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
// runtime (of the CUDA program) error , but it might be great if we could at
// least output a warning when we found the target architecture is <8.0 and the
// user still wants to use cusparseLt. to make sure when lowering gpu sparse
// dialect to llvm calls, the cusparselt calls are disabled for cuda
// architecture <8.0
static bool is2To4Sparsity(Value spMat) {
if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
return true;
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
return false;
if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
return false;
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
return false;
if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
return false;
if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
return false;
// Print the spMat defining op
spMat.getDefiningOp()->print(llvm::errs());
llvm_unreachable("cannot find spmat def");
}
static bool isSpMMCusparseLtOp(Value op) {
for (Operation *user : op.getUsers()) {
auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
// If the other operator is 50% sparsity then we should use cusparseLt
if (!spmmOp)
continue;
if (is2To4Sparsity(spmmOp.getSpmatA()))
return true;
}
return false;
}
// Returns whether all operands are of LLVM type.
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter) {
if (!llvm::all_of(operands, [](Value value) {
return LLVM::isCompatibleType(value.getType());
}))
return rewriter.notifyMatchFailure(
op, "Cannot convert if operands aren't of LLVM type.");
return success();
}
static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
gpu::AsyncOpInterface op) {
if (op.getAsyncDependencies().size() != 1)
return rewriter.notifyMatchFailure(
op, "Can only convert with exactly one async dependency.");
if (!op.getAsyncToken())
return rewriter.notifyMatchFailure(op, "Can convert only async version.");
return success();
}
LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *op = hostRegisterOp.getOperation();
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
Location loc = op->getLoc();
auto memRefType = hostRegisterOp.getValue().getType();
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
auto arguments = getTypeConverter()->promoteOperands(
loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
rewriter.eraseOp(op);
return success();
}
LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Operation *op = hostUnregisterOp.getOperation();
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
Location loc = op->getLoc();
auto memRefType = hostUnregisterOp.getValue().getType();
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
auto arguments = getTypeConverter()->promoteOperands(
loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostUnregisterCallBuilder.create(loc, rewriter, arguments);
rewriter.eraseOp(op);
return success();
}
LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType memRefType = allocOp.getType();
if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType))
return failure();
auto loc = allocOp.getLoc();
bool isShared = allocOp.getHostShared();
if (isShared && allocOp.getAsyncToken())
return rewriter.notifyMatchFailure(
allocOp, "Host Shared allocation cannot be done async");
if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
return failure();
// Get shape of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands.
SmallVector<Value, 4> shape;
SmallVector<Value, 4> strides;
Value sizeBytes;
getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
shape, strides, sizeBytes);
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
Value stream = adaptor.getAsyncDependencies().empty()
? nullPtr
: adaptor.getAsyncDependencies().front();
auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
.getResult();
// No alignment.
Value alignedPtr = allocatedPtr;
// Create the MemRef descriptor.
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
if (allocOp.getAsyncToken()) {
// Async alloc: make dependent ops use the same stream.
rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
} else {
rewriter.replaceOp(allocOp, {memRefDescriptor});
}
return success();
}
LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, deallocOp)))
return failure();
Location loc = deallocOp.getLoc();
Value pointer =
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
Value stream = adaptor.getAsyncDependencies().front();
deallocCallBuilder.create(loc, rewriter, {pointer, stream});
rewriter.replaceOp(deallocOp, {stream});
return success();
}
static bool isGpuAsyncTokenType(Value value) {
return isa<gpu::AsyncTokenType>(value.getType());
}
// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
// !gpu.async.token are lowered to stream within the async.execute region, but
// are passed as events between them. For each !gpu.async.token operand, we
// create an event and record it on the stream.
LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
Location loc = yieldOp.getLoc();
SmallVector<Value, 4> newOperands(adaptor.getOperands());
llvm::SmallDenseSet<Value> streams;
for (auto &operand : yieldOp->getOpOperands()) {
if (!isGpuAsyncTokenType(operand.get()))
continue;
auto idx = operand.getOperandNumber();
auto stream = adaptor.getOperands()[idx];
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, stream});
newOperands[idx] = event;
streams.insert(stream);
}
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});
rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
return success();
}
// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
static bool isDefinedByCallTo(Value value, StringRef functionName) {
assert(isa<LLVM::LLVMPointerType>(value.getType()));
if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
return *defOp.getCallee() == functionName;
return false;
}
// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
// with the stream/event operands. The operands are destroyed. That is, it
// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
// runtime error. Eventually, we should guarantee this property.
LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (waitOp.getAsyncToken())
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
Location loc = waitOp.getLoc();
for (auto operand : adaptor.getOperands()) {
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream.
streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
streamDestroyCallBuilder.create(loc, rewriter, {operand});
} else {
// Otherwise the converted operand is an event. This assumes that we use
// events in control flow code as well.
eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
eventDestroyCallBuilder.create(loc, rewriter, {operand});
}
}
rewriter.eraseOp(waitOp);
return success();
}
// Converts `gpu.wait async` to runtime calls. The converted op creates a new
// stream that is synchronized with stream/event operands. The operands are
// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
// Otherwise we will get a runtime error. Eventually, we should guarantee this
// property.
LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!waitOp.getAsyncToken())
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
Location loc = waitOp.getLoc();
auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
for (auto pair :
llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
auto operand = std::get<1>(pair);
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream. Insert an event
// into the stream just after the last use of the original token operand.
auto *defOp = std::get<0>(pair).getDefiningOp();
rewriter.setInsertionPointAfter(defOp);
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, operand});
events.push_back(event);
} else {
// Otherwise the converted operand is an event. This assumes that we use
// events in control flow code as well.
events.push_back(operand);
}
}
rewriter.restoreInsertionPoint(insertionPoint);
auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
for (auto event : events)
streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
for (auto event : events)
eventDestroyCallBuilder.create(loc, rewriter, {event});
rewriter.replaceOp(waitOp, {stream});
return success();
}
// Legalize the op's operands.
LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
return failure();
if (launchOp.getAsyncDependencies().size() > 1)
return rewriter.notifyMatchFailure(
launchOp, "Cannot convert with more than one async dependency.");
// Fail when the synchronous version of the op has async dependencies. The
// lowering destroys the stream, and we do not want to check that there is no
// use of the stream after this op.
if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
return rewriter.notifyMatchFailure(
launchOp, "Cannot convert non-async op with async dependencies.");
Location loc = launchOp.getLoc();
Value stream = Value();
if (!adaptor.getAsyncDependencies().empty())
stream = adaptor.getAsyncDependencies().front();
// If the async keyword is present and there are no dependencies, then a
// stream must be created to pass to subsequent operations.
else if (launchOp.getAsyncToken())
stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
// Lower the kernel operands to match kernel parameters.
// Note: If `useBarePtrCallConv` is set in the type converter's options,
// the value of `kernelBarePtrCallConv` will be ignored.
OperandRange origArguments = launchOp.getKernelOperands();
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
loc, origArguments, adaptor.getKernelOperands(), rewriter,
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
SmallVector<Value, 8> llvmArgumentsWithSizes;
// Intersperse size information if requested.
if (kernelIntersperseSizeCallConv) {
if (origArguments.size() != llvmArguments.size()) {
// This shouldn't happen if the bare-pointer calling convention is used.
return rewriter.notifyMatchFailure(
launchOp,
"Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
}
llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
if (!memrefTy) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref.");
}
if (!memrefTy.hasStaticShape() ||
!memrefTy.getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a static "
"shape and an integer or float element type.");
}
unsigned bitwidth = memrefTy.getElementTypeBitWidth();
if (bitwidth % 8 != 0) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a "
"byte-aligned element type.");
}
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
static_cast<uint64_t>(memrefTy.getNumElements());
Value sizeArg = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(staticSize));
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
llvmArgumentsWithSizes.push_back(sizeArg);
}
}
std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
if (launchOp.hasClusterSize()) {
clusterSize =
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
adaptor.getClusterSizeZ()};
}
rewriter.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
adaptor.getBlockSizeZ()},
adaptor.getDynamicSharedMemorySize(),
llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
stream, clusterSize);
if (launchOp.getAsyncToken())
rewriter.replaceOp(launchOp, {stream});
else
rewriter.eraseOp(launchOp);
return success();
}
static Value bitAndAddrspaceCast(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMPointerType destinationType,
Value sourcePtr,
const LLVMTypeConverter &typeConverter) {
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(),
destinationType.getAddressSpace()),
sourcePtr);
return sourcePtr;
}
LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
return failure();
auto loc = memcpyOp.getLoc();
MemRefDescriptor srcDesc(adaptor.getSrc());
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
Value gepPtr = rewriter.create<LLVM::GEPOp>(
loc, elementPtrType,
typeConverter->convertType(memRefType.getElementType()), nullPtr,
numElements);
auto sizeBytes =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
srcDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
auto dst = bitAndAddrspaceCast(
loc, rewriter, llvmPointerType,
MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
*getTypeConverter());
auto stream = adaptor.getAsyncDependencies().front();
memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
rewriter.replaceOp(memcpyOp, {stream});
return success();
}
LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memsetOp)))
return failure();
auto loc = memsetOp.getLoc();
Type valueType = adaptor.getValue().getType();
unsigned bitWidth = valueType.getIntOrFloatBitWidth();
// Ints and floats of 16 or 32 bit width are allowed.
if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
return rewriter.notifyMatchFailure(
memsetOp, "value must be a 16 or 32 bit int or float");
}
unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
MemRefDescriptor dstDesc(adaptor.getDst());
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
auto stream = adaptor.getAsyncDependencies().front();
FunctionCallBuilder builder =
valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
builder.create(loc, rewriter, {dst, value, numElements, stream});
rewriter.replaceOp(memsetOp, {stream});
return success();
}
LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
{adaptor.getDevIndex()});
rewriter.replaceOp(op, call);
return success();
}
template <typename T>
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmInt32Type = builder.getIntegerType(32);
return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
static_cast<int32_t>(tValue));
}
template <typename T>
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmFloat32Type = builder.getF32Type();
return builder.create<LLVM::ConstantOp>(
loc, llvmFloat32Type,
builder.getF32FloatAttr(static_cast<float>(tValue)));
}
LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateDnTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pTensor =
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
Type dType = op.getMemref().getType().getElementType();
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
SmallVector<Value, 4> dims;
for (Value dim : adaptor.getDims()) {
dims.push_back(dim);
}
Value handle;
// TODO: For now, we track the use of the handle and lower it to cusparse /
// cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
// used, we require two separate Creation ops to be the correct logic. In
// future, we may add support to using one handle in sparse tensor / GPU
// dialect in both cusparse and cusparseLt. use the cusparseLt create call if
// the dnmat is used with spmat with 2:4 sparsity
if (dims.size() == 2) {
if (isSpMMCusparseLtOp(op.getDnTensor())) {
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(11032));
handle = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
createLtDnMatCallBuilder
.create(loc, rewriter,
{handle, dims[0], dims[1], pTensor, dtp, stream})
.getResult();
} else {
handle =
createDnMatCallBuilder
.create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
.getResult();
}
} else {
assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
handle = createDnVecCallBuilder
.create(loc, rewriter, {dims[0], pTensor, dtp, stream})
.getResult();
}
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
SmallVector<Value, 4> dims;
for (Value dim : definingOp.getDims()) {
dims.push_back(dim);
}
if (dims.size() == 2) {
// Use the cusparseLt destroy call if the dnmat is used with spmat with
// 2:4 sparsity
if (isSpMMCusparseLtOp(op.getDnTensor())) {
destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
{adaptor.getDnTensor(), stream});
} else {
destroyDnMatCallBuilder.create(loc, rewriter,
{adaptor.getDnTensor(), stream});
}
} else {
assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
destroyDnVecCallBuilder.create(loc, rewriter,
{adaptor.getDnTensor(), stream});
}
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateCooOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pRowIdxs =
MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
Value pColIdxs =
MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
Value pValues =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
Type iType =
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCooCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateCooAoSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
Value pValues =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCooAoSCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
pIdxs, pValues, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateCsrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pRowPos =
MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
Value pColIdxs =
MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
Value pValues =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
Type pType =
llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
Type iType =
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCsrCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pMat =
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
Type dType =
llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
// CUDA runner asserts the size is 44104 bytes.
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(44104));
Value handle = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
create2To4SpMatCallBuilder
.create(loc, rewriter,
{handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::DestroySpMatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
// Use the cusparseLt destroy call if the spmat is 2:4 sparsity
if (is2To4Sparsity(op.getSpmat())) {
destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
{adaptor.getSpmat(), stream});
} else {
destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
}
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
auto bufferSize = spMVBufferSizeCallBuilder
.create(loc, rewriter,
{modeA, adaptor.getSpmatA(), adaptor.getDnX(),
adaptor.getDnY(), computeType, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
}
LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpMVOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
spMVCallBuilder.create(loc, rewriter,
{modeA, adaptor.getSpmatA(), adaptor.getDnX(),
adaptor.getDnY(), computeType, pBuf, stream});
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
Value bufferSize;
if (is2To4Sparsity(op.getSpmatA())) {
auto pruneFlag =
genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(3));
auto bufferSize = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
createCuSparseLtSpMMBufferSizeBuilder
.create(loc, rewriter,
{bufferSize, modeA, modeB, adaptor.getSpmatA(),
adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
pruneFlag, stream})
.getResult();
auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(1))});
auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(2))});
auto bufferSize0 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
auto bufferSize1 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
auto bufferSize2 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
} else {
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
bufferSize =
createSpMMBufferSizeCallBuilder
.create(loc, rewriter,
{modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
adaptor.getDnmatC(), computeType, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
}
return success();
}
LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
auto bufferSize =
createSDDMMBufferSizeCallBuilder
.create(loc, rewriter,
{modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
adaptor.getSpmatC(), computeType, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
}
LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpMMOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
// Lower to cusparseLt if applicable
if (is2To4Sparsity(op.getSpmatA())) {
SmallVector<Value> pBufs;
for (Value buffer : adaptor.getBuffers()) {
Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
pBufs.push_back(pBuf);
}
createCuSparseLtSpMMBuilder.create(
loc, rewriter,
{adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
pBufs[0], pBufs[1], pBufs[2], stream});
} else {
Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
.allocatedPtr(rewriter, loc);
createSpMMCallBuilder.create(loc, rewriter,
{modeA, modeB, adaptor.getSpmatA(),
adaptor.getDnmatB(), adaptor.getDnmatC(),
computeType, pBuf, stream});
}
rewriter.replaceOp(op, {stream});
return success();
}
template <typename T>
static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
converter.addConversion([&converter](T) -> Type {
return LLVM::LLVMPointerType::get(&converter.getContext());
});
}
LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SDDMMOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
createSDDMMCallBuilder.create(loc, rewriter,
{modeA, modeB, adaptor.getDnmatA(),
adaptor.getDnmatB(), adaptor.getSpmatC(),
computeType, pBuf, stream});
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult
ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
.getResult();
rewriter.replaceOp(op, {descr, stream});
return success();
}
LogicalResult
ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
{adaptor.getDesc(), stream});
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult
ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
Value bufferSizeNew;
if (adaptor.getKind() ==
gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
bufferSizeNew =
createSpGEMMWorkEstimationBuilder
.create(loc, rewriter,
{adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
adaptor.getBufferSz(), pBuf, stream})
.getResult();
} else {
bufferSizeNew =
createSpGEMMComputeBuilder
.create(loc, rewriter,
{adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
adaptor.getBufferSz(), pBuf, stream})
.getResult();
}
rewriter.replaceOp(op, {bufferSizeNew, stream});
return success();
}
LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
createSpGEMMCopyBuilder.create(loc, rewriter,
{adaptor.getDesc(), modeA, modeB,
adaptor.getSpmatA(), adaptor.getSpmatB(),
adaptor.getSpmatC(), computeType, stream});
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(3));
auto buffer = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
auto rowsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(0))});
auto colsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(1))});
auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(2))});
createSpMatGetSizeBuilder.create(
loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
rewriter.replaceOp(op, {rows, cols, nnzs, stream});
return success();
}
LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SetCsrPointersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pPos =
MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
Value pCrd =
MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
Value pVal =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
createSetCsrPointersBuilder.create(
loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
rewriter.replaceOp(op, {stream});
return success();
}
LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateCscOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pColPos =
MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
Value pRowIdxs =
MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
Value pValues =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
Type pType =
llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
Type iType =
llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCscCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateBsrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
Value pRowPos =
MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
Value pColIdxs =
MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
Value pValues =
MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
Type pType =
llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
Type iType =
llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createBsrCallBuilder
.create(loc, rewriter,
{adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
pColIdxs, pValues, ptp, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
}
void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
ConvertDeallocOpToGpuRuntimeCallPattern,
ConvertHostRegisterOpToGpuRuntimeCallPattern,
ConvertHostUnregisterOpToGpuRuntimeCallPattern,
ConvertMemcpyOpToGpuRuntimeCallPattern,
ConvertMemsetOpToGpuRuntimeCallPattern,
ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern,
ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
ConvertCreateCooOpToGpuRuntimeCallPattern,
ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
ConvertCreateCsrOpToGpuRuntimeCallPattern,
ConvertCreateCscOpToGpuRuntimeCallPattern,
ConvertCreateBsrOpToGpuRuntimeCallPattern,
ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
ConvertDestroySpMatOpToGpuRuntimeCallPattern,
ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
ConvertSpMVOpToGpuRuntimeCallPattern,
ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
ConvertSpMMOpToGpuRuntimeCallPattern,
ConvertSDDMMOpToGpuRuntimeCallPattern,
ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
kernelIntersperseSizeCallConv);
}
//===----------------------------------------------------------------------===//
// GPUModuleOp convert to LLVM op interface
//===----------------------------------------------------------------------===//
namespace {
struct GPUModuleOpConvertToLLVMInterface
: public ConvertToLLVMOpInterface::ExternalModel<
GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
/// Get the conversion patterns from the target attribute.
void getConvertToLLVMConversionAttrs(
Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
};
} // namespace
void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
auto module = cast<gpu::GPUModuleOp>(op);
ArrayAttr targetsAttr = module.getTargetsAttr();
// Fail if there are no target attributes or there is more than one target.
if (!targetsAttr || targetsAttr.size() != 1)
return;
if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
attrs.push_back(patternAttr);
}
void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
});
}