[mlir][NFC] Mark type converter in populate...
functions as const
(#111250)
This commit marks the type converter in `populate...` functions as `const`. This is useful for debugging. Patterns already take a `const` type converter. However, some `populate...` functions do not only add new patterns, but also add additional type conversion rules. That makes it difficult to find the place where a type conversion was added in the code base. With this change, all `populate...` functions that only populate pattern now have a `const` type converter. Programmers can then conclude from the function signature that these functions do not register any new type conversion rules. Also some minor cleanups around the 1:N dialect conversion infrastructure, which did not always pass the type converter as a `const` object internally.
This commit is contained in:
parent
73683cc1ab
commit
206fad0e21
@ -72,9 +72,9 @@ std::unique_ptr<mlir::Pass> createLLVMDialectToLLVMPass(
|
||||
[](llvm::Module &m, llvm::raw_ostream &out) { m.print(out, nullptr); });
|
||||
|
||||
/// Populate the given list with patterns that convert from FIR to LLVM.
|
||||
void populateFIRToLLVMConversionPatterns(fir::LLVMTypeConverter &converter,
|
||||
mlir::RewritePatternSet &patterns,
|
||||
fir::FIRToLLVMPassOptions &options);
|
||||
void populateFIRToLLVMConversionPatterns(
|
||||
const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
|
||||
fir::FIRToLLVMPassOptions &options);
|
||||
|
||||
/// Populate the pattern set with the PreCGRewrite patterns.
|
||||
void populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns,
|
||||
|
@ -19,7 +19,7 @@ class LLVMTypeConverter;
|
||||
/// dialect, utilised in cases where the default OpenMP dialect handling cannot
|
||||
/// handle all cases for intermingled fir types and operations.
|
||||
void populateOpenMPFIRToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns);
|
||||
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns);
|
||||
|
||||
} // namespace fir
|
||||
|
||||
|
@ -22,7 +22,7 @@ class DataLayout;
|
||||
|
||||
namespace cuf {
|
||||
|
||||
void populateCUFToFIRConversionPatterns(fir::LLVMTypeConverter &converter,
|
||||
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
|
||||
mlir::DataLayout &dl,
|
||||
mlir::RewritePatternSet &patterns);
|
||||
|
||||
|
@ -3823,7 +3823,7 @@ fir::createLLVMDialectToLLVMPass(llvm::raw_ostream &output,
|
||||
}
|
||||
|
||||
void fir::populateFIRToLLVMConversionPatterns(
|
||||
fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
|
||||
const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
|
||||
fir::FIRToLLVMPassOptions &options) {
|
||||
patterns.insert<
|
||||
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
|
||||
|
@ -93,6 +93,6 @@ struct MapInfoOpConversion
|
||||
} // namespace
|
||||
|
||||
void fir::populateOpenMPFIRToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<MapInfoOpConversion>(converter);
|
||||
}
|
||||
|
@ -222,7 +222,7 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
CufAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
|
||||
fir::LLVMTypeConverter *typeConverter)
|
||||
const fir::LLVMTypeConverter *typeConverter)
|
||||
: OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
|
||||
|
||||
mlir::LogicalResult
|
||||
@ -311,7 +311,7 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
|
||||
|
||||
private:
|
||||
mlir::DataLayout *dl;
|
||||
fir::LLVMTypeConverter *typeConverter;
|
||||
const fir::LLVMTypeConverter *typeConverter;
|
||||
};
|
||||
|
||||
struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
|
||||
@ -583,7 +583,7 @@ public:
|
||||
} // namespace
|
||||
|
||||
void cuf::populateCUFToFIRConversionPatterns(
|
||||
fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
|
||||
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
|
||||
mlir::RewritePatternSet &patterns) {
|
||||
patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
|
||||
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
|
||||
|
@ -651,7 +651,7 @@ is very small, and follows the basic pattern of any dialect conversion pass.
|
||||
|
||||
```
|
||||
void mlir::populateTensorBufferizePatterns(
|
||||
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeCastOp, BufferizeExtractOp>(typeConverter,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ class Pass;
|
||||
/// Note: The ROCDL target does not support the LLVM bfloat type at this time
|
||||
/// and so this function will add conversions to change all `bfloat` uses
|
||||
/// to `i16`.
|
||||
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
amdgpu::Chipset chipset);
|
||||
|
||||
|
@ -22,7 +22,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace arith {
|
||||
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertArithToLLVMInterface(DialectRegistry ®istry);
|
||||
|
@ -22,7 +22,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace arith {
|
||||
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();
|
||||
|
@ -39,7 +39,7 @@ public:
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Complex to LLVM.
|
||||
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry);
|
||||
|
@ -20,7 +20,7 @@ class SPIRVTypeConverter;
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating Complex ops
|
||||
/// to SPIR-V ops.
|
||||
void populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -29,13 +29,13 @@ namespace cf {
|
||||
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
|
||||
/// references have to remain alive during the entire pattern lifetime.
|
||||
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateControlFlowToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
|
||||
/// set to false, the program execution continues when a condition is
|
||||
/// unsatisfied.
|
||||
void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter,
|
||||
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
bool abortOnFailure = true);
|
||||
|
||||
|
@ -20,7 +20,7 @@ class SPIRVTypeConverter;
|
||||
namespace cf {
|
||||
/// Appends to a pattern list additional patterns for translating ControlFLow
|
||||
/// ops to SPIR-V ops.
|
||||
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateControlFlowToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace cf
|
||||
} // namespace mlir
|
||||
|
@ -39,8 +39,8 @@ convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateFuncToLLVMFuncOpConversionPattern(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Collect the patterns to convert from the Func dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
|
||||
@ -56,7 +56,7 @@ void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
|
||||
/// needed if `converter.getOptions().useBarePtrCallConv` is `true`, but it's
|
||||
/// not an error to provide it anyway.
|
||||
void populateFuncToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const SymbolTable *symbolTable = nullptr);
|
||||
|
||||
void registerConvertFuncToLLVMInterface(DialectRegistry ®istry);
|
||||
|
@ -21,7 +21,7 @@ class SPIRVTypeConverter;
|
||||
/// Appends to a pattern list additional patterns for translating Func ops
|
||||
/// to SPIR-V ops. Also adds the patterns to legalize ops not directly
|
||||
/// translated to SPIR-V dialect.
|
||||
void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -21,7 +21,7 @@ class TypeConverter;
|
||||
#define GEN_PASS_DECL_CONVERTGPUOPSTOLLVMSPVOPS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populates memory space attribute conversion rules for lowering
|
||||
|
@ -32,16 +32,16 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
|
||||
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
|
||||
|
||||
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
|
||||
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm
|
||||
/// op that is not available on every GPU.
|
||||
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateGpuSubgroupReduceOpLoweringPattern(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
|
||||
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -30,7 +30,7 @@ class GPUModuleOp;
|
||||
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
|
||||
/// If `runtime` is Unknown, gpu.printf will not be lowered
|
||||
/// The resulting pattern set should be run over a gpu.module op
|
||||
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
gpu::amd::Runtime runtime);
|
||||
|
||||
|
@ -23,13 +23,13 @@ class SPIRVTypeConverter;
|
||||
/// Appends to a pattern list additional patterns for translating GPU Ops to
|
||||
/// SPIR-V ops. For a gpu.func to be converted, it should have a
|
||||
/// spirv.entry_point_abi attribute.
|
||||
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
|
||||
/// using the KHR Cooperative Matrix extension.
|
||||
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
|
||||
/// conversion to the type converter.
|
||||
|
@ -21,7 +21,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace index {
|
||||
void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateIndexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertIndexToLLVMInterface(DialectRegistry ®istry);
|
||||
|
@ -21,7 +21,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace index {
|
||||
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
|
||||
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
|
||||
} // namespace index
|
||||
|
@ -21,7 +21,7 @@ class Pass;
|
||||
#define GEN_PASS_DECL_CONVERTMATHTOLLVMPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
bool approximateLog1p = true);
|
||||
|
||||
|
@ -19,7 +19,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
/// Populate the given list with patterns that convert from Math to ROCDL calls.
|
||||
void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -20,7 +20,7 @@ class SPIRVTypeConverter;
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating Math ops
|
||||
/// to SPIR-V ops.
|
||||
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -15,7 +15,7 @@ class TypeConverter;
|
||||
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
|
||||
|
||||
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &converter);
|
||||
const TypeConverter &converter);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
|
@ -23,7 +23,7 @@ class RewritePatternSet;
|
||||
/// Collect a set of patterns to convert memory-related operations from the
|
||||
/// MemRef dialect to the LLVM dialect.
|
||||
void populateFinalizeMemRefToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
|
@ -67,7 +67,7 @@ void convertMemRefTypesAndAttrs(
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating MemRef ops
|
||||
/// to SPIR-V ops.
|
||||
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -34,7 +34,7 @@ MemRefType getMBarrierMemrefType(MLIRContext *context,
|
||||
MBarrierGroupType barrierType);
|
||||
} // namespace nvgpu
|
||||
|
||||
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -22,8 +22,8 @@ class RewritePatternSet;
|
||||
|
||||
/// Configure dynamic conversion legality of regionless operations from OpenMP
|
||||
/// to LLVM.
|
||||
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target,
|
||||
LLVMTypeConverter &typeConverter);
|
||||
void configureOpenMPToLLVMConversionLegality(
|
||||
ConversionTarget &target, const LLVMTypeConverter &typeConverter);
|
||||
|
||||
/// Populate the given list with patterns that convert from OpenMP to LLVM.
|
||||
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
|
@ -34,7 +34,7 @@ private:
|
||||
|
||||
/// Collects a set of patterns to lower from scf.for, scf.if, and
|
||||
/// loop.terminator to CFG operations within the SPIR-V dialect.
|
||||
void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
ScfToSPIRVContext &scfToSPIRVContext,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace mlir
|
||||
|
@ -25,13 +25,16 @@ class ModuleOp;
|
||||
template <typename SPIRVOp>
|
||||
class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
|
||||
public:
|
||||
SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter,
|
||||
SPIRVToLLVMConversion(MLIRContext *context,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpConversionPattern<SPIRVOp>(typeConverter, context, benefit),
|
||||
typeConverter(typeConverter) {}
|
||||
: OpConversionPattern<SPIRVOp>(typeConverter, context, benefit) {}
|
||||
|
||||
protected:
|
||||
LLVMTypeConverter &typeConverter;
|
||||
const LLVMTypeConverter *getTypeConverter() const {
|
||||
return static_cast<const LLVMTypeConverter *>(
|
||||
ConversionPattern::getTypeConverter());
|
||||
}
|
||||
};
|
||||
|
||||
/// Encodes global variable's descriptor set and binding into its name if they
|
||||
@ -46,18 +49,18 @@ void populateSPIRVToLLVMTypeConversion(
|
||||
|
||||
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
|
||||
void populateSPIRVToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
spirv::ClientAPI clientAPIForAddressSpaceMapping =
|
||||
spirv::ClientAPI::Unknown);
|
||||
|
||||
/// Populates the given list with patterns for function conversion from SPIR-V
|
||||
/// to LLVM.
|
||||
void populateSPIRVToLLVMFunctionConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Populates the given patterns for module conversion from SPIR-V to LLVM.
|
||||
void populateSPIRVToLLVMModuleConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -30,7 +30,7 @@ class SPIRVTypeConverter;
|
||||
/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
|
||||
/// away. So this has implications over register pressure. Therefore, a
|
||||
/// threshold is used to control when the patterns should kick in.
|
||||
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateTensorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
int64_t byteCountThreshold,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
|
@ -47,7 +47,7 @@ void addTosaToLinalgPasses(
|
||||
void registerTosaToLinalgPipelines();
|
||||
|
||||
/// Populates conversion passes from TOSA dialect to Linalg dialect.
|
||||
void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
|
||||
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
|
||||
RewritePatternSet *patterns);
|
||||
|
||||
/// Populates conversion passes from TOSA dialect to Linalg named operations.
|
||||
|
@ -25,7 +25,7 @@ namespace tosa {
|
||||
|
||||
std::unique_ptr<Pass> createTosaToTensor();
|
||||
|
||||
void populateTosaToTensorConversionPatterns(TypeConverter &converter,
|
||||
void populateTosaToTensorConversionPatterns(const TypeConverter &converter,
|
||||
RewritePatternSet *patterns);
|
||||
|
||||
} // namespace tosa
|
||||
|
@ -22,7 +22,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace ub {
|
||||
void populateUBToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
void populateUBToLLVMConversionPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertUBToLLVMInterface(DialectRegistry ®istry);
|
||||
|
@ -21,7 +21,7 @@ class Pass;
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace ub {
|
||||
void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
|
||||
void populateUBToSPIRVConversionPatterns(const SPIRVTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace ub
|
||||
} // namespace mlir
|
||||
|
@ -16,12 +16,12 @@ class LLVMTypeConverter;
|
||||
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
|
||||
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
|
||||
/// will be needed when invoking LLVM.
|
||||
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateVectorToLLVMMatrixConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
|
||||
void populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -20,7 +20,7 @@ class SPIRVTypeConverter;
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating Vector Ops to
|
||||
/// SPIR-V ops.
|
||||
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns to convert vector reduction of the form:
|
||||
|
@ -17,8 +17,8 @@ class RewritePatternSet;
|
||||
|
||||
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateAMXLegalizeForLLVMExportPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Configure the target to support lowering AMX ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
|
@ -28,13 +28,15 @@ class NarrowTypeEmulationConverter;
|
||||
/// types into supported ones. This is done by splitting original power-of-two
|
||||
/// i2N integer types into two iN halves.
|
||||
void populateArithWideIntEmulationPatterns(
|
||||
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const WideIntEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Adds patterns to emulate narrow Arith and Function ops into wide
|
||||
/// supported types. Users need to add conversions about the computation
|
||||
/// domain of narrow types.
|
||||
void populateArithNarrowTypeEmulationPatterns(
|
||||
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populate the type conversions needed to emulate the unsupported
|
||||
/// `sourceTypes` with `destType`
|
||||
@ -45,12 +47,12 @@ void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter,
|
||||
/// Add rewrite patterns for converting operations that use illegal float types
|
||||
/// to ones that use legal ones.
|
||||
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &converter);
|
||||
const TypeConverter &converter);
|
||||
|
||||
/// Set up a dialect conversion to reject arithmetic operations on unsupported
|
||||
/// float types.
|
||||
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target,
|
||||
TypeConverter &converter);
|
||||
const TypeConverter &converter);
|
||||
/// Add patterns to expand Arith ceil/floor division ops.
|
||||
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
|
@ -17,8 +17,8 @@ class RewritePatternSet;
|
||||
|
||||
/// Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateArmSVELegalizeForLLVMExportPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
|
@ -60,7 +60,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
|
||||
///
|
||||
/// In particular, these are the tensor_load/buffer_cast ops.
|
||||
void populateEliminateBufferizeMaterializationsPatterns(
|
||||
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
|
||||
///
|
||||
|
@ -85,7 +85,7 @@ private:
|
||||
/// Populates the patterns needed to drive the conversion process for
|
||||
/// decomposing call graph types with the given `ValueDecomposer`.
|
||||
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
|
||||
TypeConverter &typeConverter,
|
||||
const TypeConverter &typeConverter,
|
||||
ValueDecomposer &decomposer,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
|
@ -29,7 +29,7 @@ class RewritePatternSet;
|
||||
/// Add a pattern to the given pattern list to convert the operand and result
|
||||
/// types of a CallOp with the given type converter.
|
||||
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
|
||||
TypeConverter &converter);
|
||||
const TypeConverter &converter);
|
||||
|
||||
/// Add a pattern to the given pattern list to rewrite branch operations to use
|
||||
/// operands that have been legalized by the conversion framework. This can only
|
||||
@ -41,25 +41,25 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
|
||||
/// shouldConvertBranchOperand. This callback should return true if branchOp's
|
||||
/// operand at index idx should be converted.
|
||||
void populateBranchOpInterfaceTypeConversionPattern(
|
||||
RewritePatternSet &patterns, TypeConverter &converter,
|
||||
RewritePatternSet &patterns, const TypeConverter &converter,
|
||||
function_ref<bool(BranchOpInterface branchOp, int idx)>
|
||||
shouldConvertBranchOperand = nullptr);
|
||||
|
||||
/// Return true if op is a BranchOpInterface op whose operands are all legal
|
||||
/// according to converter.
|
||||
bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
|
||||
TypeConverter &converter);
|
||||
bool isLegalForBranchOpInterfaceTypeConversionPattern(
|
||||
Operation *op, const TypeConverter &converter);
|
||||
|
||||
/// Add a pattern to the given pattern list to rewrite `return` ops to use
|
||||
/// operands that have been legalized by the conversion framework.
|
||||
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
|
||||
TypeConverter &converter);
|
||||
const TypeConverter &converter);
|
||||
|
||||
/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
|
||||
/// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise,
|
||||
/// return false.
|
||||
bool isLegalForReturnOpTypeConversionPattern(Operation *op,
|
||||
TypeConverter &converter,
|
||||
const TypeConverter &converter,
|
||||
bool returnOpAlwaysLegal = false);
|
||||
|
||||
/// Return true if op is neither BranchOpInterface nor ReturnLike.
|
||||
|
@ -18,7 +18,7 @@ namespace mlir {
|
||||
|
||||
// Populates the provided pattern set with patterns that do 1:N type conversions
|
||||
// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
|
||||
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
|
||||
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -62,7 +62,7 @@ void populateExtendToSupportedTypesTypeConverter(
|
||||
void populateExtendToSupportedTypesConversionTarget(
|
||||
ConversionTarget &target, TypeConverter &typeConverter);
|
||||
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &typeConverter);
|
||||
const TypeConverter &typeConverter);
|
||||
} // namespace math
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -72,7 +72,7 @@ void populateExpandReallocPatterns(RewritePatternSet &patterns,
|
||||
/// Appends patterns for emulating wide integer memref operations with ops over
|
||||
/// narrower integer types.
|
||||
void populateMemRefWideIntEmulationPatterns(
|
||||
arith::WideIntEmulationConverter &typeConverter,
|
||||
const arith::WideIntEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends type conversions for emulating wide integer memref operations with
|
||||
@ -83,7 +83,7 @@ void populateMemRefWideIntEmulationConversions(
|
||||
/// Appends patterns for emulating memref operations over narrow types with ops
|
||||
/// over wider types.
|
||||
void populateMemRefNarrowTypeEmulationPatterns(
|
||||
arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
const arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends type conversions for emulating memref operations over narrow types
|
||||
|
@ -50,12 +50,12 @@ protected:
|
||||
/// corresponding scf.yield ops need to update their types accordingly to the
|
||||
/// TypeConverter, but otherwise don't care what type conversions are happening.
|
||||
void populateSCFStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
const TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not
|
||||
/// populate the conversion target.
|
||||
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter,
|
||||
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Updates the ConversionTarget with dynamic legality of SCF operations based
|
||||
@ -66,8 +66,8 @@ void populateSCFStructuralTypeConversionTarget(
|
||||
/// Populates the provided pattern set with patterns that do 1:N type
|
||||
/// conversions on (some) SCF ops. This is intended to be used with
|
||||
/// applyPartialOneToNConversion.
|
||||
void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
void populateSCFStructuralOneToNTypeConversions(
|
||||
const TypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Populate patterns for SCF software pipelining transformation. See the
|
||||
/// ForLoopPipeliningPattern for the transformation details.
|
||||
|
@ -135,7 +135,7 @@ private:
|
||||
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
|
||||
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
|
||||
/// types.
|
||||
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
|
||||
|
@ -154,7 +154,7 @@ struct SparseIterationTypeConverter : public OneToNTypeConverter {
|
||||
SparseIterationTypeConverter();
|
||||
};
|
||||
|
||||
void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
|
||||
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
|
||||
@ -170,7 +170,7 @@ public:
|
||||
};
|
||||
|
||||
/// Sets up sparse tensor conversion rules.
|
||||
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
|
||||
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createSparseTensorConversionPass();
|
||||
@ -186,7 +186,7 @@ public:
|
||||
};
|
||||
|
||||
/// Sets up sparse tensor codegen rules.
|
||||
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
bool createSparseDeallocs,
|
||||
bool enableBufferInitialization);
|
||||
@ -244,7 +244,7 @@ public:
|
||||
StorageSpecifierToLLVMTypeConverter();
|
||||
};
|
||||
|
||||
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
|
||||
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
|
||||
|
||||
|
@ -366,7 +366,7 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||
/// Appends patterns for emulating vector operations over narrow types with ops
|
||||
/// over wider types.
|
||||
void populateVectorNarrowTypeEmulationPatterns(
|
||||
arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
const arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
|
||||
@ -403,10 +403,9 @@ void populateVectorLinearizeTypeConversionsAndLegality(
|
||||
|
||||
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
|
||||
/// vector shuffle operations.
|
||||
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
unsigned targetBitWidth);
|
||||
void populateVectorLinearizeShuffleLikeOpsPatterns(
|
||||
const TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, unsigned targetBitWidth);
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
@ -176,7 +176,7 @@ void populateSpecializedTransposeLoweringPatterns(
|
||||
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateX86VectorLegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
/// Configure the target to support lowering X86Vector ops to ops that map to
|
||||
/// LLVM intrinsics.
|
||||
|
@ -53,7 +53,7 @@ public:
|
||||
/// `TypeConverter::convertSignatureArgs` and exists here with a different
|
||||
/// name to reflect the broader semantic.
|
||||
LogicalResult computeTypeMapping(TypeRange types,
|
||||
SignatureConversion &result) {
|
||||
SignatureConversion &result) const {
|
||||
return convertSignatureArgs(types, result);
|
||||
}
|
||||
|
||||
@ -126,24 +126,25 @@ public:
|
||||
/// Construct a conversion pattern with the given converter, and forward the
|
||||
/// remaining arguments to RewritePattern.
|
||||
template <typename... Args>
|
||||
RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
|
||||
RewritePatternWithConverter(const TypeConverter &typeConverter,
|
||||
Args &&...args)
|
||||
: RewritePattern(std::forward<Args>(args)...),
|
||||
typeConverter(&typeConverter) {}
|
||||
|
||||
/// Return the type converter held by this pattern, or nullptr if the pattern
|
||||
/// does not require type conversion.
|
||||
TypeConverter *getTypeConverter() const { return typeConverter; }
|
||||
const TypeConverter *getTypeConverter() const { return typeConverter; }
|
||||
|
||||
template <typename ConverterTy>
|
||||
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
|
||||
ConverterTy *>
|
||||
const ConverterTy *>
|
||||
getTypeConverter() const {
|
||||
return static_cast<ConverterTy *>(typeConverter);
|
||||
return static_cast<const ConverterTy *>(typeConverter);
|
||||
}
|
||||
|
||||
protected:
|
||||
/// A type converter for use by this pattern.
|
||||
TypeConverter *const typeConverter;
|
||||
const TypeConverter *const typeConverter;
|
||||
};
|
||||
|
||||
/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
|
||||
@ -212,8 +213,8 @@ public:
|
||||
template <typename SourceOp>
|
||||
class OneToNOpConversionPattern : public OneToNConversionPattern {
|
||||
public:
|
||||
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
|
||||
PatternBenefit benefit = 1,
|
||||
OneToNOpConversionPattern(const TypeConverter &typeConverter,
|
||||
MLIRContext *context, PatternBenefit benefit = 1,
|
||||
ArrayRef<StringRef> generatedNames = {})
|
||||
: OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
|
||||
benefit, context, generatedNames) {}
|
||||
@ -302,11 +303,11 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
|
||||
/// ops which use FunctionType to represent their type. This is intended to be
|
||||
/// used with the 1:N dialect conversion.
|
||||
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
|
||||
StringRef functionLikeOpName, TypeConverter &converter,
|
||||
StringRef functionLikeOpName, const TypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
template <typename FuncOpT>
|
||||
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
|
||||
TypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const TypeConverter &converter, RewritePatternSet &patterns) {
|
||||
populateOneToNFunctionOpInterfaceTypeConversionPattern(
|
||||
FuncOpT::getOperationName(), converter, patterns);
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
|
||||
};
|
||||
|
||||
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
|
||||
LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
|
||||
|
||||
Chipset chipset;
|
||||
@ -335,7 +335,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
|
||||
};
|
||||
|
||||
struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
|
||||
SchedBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
|
||||
|
||||
Chipset chipset;
|
||||
@ -725,7 +725,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
|
||||
namespace {
|
||||
struct ExtPackedFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
|
||||
ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
@ -737,7 +737,8 @@ struct ExtPackedFp8OpLowering final
|
||||
|
||||
struct PackedTrunc2xFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
|
||||
PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
|
||||
Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
@ -749,7 +750,8 @@ struct PackedTrunc2xFp8OpLowering final
|
||||
|
||||
struct PackedStochRoundFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
|
||||
PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
|
||||
Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
@ -880,7 +882,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
|
||||
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
|
||||
// operation into the corresponding ROCDL instructions.
|
||||
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
|
||||
AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
|
||||
@ -1052,9 +1054,9 @@ struct ConvertAMDGPUToROCDLPass
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
Chipset chipset) {
|
||||
void mlir::populateAMDGPUToROCDLConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
Chipset chipset) {
|
||||
patterns
|
||||
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
|
||||
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
|
||||
|
@ -520,7 +520,7 @@ void mlir::arith::registerConvertArithToLLVMInterface(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::arith::populateArithToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AddFOpLowering,
|
||||
|
@ -1262,7 +1262,7 @@ public:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::arith::populateArithToSPIRVPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
ConstantCompositeOpPattern,
|
||||
|
@ -323,7 +323,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::populateComplexToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AbsOpConversion,
|
||||
|
@ -102,8 +102,8 @@ struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateComplexToSPIRVPatterns(
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
|
||||
|
@ -43,7 +43,7 @@ namespace {
|
||||
/// ignored by the default lowering but should be propagated by any custom
|
||||
/// lowering.
|
||||
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
|
||||
explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
|
||||
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
|
||||
bool abortOnFailedAssert = true)
|
||||
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
|
||||
abortOnFailedAssert(abortOnFailedAssert) {}
|
||||
@ -201,7 +201,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AssertOpLowering,
|
||||
@ -212,7 +212,7 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
||||
}
|
||||
|
||||
void mlir::cf::populateAssertToLLVMConversionPattern(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool abortOnFailure) {
|
||||
patterns.add<AssertOpLowering>(converter, abortOnFailure);
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::cf::populateControlFlowToSPIRVPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
|
||||
|
@ -55,7 +55,7 @@ void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
|
||||
}
|
||||
|
||||
/// Populate patterns for each dialect.
|
||||
void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
ScfToSPIRVContext &scfToSPIRVContext,
|
||||
RewritePatternSet &patterns) {
|
||||
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
|
||||
|
@ -722,12 +722,12 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::populateFuncToLLVMFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<FuncOpConversion>(converter);
|
||||
}
|
||||
|
||||
void mlir::populateFuncToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const SymbolTable *symbolTable) {
|
||||
populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
|
||||
patterns.add<CallIndirectOpLowering>(converter);
|
||||
|
@ -87,7 +87,7 @@ public:
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void mlir::populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
|
@ -36,13 +36,13 @@ private:
|
||||
IntrType intrType;
|
||||
|
||||
public:
|
||||
explicit OpLowering(LLVMTypeConverter &typeConverter)
|
||||
explicit OpLowering(const LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToLLVMPattern<Op>(typeConverter),
|
||||
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
|
||||
indexKind(IndexKind::Other), intrType(IntrType::None) {}
|
||||
|
||||
explicit OpLowering(LLVMTypeConverter &typeConverter, IndexKind indexKind,
|
||||
IntrType intrType)
|
||||
explicit OpLowering(const LLVMTypeConverter &typeConverter,
|
||||
IndexKind indexKind, IntrType intrType)
|
||||
: ConvertOpToLLVMPattern<Op>(typeConverter),
|
||||
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
|
||||
indexKind(indexKind), intrType(intrType) {}
|
||||
|
@ -42,9 +42,9 @@ namespace mlir {
|
||||
template <typename SourceOp>
|
||||
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
|
||||
StringRef f64Func, StringRef f32ApproxFunc,
|
||||
StringRef f16Func)
|
||||
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
|
||||
StringRef f32Func, StringRef f64Func,
|
||||
StringRef f32ApproxFunc, StringRef f16Func)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
|
||||
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
|
||||
|
||||
|
@ -412,8 +412,8 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
void populateGpuToLLVMSPVConversionPatterns(
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
|
||||
GPUSubgroupOpConversion<gpu::LaneIdOp>,
|
||||
GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
|
||||
|
@ -333,7 +333,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
static void populateOpPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns, StringRef f32Func,
|
||||
StringRef f64Func, StringRef f32ApproxFunc = "",
|
||||
StringRef f16Func = "") {
|
||||
@ -343,12 +343,12 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
}
|
||||
|
||||
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<GPUSubgroupReduceOpLowering>(converter);
|
||||
}
|
||||
|
||||
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateGpuToNVVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
using gpu::index_lowering::IndexKind;
|
||||
using gpu::index_lowering::IntrType;
|
||||
populateWithGenerated(patterns);
|
||||
|
@ -388,7 +388,7 @@ LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
|
||||
}
|
||||
|
||||
void mlir::populateGpuWMMAToNVVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
|
||||
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
|
||||
WmmaElementwiseOpToNVVMLowering>(converter);
|
||||
|
@ -343,7 +343,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
static void populateOpPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns, StringRef f32Func,
|
||||
StringRef f64Func, StringRef f32ApproxFunc,
|
||||
StringRef f16Func) {
|
||||
@ -353,7 +353,7 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
}
|
||||
|
||||
void mlir::populateGpuToROCDLConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
mlir::gpu::amd::Runtime runtime) {
|
||||
using gpu::index_lowering::IndexKind;
|
||||
using gpu::index_lowering::IntrType;
|
||||
|
@ -59,7 +59,8 @@ public:
|
||||
/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
|
||||
class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
|
||||
public:
|
||||
WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
|
||||
WorkGroupSizeConversion(const TypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
|
||||
|
||||
LogicalResult
|
||||
@ -728,7 +729,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
|
||||
// GPU To SPIRV Patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
|
||||
|
@ -293,7 +293,7 @@ struct WmmaMmaOpToSPIRVLowering final
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
|
||||
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
using namespace mlir;
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
|
||||
|
@ -292,7 +292,7 @@ using ConvertIndexBoolConstant =
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void index::populateIndexToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.insert<
|
||||
// clang-format off
|
||||
ConvertIndexAdd,
|
||||
|
@ -338,8 +338,8 @@ struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
|
||||
// Pattern Population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
void index::populateIndexToSPIRVPatterns(
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
// clang-format off
|
||||
ConvertIndexAdd,
|
||||
|
@ -298,9 +298,9 @@ struct ConvertMathToLLVMPass
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
bool approximateLog1p) {
|
||||
void mlir::populateMathToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool approximateLog1p) {
|
||||
if (approximateLog1p)
|
||||
patterns.add<Log1pOpLowering>(converter);
|
||||
// clang-format off
|
||||
|
@ -36,7 +36,7 @@ using namespace mlir;
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
||||
|
||||
template <typename OpTy>
|
||||
static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
static void populateOpPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns, StringRef f32Func,
|
||||
StringRef f64Func, StringRef f16Func,
|
||||
StringRef f32ApproxFunc = "") {
|
||||
@ -45,8 +45,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
f32ApproxFunc, f16Func);
|
||||
}
|
||||
|
||||
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateMathToROCDLConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// Handled by mathToLLVM: math::AbsIOp
|
||||
// Handled by mathToLLVM: math::AbsFOp
|
||||
// Handled by mathToLLVM: math::CopySignOp
|
||||
|
@ -462,7 +462,7 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
// Core patterns
|
||||
patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
|
||||
|
@ -179,8 +179,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &converter) {
|
||||
void mlir::populateMemRefToEmitCConversionPatterns(
|
||||
RewritePatternSet &patterns, const TypeConverter &converter) {
|
||||
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
|
||||
ConvertStore>(converter, patterns.getContext());
|
||||
}
|
||||
|
@ -1667,7 +1667,7 @@ public:
|
||||
} // namespace
|
||||
|
||||
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AllocaOpLowering,
|
||||
|
@ -926,7 +926,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
|
||||
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
|
||||
|
@ -1701,8 +1701,8 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateNVGPUToNVVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
|
||||
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
|
||||
|
@ -221,7 +221,7 @@ void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
|
||||
} // namespace
|
||||
|
||||
void mlir::configureOpenMPToLLVMConversionLegality(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter) {
|
||||
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
|
||||
target.addDynamicallyLegalOp<
|
||||
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
|
||||
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
|
||||
|
@ -96,7 +96,7 @@ Region::iterator getBlockIt(Region ®ion, unsigned index) {
|
||||
template <typename OpTy>
|
||||
class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter,
|
||||
SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
|
||||
ScfToSPIRVContextImpl *scfToSPIRVContext)
|
||||
: OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
|
||||
scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
|
||||
@ -117,7 +117,7 @@ protected:
|
||||
// conversion. There isn't a straightforward way to do that yet, as when
|
||||
// converting types, ops aren't taken into consideration. Therefore, we just
|
||||
// bypass the framework's type conversion for now.
|
||||
SPIRVTypeConverter &typeConverter;
|
||||
const SPIRVTypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -436,7 +436,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
|
||||
// Public API
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
ScfToSPIRVContext &scfToSPIRVContext,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
|
||||
|
@ -149,7 +149,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
|
||||
|
||||
/// Broadcasts the value to vector with `numElements` number of elements.
|
||||
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
|
||||
auto llvmVectorType = typeConverter.convertType(vectorType);
|
||||
@ -166,7 +166,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
|
||||
|
||||
/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
|
||||
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
if (auto vectorType = dyn_cast<VectorType>(srcType)) {
|
||||
unsigned numElements = vectorType.getNumElements();
|
||||
@ -186,7 +186,8 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType,
|
||||
/// Then cast `Offset` and `Count` if their bit width is different
|
||||
/// from `Base` bit width.
|
||||
static Value processCountOrOffset(Location loc, Value value, Type srcType,
|
||||
Type dstType, LLVMTypeConverter &converter,
|
||||
Type dstType,
|
||||
const LLVMTypeConverter &converter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value broadcasted =
|
||||
optionallyBroadcast(loc, value, srcType, converter, rewriter);
|
||||
@ -196,7 +197,7 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
|
||||
/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
|
||||
/// offset to LLVM struct. Otherwise, the conversion is not supported.
|
||||
static Type convertStructTypeWithOffset(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
const LLVMTypeConverter &converter) {
|
||||
if (type != VulkanLayoutUtils::decorateType(type))
|
||||
return nullptr;
|
||||
|
||||
@ -209,7 +210,7 @@ static Type convertStructTypeWithOffset(spirv::StructType type,
|
||||
|
||||
/// Converts SPIR-V struct with no offset to packed LLVM struct.
|
||||
static Type convertStructTypePacked(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
const LLVMTypeConverter &converter) {
|
||||
SmallVector<Type> elementsVector;
|
||||
if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
|
||||
return nullptr;
|
||||
@ -226,11 +227,10 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
/// Utility for `spirv.Load` and `spirv.Store` conversion.
|
||||
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
unsigned alignment, bool isVolatile,
|
||||
bool isNonTemporal) {
|
||||
static LogicalResult replaceWithLoadOrStore(
|
||||
Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter,
|
||||
const LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile,
|
||||
bool isNonTemporal) {
|
||||
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
|
||||
auto dstType = typeConverter.convertType(loadOp.getType());
|
||||
if (!dstType)
|
||||
@ -271,7 +271,7 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
|
||||
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
|
||||
/// modelled at the moment.
|
||||
static Type convertPointerType(spirv::PointerType type,
|
||||
LLVMTypeConverter &converter,
|
||||
const LLVMTypeConverter &converter,
|
||||
spirv::ClientAPI clientAPI) {
|
||||
unsigned addressSpace =
|
||||
storageClassToAddressSpace(clientAPI, type.getStorageClass());
|
||||
@ -292,7 +292,7 @@ static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
|
||||
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
|
||||
/// member decorations. Also, only natural offset is supported.
|
||||
static Type convertStructType(spirv::StructType type,
|
||||
LLVMTypeConverter &converter) {
|
||||
const LLVMTypeConverter &converter) {
|
||||
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
|
||||
type.getMemberDecorations(memberDecorations);
|
||||
if (!memberDecorations.empty())
|
||||
@ -315,20 +315,21 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
|
||||
auto dstType =
|
||||
getTypeConverter()->convertType(op.getComponentPtr().getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
// To use GEP we need to add a first 0 index to go through the pointer.
|
||||
auto indices = llvm::to_vector<4>(adaptor.getIndices());
|
||||
Type indexType = op.getIndices().front().getType();
|
||||
auto llvmIndexType = typeConverter.convertType(indexType);
|
||||
auto llvmIndexType = getTypeConverter()->convertType(indexType);
|
||||
if (!llvmIndexType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
|
||||
indices.insert(indices.begin(), zero);
|
||||
|
||||
auto elementType = typeConverter.convertType(
|
||||
auto elementType = getTypeConverter()->convertType(
|
||||
cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
|
||||
if (!elementType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
@ -345,7 +346,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(op.getPointer().getType());
|
||||
auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
|
||||
@ -363,16 +364,16 @@ public:
|
||||
matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getType();
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
|
||||
// Create a mask with bits set outside [Offset, Offset + Count - 1].
|
||||
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
|
||||
@ -410,7 +411,7 @@ public:
|
||||
if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
|
||||
return failure();
|
||||
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
|
||||
|
||||
@ -451,16 +452,16 @@ public:
|
||||
matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getType();
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
|
||||
// Create a constant that holds the size of the `Base`.
|
||||
IntegerType integerType;
|
||||
@ -504,16 +505,16 @@ public:
|
||||
matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getType();
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
*getTypeConverter(), rewriter);
|
||||
|
||||
// Create a mask with bits set at [0, Count - 1].
|
||||
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
|
||||
@ -580,7 +581,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -612,7 +613,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -643,7 +644,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
rewriter.template replaceOpWithNewOp<LLVMOp>(
|
||||
@ -749,7 +750,7 @@ public:
|
||||
return failure();
|
||||
|
||||
auto srcType = cast<spirv::PointerType>(op.getType());
|
||||
auto dstType = typeConverter.convertType(srcType.getPointeeType());
|
||||
auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -810,7 +811,7 @@ public:
|
||||
Type fromType = op.getOperand().getType();
|
||||
Type toType = op.getType();
|
||||
|
||||
auto dstType = this->typeConverter.convertType(toType);
|
||||
auto dstType = this->getTypeConverter()->convertType(toType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -846,7 +847,7 @@ public:
|
||||
}
|
||||
|
||||
// Function returns a single result.
|
||||
auto dstType = typeConverter.convertType(callOp.getType(0));
|
||||
auto dstType = getTypeConverter()->convertType(callOp.getType(0));
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
|
||||
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
@ -868,7 +869,7 @@ public:
|
||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -888,7 +889,7 @@ public:
|
||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -907,7 +908,7 @@ public:
|
||||
matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getType();
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -930,7 +931,7 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.getMemoryAccess()) {
|
||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||
this->typeConverter, /*alignment=*/0,
|
||||
*this->getTypeConverter(), /*alignment=*/0,
|
||||
/*isVolatile=*/false,
|
||||
/*isNonTemporal=*/false);
|
||||
}
|
||||
@ -945,8 +946,8 @@ public:
|
||||
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
|
||||
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
|
||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||
this->typeConverter, alignment, isVolatile,
|
||||
isNonTemporal);
|
||||
*this->getTypeConverter(), alignment,
|
||||
isVolatile, isNonTemporal);
|
||||
}
|
||||
default:
|
||||
// There is no support of other memory access attributes.
|
||||
@ -965,7 +966,7 @@ public:
|
||||
matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = notOp.getType();
|
||||
auto dstType = this->typeConverter.convertType(srcType);
|
||||
auto dstType = this->getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(notOp, "type conversion failed");
|
||||
|
||||
@ -1196,7 +1197,7 @@ public:
|
||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto dstType = this->typeConverter.convertType(op.getType());
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
@ -1247,7 +1248,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(tanOp.getType());
|
||||
auto dstType = getTypeConverter()->convertType(tanOp.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
|
||||
|
||||
@ -1273,7 +1274,7 @@ public:
|
||||
matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = tanhOp.getType();
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
|
||||
|
||||
@ -1307,21 +1308,21 @@ public:
|
||||
if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
|
||||
return failure();
|
||||
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
auto dstType = getTypeConverter()->convertType(srcType);
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
|
||||
|
||||
Location loc = varOp.getLoc();
|
||||
Value size = createI32ConstantOf(loc, rewriter, 1);
|
||||
if (!init) {
|
||||
auto elementType = typeConverter.convertType(pointerTo);
|
||||
auto elementType = getTypeConverter()->convertType(pointerTo);
|
||||
if (!elementType)
|
||||
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
|
||||
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
|
||||
size);
|
||||
return success();
|
||||
}
|
||||
auto elementType = typeConverter.convertType(pointerTo);
|
||||
auto elementType = getTypeConverter()->convertType(pointerTo);
|
||||
if (!elementType)
|
||||
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
|
||||
Value allocated =
|
||||
@ -1344,7 +1345,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(bitcastOp.getType());
|
||||
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
|
||||
|
||||
@ -1377,7 +1378,7 @@ public:
|
||||
auto funcType = funcOp.getFunctionType();
|
||||
TypeConverter::SignatureConversion signatureConverter(
|
||||
funcType.getNumInputs());
|
||||
auto llvmType = typeConverter.convertFunctionSignature(
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
|
||||
signatureConverter);
|
||||
if (!llvmType)
|
||||
@ -1418,8 +1419,8 @@ public:
|
||||
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
|
||||
&signatureConverter))) {
|
||||
if (failed(rewriter.convertRegionTypes(
|
||||
&newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.eraseOp(funcOp);
|
||||
@ -1474,7 +1475,7 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
auto dstType = typeConverter.convertType(op.getType());
|
||||
auto dstType = getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
auto scalarType = cast<VectorType>(dstType).getElementType();
|
||||
@ -1535,7 +1536,7 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
|
||||
}
|
||||
|
||||
void mlir::populateSPIRVToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
spirv::ClientAPI clientAPI) {
|
||||
patterns.add<
|
||||
// Arithmetic ops
|
||||
@ -1653,12 +1654,12 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
||||
}
|
||||
|
||||
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
|
||||
}
|
||||
|
||||
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,7 @@ namespace {
|
||||
class TensorExtractPattern final
|
||||
: public OpConversionPattern<tensor::ExtractOp> {
|
||||
public:
|
||||
TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
|
||||
TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
|
||||
int64_t threshold, PatternBenefit benefit = 1)
|
||||
: OpConversionPattern(typeConverter, context, benefit),
|
||||
byteCountThreshold(threshold) {}
|
||||
@ -103,9 +103,9 @@ private:
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
int64_t byteCountThreshold,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateTensorToSPIRVPatterns(
|
||||
const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
|
||||
byteCountThreshold);
|
||||
}
|
||||
|
@ -2588,7 +2588,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToLinalgConversionPatterns(
|
||||
TypeConverter &converter, RewritePatternSet *patterns) {
|
||||
const TypeConverter &converter, RewritePatternSet *patterns) {
|
||||
|
||||
// We have multiple resize coverters to handle degenerate cases.
|
||||
patterns->add<GenericResizeConverter>(patterns->getContext(),
|
||||
|
@ -438,7 +438,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToTensorConversionPatterns(
|
||||
TypeConverter &converter, RewritePatternSet *patterns) {
|
||||
const TypeConverter &converter, RewritePatternSet *patterns) {
|
||||
patterns
|
||||
->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
|
||||
converter, patterns->getContext());
|
||||
|
@ -91,8 +91,8 @@ struct UBToLLVMConversionPass
|
||||
// Pattern Population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::ub::populateUBToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::ub::populateUBToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<PoisonOpLowering>(converter);
|
||||
}
|
||||
|
||||
|
@ -79,6 +79,6 @@ struct UBToSPIRVConversionPass final
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::ub::populateUBToSPIRVConversionPatterns(
|
||||
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<PoisonOpLowering>(converter, patterns.getContext());
|
||||
}
|
||||
|
@ -1881,7 +1881,7 @@ struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions, bool force32BitVectorIndices) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||
@ -1909,7 +1909,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<VectorMatmulOpConversion>(converter);
|
||||
patterns.add<VectorFlatTransposeOpConversion>(converter);
|
||||
}
|
||||
|
@ -950,8 +950,8 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
|
||||
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
|
||||
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
|
||||
|
||||
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateVectorToSPIRVPatterns(
|
||||
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
VectorBitcastConvert, VectorBroadcastConvert,
|
||||
VectorExtractElementOpConvert, VectorExtractOpConvert,
|
||||
|
@ -203,7 +203,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::populateAMXLegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
|
||||
TileMulFConversion, TileMulIConversion>(converter);
|
||||
}
|
||||
|
@ -51,7 +51,8 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
|
||||
}
|
||||
|
||||
void arith::populateArithNarrowTypeEmulationPatterns(
|
||||
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
// Populate `func.*` conversion patterns.
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
typeConverter);
|
||||
|
@ -41,7 +41,7 @@ struct EmulateUnsupportedFloatsPass
|
||||
};
|
||||
|
||||
struct EmulateFloatPattern final : ConversionPattern {
|
||||
EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
|
||||
EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
|
||||
: ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override;
|
||||
@ -106,12 +106,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
|
||||
}
|
||||
|
||||
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
|
||||
RewritePatternSet &patterns, TypeConverter &converter) {
|
||||
RewritePatternSet &patterns, const TypeConverter &converter) {
|
||||
patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::arith::populateEmulateUnsupportedFloatsLegality(
|
||||
ConversionTarget &target, TypeConverter &converter) {
|
||||
ConversionTarget &target, const TypeConverter &converter) {
|
||||
// Don't try to legalize functions and other ops that don't need expansion.
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
|
||||
target.addDynamicallyLegalDialect<arith::ArithDialect>(
|
||||
|
@ -1122,7 +1122,8 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
|
||||
}
|
||||
|
||||
void arith::populateArithWideIntEmulationPatterns(
|
||||
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const WideIntEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
// Populate `func.*` conversion patterns.
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
typeConverter);
|
||||
|
@ -200,7 +200,7 @@ struct CreateMaskOpLowering
|
||||
|
||||
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
|
||||
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// Populate conversion patterns
|
||||
|
||||
// clang-format off
|
||||
|
@ -130,7 +130,7 @@ public:
|
||||
} // namespace
|
||||
|
||||
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
|
||||
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ template <typename SourceOp>
|
||||
class DecomposeCallGraphTypesOpConversionPattern
|
||||
: public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
|
||||
DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
|
||||
MLIRContext *context,
|
||||
ValueDecomposer &decomposer,
|
||||
PatternBenefit benefit = 1)
|
||||
@ -188,7 +188,7 @@ struct DecomposeCallGraphTypesForCallOp
|
||||
} // namespace
|
||||
|
||||
void mlir::populateDecomposeCallGraphTypesPatterns(
|
||||
MLIRContext *context, TypeConverter &typeConverter,
|
||||
MLIRContext *context, const TypeConverter &typeConverter,
|
||||
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
|
||||
|
@ -44,7 +44,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
|
||||
} // namespace
|
||||
|
||||
void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
|
||||
TypeConverter &converter) {
|
||||
const TypeConverter &converter) {
|
||||
patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ public:
|
||||
BranchOpInterface>::OpInterfaceConversionPattern;
|
||||
|
||||
BranchOpInterfaceTypeConversion(
|
||||
TypeConverter &typeConverter, MLIRContext *ctx,
|
||||
const TypeConverter &typeConverter, MLIRContext *ctx,
|
||||
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
|
||||
: OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
|
||||
shouldConvertBranchOperand(shouldConvertBranchOperand) {}
|
||||
@ -115,14 +115,14 @@ public:
|
||||
} // namespace
|
||||
|
||||
void mlir::populateBranchOpInterfaceTypeConversionPattern(
|
||||
RewritePatternSet &patterns, TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, const TypeConverter &typeConverter,
|
||||
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
|
||||
patterns.add<BranchOpInterfaceTypeConversion>(
|
||||
typeConverter, patterns.getContext(), shouldConvertBranchOperand);
|
||||
}
|
||||
|
||||
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
|
||||
Operation *op, TypeConverter &converter) {
|
||||
Operation *op, const TypeConverter &converter) {
|
||||
// All successor operands of branch like operations must be rewritten.
|
||||
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
|
||||
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
|
||||
@ -137,14 +137,13 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
|
||||
return false;
|
||||
}
|
||||
|
||||
void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
|
||||
TypeConverter &typeConverter) {
|
||||
void mlir::populateReturnOpTypeConversionPattern(
|
||||
RewritePatternSet &patterns, const TypeConverter &typeConverter) {
|
||||
patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
|
||||
TypeConverter &converter,
|
||||
bool returnOpAlwaysLegal) {
|
||||
bool mlir::isLegalForReturnOpTypeConversionPattern(
|
||||
Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
|
||||
// If this is a `return` and the user pass wants to convert/transform across
|
||||
// function boundaries, then `converter` is invoked to check whether the
|
||||
// `return` op is legal.
|
||||
|
@ -72,7 +72,7 @@ public:
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
|
||||
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
// clang-format off
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user