[mlir][NFC] Mark type converter in populate... functions as const (#111250)

This commit marks the type converter in `populate...` functions as
`const`. This is useful for debugging.

Patterns already take a `const` type converter. However, some
`populate...` functions do not only add new patterns, but also add
additional type conversion rules. That makes it difficult to find the
place where a type conversion was added in the code base. With this
change, all `populate...` functions that only populate pattern now have
a `const` type converter. Programmers can then conclude from the
function signature that these functions do not register any new type
conversion rules.

Also some minor cleanups around the 1:N dialect conversion
infrastructure, which did not always pass the type converter as a
`const` object internally.
This commit is contained in:
Matthias Springer 2024-10-05 21:32:40 +02:00 committed by GitHub
parent 73683cc1ab
commit 206fad0e21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
115 changed files with 293 additions and 282 deletions

View File

@ -72,8 +72,8 @@ std::unique_ptr<mlir::Pass> createLLVMDialectToLLVMPass(
[](llvm::Module &m, llvm::raw_ostream &out) { m.print(out, nullptr); });
/// Populate the given list with patterns that convert from FIR to LLVM.
void populateFIRToLLVMConversionPatterns(fir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns,
void populateFIRToLLVMConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
fir::FIRToLLVMPassOptions &options);
/// Populate the pattern set with the PreCGRewrite patterns.

View File

@ -19,7 +19,7 @@ class LLVMTypeConverter;
/// dialect, utilised in cases where the default OpenMP dialect handling cannot
/// handle all cases for intermingled fir types and operations.
void populateOpenMPFIRToLLVMConversionPatterns(
LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns);
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns);
} // namespace fir

View File

@ -22,7 +22,7 @@ class DataLayout;
namespace cuf {
void populateCUFToFIRConversionPatterns(fir::LLVMTypeConverter &converter,
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
mlir::RewritePatternSet &patterns);

View File

@ -3823,7 +3823,7 @@ fir::createLLVMDialectToLLVMPass(llvm::raw_ostream &output,
}
void fir::populateFIRToLLVMConversionPatterns(
fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
fir::FIRToLLVMPassOptions &options) {
patterns.insert<
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,

View File

@ -93,6 +93,6 @@ struct MapInfoOpConversion
} // namespace
void fir::populateOpenMPFIRToLLVMConversionPatterns(
LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
}

View File

@ -222,7 +222,7 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
CufAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
fir::LLVMTypeConverter *typeConverter)
const fir::LLVMTypeConverter *typeConverter)
: OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
mlir::LogicalResult
@ -311,7 +311,7 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
private:
mlir::DataLayout *dl;
fir::LLVMTypeConverter *typeConverter;
const fir::LLVMTypeConverter *typeConverter;
};
struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
@ -583,7 +583,7 @@ public:
} // namespace
void cuf::populateCUFToFIRConversionPatterns(
fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
mlir::RewritePatternSet &patterns) {
patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,

View File

@ -651,7 +651,7 @@ is very small, and follows the basic pattern of any dialect conversion pass.
```
void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeExtractOp>(typeConverter,
patterns.getContext());
}

View File

@ -24,7 +24,7 @@ class Pass;
/// Note: The ROCDL target does not support the LLVM bfloat type at this time
/// and so this function will add conversions to change all `bfloat` uses
/// to `i16`.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);

View File

@ -22,7 +22,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace arith {
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter,
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void registerConvertArithToLLVMInterface(DialectRegistry &registry);

View File

@ -22,7 +22,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace arith {
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();

View File

@ -39,7 +39,7 @@ public:
};
/// Populate the given list with patterns that convert from Complex to LLVM.
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter,
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void registerConvertComplexToLLVMInterface(DialectRegistry &registry);

View File

@ -20,7 +20,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Complex ops
/// to SPIR-V ops.
void populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -29,13 +29,13 @@ namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
/// set to false, the program execution continues when a condition is
/// unsatisfied.
void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter,
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
bool abortOnFailure = true);

View File

@ -20,7 +20,7 @@ class SPIRVTypeConverter;
namespace cf {
/// Appends to a pattern list additional patterns for translating ControlFLow
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateControlFlowToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace cf
} // namespace mlir

View File

@ -39,8 +39,8 @@ convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateFuncToLLVMFuncOpConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Collect the patterns to convert from the Func dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@ -56,7 +56,7 @@ void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
/// needed if `converter.getOptions().useBarePtrCallConv` is `true`, but it's
/// not an error to provide it anyway.
void populateFuncToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
const SymbolTable *symbolTable = nullptr);
void registerConvertFuncToLLVMInterface(DialectRegistry &registry);

View File

@ -21,7 +21,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Func ops
/// to SPIR-V ops. Also adds the patterns to legalize ops not directly
/// translated to SPIR-V dialect.
void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -21,7 +21,7 @@ class TypeConverter;
#define GEN_PASS_DECL_CONVERTGPUOPSTOLLVMSPVOPS
#include "mlir/Conversion/Passes.h.inc"
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &converter,
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Populates memory space attribute conversion rules for lowering

View File

@ -32,16 +32,16 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm
/// op that is not available on every GPU.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateGpuSubgroupReduceOpLoweringPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -30,7 +30,7 @@ class GPUModuleOp;
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
/// If `runtime` is Unknown, gpu.printf will not be lowered
/// The resulting pattern set should be run over a gpu.module op
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
gpu::amd::Runtime runtime);

View File

@ -23,13 +23,13 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spirv.entry_point_abi attribute.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
/// using the KHR Cooperative Matrix extension.
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
/// conversion to the type converter.

View File

@ -21,7 +21,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace index {
void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter,
void populateIndexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void registerConvertIndexToLLVMInterface(DialectRegistry &registry);

View File

@ -21,7 +21,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace index {
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
} // namespace index

View File

@ -21,7 +21,7 @@ class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
bool approximateLog1p = true);

View File

@ -19,7 +19,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -20,7 +20,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Math ops
/// to SPIR-V ops.
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -15,7 +15,7 @@ class TypeConverter;
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter);
const TypeConverter &converter);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

View File

@ -23,7 +23,7 @@ class RewritePatternSet;
/// Collect a set of patterns to convert memory-related operations from the
/// MemRef dialect to the LLVM dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns);
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);

View File

@ -67,7 +67,7 @@ void convertMemRefTypesAndAttrs(
/// Appends to a pattern list additional patterns for translating MemRef ops
/// to SPIR-V ops.
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -34,7 +34,7 @@ MemRefType getMBarrierMemrefType(MLIRContext *context,
MBarrierGroupType barrierType);
} // namespace nvgpu
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -22,8 +22,8 @@ class RewritePatternSet;
/// Configure dynamic conversion legality of regionless operations from OpenMP
/// to LLVM.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target,
LLVMTypeConverter &typeConverter);
void configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, const LLVMTypeConverter &typeConverter);
/// Populate the given list with patterns that convert from OpenMP to LLVM.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,

View File

@ -34,7 +34,7 @@ private:
/// Collects a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the SPIR-V dialect.
void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -25,13 +25,16 @@ class ModuleOp;
template <typename SPIRVOp>
class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
public:
SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter,
SPIRVToLLVMConversion(MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: OpConversionPattern<SPIRVOp>(typeConverter, context, benefit),
typeConverter(typeConverter) {}
: OpConversionPattern<SPIRVOp>(typeConverter, context, benefit) {}
protected:
LLVMTypeConverter &typeConverter;
const LLVMTypeConverter *getTypeConverter() const {
return static_cast<const LLVMTypeConverter *>(
ConversionPattern::getTypeConverter());
}
};
/// Encodes global variable's descriptor set and binding into its name if they
@ -46,18 +49,18 @@ void populateSPIRVToLLVMTypeConversion(
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
spirv::ClientAPI clientAPIForAddressSpaceMapping =
spirv::ClientAPI::Unknown);
/// Populates the given list with patterns for function conversion from SPIR-V
/// to LLVM.
void populateSPIRVToLLVMFunctionConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
/// Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
} // namespace mlir

View File

@ -30,7 +30,7 @@ class SPIRVTypeConverter;
/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
/// away. So this has implications over register pressure. Therefore, a
/// threshold is used to control when the patterns should kick in.
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateTensorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
RewritePatternSet &patterns);

View File

@ -47,7 +47,7 @@ void addTosaToLinalgPasses(
void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
RewritePatternSet *patterns);
/// Populates conversion passes from TOSA dialect to Linalg named operations.

View File

@ -25,7 +25,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToTensor();
void populateTosaToTensorConversionPatterns(TypeConverter &converter,
void populateTosaToTensorConversionPatterns(const TypeConverter &converter,
RewritePatternSet *patterns);
} // namespace tosa

View File

@ -22,7 +22,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace ub {
void populateUBToLLVMConversionPatterns(LLVMTypeConverter &converter,
void populateUBToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void registerConvertUBToLLVMInterface(DialectRegistry &registry);

View File

@ -21,7 +21,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace ub {
void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
void populateUBToSPIRVConversionPatterns(const SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace ub
} // namespace mlir

View File

@ -16,12 +16,12 @@ class LLVMTypeConverter;
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateVectorToLLVMMatrixConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
} // namespace mlir

View File

@ -20,7 +20,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Vector Ops to
/// SPIR-V ops.
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Appends patterns to convert vector reduction of the form:

View File

@ -17,8 +17,8 @@ class RewritePatternSet;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateAMXLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.

View File

@ -28,13 +28,15 @@ class NarrowTypeEmulationConverter;
/// types into supported ones. This is done by splitting original power-of-two
/// i2N integer types into two iN halves.
void populateArithWideIntEmulationPatterns(
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
const WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Adds patterns to emulate narrow Arith and Function ops into wide
/// supported types. Users need to add conversions about the computation
/// domain of narrow types.
void populateArithNarrowTypeEmulationPatterns(
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
const NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Populate the type conversions needed to emulate the unsupported
/// `sourceTypes` with `destType`
@ -45,12 +47,12 @@ void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter,
/// Add rewrite patterns for converting operations that use illegal float types
/// to ones that use legal ones.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns,
TypeConverter &converter);
const TypeConverter &converter);
/// Set up a dialect conversion to reject arithmetic operations on unsupported
/// float types.
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target,
TypeConverter &converter);
const TypeConverter &converter);
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);

View File

@ -17,8 +17,8 @@ class RewritePatternSet;
/// Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM
/// intrinsics.
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.

View File

@ -60,7 +60,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
///
/// In particular, these are the tensor_load/buffer_cast ops.
void populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///

View File

@ -85,7 +85,7 @@ private:
/// Populates the patterns needed to drive the conversion process for
/// decomposing call graph types with the given `ValueDecomposer`.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
TypeConverter &typeConverter,
const TypeConverter &typeConverter,
ValueDecomposer &decomposer,
RewritePatternSet &patterns);

View File

@ -29,7 +29,7 @@ class RewritePatternSet;
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
const TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite branch operations to use
/// operands that have been legalized by the conversion framework. This can only
@ -41,25 +41,25 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
/// shouldConvertBranchOperand. This callback should return true if branchOp's
/// operand at index idx should be converted.
void populateBranchOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, TypeConverter &converter,
RewritePatternSet &patterns, const TypeConverter &converter,
function_ref<bool(BranchOpInterface branchOp, int idx)>
shouldConvertBranchOperand = nullptr);
/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
TypeConverter &converter);
bool isLegalForBranchOpInterfaceTypeConversionPattern(
Operation *op, const TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite `return` ops to use
/// operands that have been legalized by the conversion framework.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
const TypeConverter &converter);
/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
/// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise,
/// return false.
bool isLegalForReturnOpTypeConversionPattern(Operation *op,
TypeConverter &converter,
const TypeConverter &converter,
bool returnOpAlwaysLegal = false);
/// Return true if op is neither BranchOpInterface nor ReturnLike.

View File

@ -18,7 +18,7 @@ namespace mlir {
// Populates the provided pattern set with patterns that do 1:N type conversions
// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

View File

@ -62,7 +62,7 @@ void populateExtendToSupportedTypesTypeConverter(
void populateExtendToSupportedTypesConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter);
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
const TypeConverter &typeConverter);
} // namespace math
} // namespace mlir

View File

@ -72,7 +72,7 @@ void populateExpandReallocPatterns(RewritePatternSet &patterns,
/// Appends patterns for emulating wide integer memref operations with ops over
/// narrower integer types.
void populateMemRefWideIntEmulationPatterns(
arith::WideIntEmulationConverter &typeConverter,
const arith::WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Appends type conversions for emulating wide integer memref operations with
@ -83,7 +83,7 @@ void populateMemRefWideIntEmulationConversions(
/// Appends patterns for emulating memref operations over narrow types with ops
/// over wider types.
void populateMemRefNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Appends type conversions for emulating memref operations over narrow types

View File

@ -50,12 +50,12 @@ protected:
/// corresponding scf.yield ops need to update their types accordingly to the
/// TypeConverter, but otherwise don't care what type conversions are happening.
void populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not
/// populate the conversion target.
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter,
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Updates the ConversionTarget with dynamic legality of SCF operations based
@ -66,8 +66,8 @@ void populateSCFStructuralTypeConversionTarget(
/// Populates the provided pattern set with patterns that do 1:N type
/// conversions on (some) SCF ops. This is intended to be used with
/// applyPartialOneToNConversion.
void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateSCFStructuralOneToNTypeConversions(
const TypeConverter &typeConverter, RewritePatternSet &patterns);
/// Populate patterns for SCF software pipelining transformation. See the
/// ForLoopPipeliningPattern for the transformation details.

View File

@ -135,7 +135,7 @@ private:
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
/// types.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);

View File

@ -154,7 +154,7 @@ struct SparseIterationTypeConverter : public OneToNTypeConverter {
SparseIterationTypeConverter();
};
void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
@ -170,7 +170,7 @@ public:
};
/// Sets up sparse tensor conversion rules.
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createSparseTensorConversionPass();
@ -186,7 +186,7 @@ public:
};
/// Sets up sparse tensor codegen rules.
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter,
RewritePatternSet &patterns,
bool createSparseDeallocs,
bool enableBufferInitialization);
@ -244,7 +244,7 @@ public:
StorageSpecifierToLLVMTypeConverter();
};
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();

View File

@ -366,7 +366,7 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
/// Appends patterns for emulating vector operations over narrow types with ops
/// over wider types.
void populateVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
@ -403,10 +403,9 @@ void populateVectorLinearizeTypeConversionsAndLegality(
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
unsigned targetBitWidth);
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir

View File

@ -176,7 +176,7 @@ void populateSpecializedTransposeLoweringPatterns(
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
/// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns);
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Configure the target to support lowering X86Vector ops to ops that map to
/// LLVM intrinsics.

View File

@ -53,7 +53,7 @@ public:
/// `TypeConverter::convertSignatureArgs` and exists here with a different
/// name to reflect the broader semantic.
LogicalResult computeTypeMapping(TypeRange types,
SignatureConversion &result) {
SignatureConversion &result) const {
return convertSignatureArgs(types, result);
}
@ -126,24 +126,25 @@ public:
/// Construct a conversion pattern with the given converter, and forward the
/// remaining arguments to RewritePattern.
template <typename... Args>
RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
RewritePatternWithConverter(const TypeConverter &typeConverter,
Args &&...args)
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }
const TypeConverter *getTypeConverter() const { return typeConverter; }
template <typename ConverterTy>
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
ConverterTy *>
const ConverterTy *>
getTypeConverter() const {
return static_cast<ConverterTy *>(typeConverter);
return static_cast<const ConverterTy *>(typeConverter);
}
protected:
/// A type converter for use by this pattern.
TypeConverter *const typeConverter;
const TypeConverter *const typeConverter;
};
/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
@ -212,8 +213,8 @@ public:
template <typename SourceOp>
class OneToNOpConversionPattern : public OneToNConversionPattern {
public:
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1,
OneToNOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
benefit, context, generatedNames) {}
@ -302,11 +303,11 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
/// ops which use FunctionType to represent their type. This is intended to be
/// used with the 1:N dialect conversion.
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, TypeConverter &converter,
StringRef functionLikeOpName, const TypeConverter &converter,
RewritePatternSet &patterns);
template <typename FuncOpT>
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
TypeConverter &converter, RewritePatternSet &patterns) {
const TypeConverter &converter, RewritePatternSet &patterns) {
populateOneToNFunctionOpInterfaceTypeConversionPattern(
FuncOpT::getOperationName(), converter, patterns);
}

View File

@ -277,7 +277,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
};
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
@ -335,7 +335,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
};
struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
SchedBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
@ -725,7 +725,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
@ -737,7 +737,8 @@ struct ExtPackedFp8OpLowering final
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
@ -749,7 +750,8 @@ struct PackedTrunc2xFp8OpLowering final
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
@ -880,7 +882,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
// operation into the corresponding ROCDL instructions.
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
Chipset chipset;
@ -1052,8 +1054,8 @@ struct ConvertAMDGPUToROCDLPass
};
} // namespace
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
void mlir::populateAMDGPUToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
Chipset chipset) {
patterns
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,

View File

@ -520,7 +520,7 @@ void mlir::arith::registerConvertArithToLLVMInterface(
//===----------------------------------------------------------------------===//
void mlir::arith::populateArithToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AddFOpLowering,

View File

@ -1262,7 +1262,7 @@ public:
//===----------------------------------------------------------------------===//
void mlir::arith::populateArithToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
ConstantCompositeOpPattern,

View File

@ -323,7 +323,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
} // namespace
void mlir::populateComplexToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AbsOpConversion,

View File

@ -102,8 +102,8 @@ struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
void mlir::populateComplexToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(

View File

@ -43,7 +43,7 @@ namespace {
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
bool abortOnFailedAssert = true)
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
abortOnFailedAssert(abortOnFailedAssert) {}
@ -201,7 +201,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
} // namespace
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
@ -212,7 +212,7 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
}
void mlir::cf::populateAssertToLLVMConversionPattern(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool abortOnFailure) {
patterns.add<AssertOpLowering>(converter, abortOnFailure);
}

View File

@ -109,7 +109,7 @@ struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
//===----------------------------------------------------------------------===//
void mlir::cf::populateControlFlowToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);

View File

@ -55,7 +55,7 @@ void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
}
/// Populate patterns for each dialect.
void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
RewritePatternSet &patterns) {
arith::populateCeilFloorDivExpandOpsPatterns(patterns);

View File

@ -722,12 +722,12 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
} // namespace
void mlir::populateFuncToLLVMFuncOpConversionPattern(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<FuncOpConversion>(converter);
}
void mlir::populateFuncToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
const SymbolTable *symbolTable) {
populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
patterns.add<CallIndirectOpLowering>(converter);

View File

@ -87,7 +87,7 @@ public:
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void mlir::populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();

View File

@ -36,13 +36,13 @@ private:
IntrType intrType;
public:
explicit OpLowering(LLVMTypeConverter &typeConverter)
explicit OpLowering(const LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<Op>(typeConverter),
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
indexKind(IndexKind::Other), intrType(IntrType::None) {}
explicit OpLowering(LLVMTypeConverter &typeConverter, IndexKind indexKind,
IntrType intrType)
explicit OpLowering(const LLVMTypeConverter &typeConverter,
IndexKind indexKind, IntrType intrType)
: ConvertOpToLLVMPattern<Op>(typeConverter),
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
indexKind(indexKind), intrType(intrType) {}

View File

@ -42,9 +42,9 @@ namespace mlir {
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc,
StringRef f16Func)
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
StringRef f32Func, StringRef f64Func,
StringRef f32ApproxFunc, StringRef f16Func)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}

View File

@ -412,8 +412,8 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
}
} // namespace
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) {
void populateGpuToLLVMSPVConversionPatterns(
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
GPUSubgroupOpConversion<gpu::LaneIdOp>,
GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,

View File

@ -333,7 +333,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
}
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc = "",
StringRef f16Func = "") {
@ -343,12 +343,12 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<GPUSubgroupReduceOpLowering>(converter);
}
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::populateGpuToNVVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);

View File

@ -388,7 +388,7 @@ LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
}
void mlir::populateGpuWMMAToNVVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
WmmaElementwiseOpToNVVMLowering>(converter);

View File

@ -343,7 +343,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
}
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc,
StringRef f16Func) {
@ -353,7 +353,7 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
}
void mlir::populateGpuToROCDLConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
mlir::gpu::amd::Runtime runtime) {
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;

View File

@ -59,7 +59,8 @@ public:
/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
public:
WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
WorkGroupSizeConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
LogicalResult
@ -728,7 +729,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,

View File

@ -293,7 +293,7 @@ struct WmmaMmaOpToSPIRVLowering final
} // namespace mlir
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,

View File

@ -292,7 +292,7 @@ using ConvertIndexBoolConstant =
//===----------------------------------------------------------------------===//
void index::populateIndexToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.insert<
// clang-format off
ConvertIndexAdd,

View File

@ -338,8 +338,8 @@ struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
// Pattern Population
//===----------------------------------------------------------------------===//
void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
void index::populateIndexToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
// clang-format off
ConvertIndexAdd,

View File

@ -298,8 +298,8 @@ struct ConvertMathToLLVMPass
};
} // namespace
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
void mlir::populateMathToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool approximateLog1p) {
if (approximateLog1p)
patterns.add<Log1pOpLowering>(converter);

View File

@ -36,7 +36,7 @@ using namespace mlir;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f16Func,
StringRef f32ApproxFunc = "") {
@ -45,8 +45,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp

View File

@ -462,7 +462,7 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
patterns.add<CopySignPattern>(typeConverter, patterns.getContext());

View File

@ -179,8 +179,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
});
}
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
}

View File

@ -1667,7 +1667,7 @@ public:
} // namespace
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AllocaOpLowering,

View File

@ -926,7 +926,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,

View File

@ -1701,8 +1701,8 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
};
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::populateNVGPUToNVVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init

View File

@ -221,7 +221,7 @@ void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
} // namespace
void mlir::configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, LLVMTypeConverter &typeConverter) {
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,

View File

@ -96,7 +96,7 @@ Region::iterator getBlockIt(Region &region, unsigned index) {
template <typename OpTy>
class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
public:
SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter,
SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
: OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
@ -117,7 +117,7 @@ protected:
// conversion. There isn't a straightforward way to do that yet, as when
// converting types, ops aren't taken into consideration. Therefore, we just
// bypass the framework's type conversion for now.
SPIRVTypeConverter &typeConverter;
const SPIRVTypeConverter &typeConverter;
};
//===----------------------------------------------------------------------===//
@ -436,7 +436,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Public API
//===----------------------------------------------------------------------===//
void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
RewritePatternSet &patterns) {
patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,

View File

@ -149,7 +149,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
/// Broadcasts the value to vector with `numElements` number of elements.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
LLVMTypeConverter &typeConverter,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
auto llvmVectorType = typeConverter.convertType(vectorType);
@ -166,7 +166,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
LLVMTypeConverter &typeConverter,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
if (auto vectorType = dyn_cast<VectorType>(srcType)) {
unsigned numElements = vectorType.getNumElements();
@ -186,7 +186,8 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType,
/// Then cast `Offset` and `Count` if their bit width is different
/// from `Base` bit width.
static Value processCountOrOffset(Location loc, Value value, Type srcType,
Type dstType, LLVMTypeConverter &converter,
Type dstType,
const LLVMTypeConverter &converter,
ConversionPatternRewriter &rewriter) {
Value broadcasted =
optionallyBroadcast(loc, value, srcType, converter, rewriter);
@ -196,7 +197,7 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
/// offset to LLVM struct. Otherwise, the conversion is not supported.
static Type convertStructTypeWithOffset(spirv::StructType type,
LLVMTypeConverter &converter) {
const LLVMTypeConverter &converter) {
if (type != VulkanLayoutUtils::decorateType(type))
return nullptr;
@ -209,7 +210,7 @@ static Type convertStructTypeWithOffset(spirv::StructType type,
/// Converts SPIR-V struct with no offset to packed LLVM struct.
static Type convertStructTypePacked(spirv::StructType type,
LLVMTypeConverter &converter) {
const LLVMTypeConverter &converter) {
SmallVector<Type> elementsVector;
if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
return nullptr;
@ -226,10 +227,9 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
}
/// Utility for `spirv.Load` and `spirv.Store` conversion.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
unsigned alignment, bool isVolatile,
static LogicalResult replaceWithLoadOrStore(
Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
auto dstType = typeConverter.convertType(loadOp.getType());
@ -271,7 +271,7 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
LLVMTypeConverter &converter,
const LLVMTypeConverter &converter,
spirv::ClientAPI clientAPI) {
unsigned addressSpace =
storageClassToAddressSpace(clientAPI, type.getStorageClass());
@ -292,7 +292,7 @@ static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
/// member decorations. Also, only natural offset is supported.
static Type convertStructType(spirv::StructType type,
LLVMTypeConverter &converter) {
const LLVMTypeConverter &converter) {
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
type.getMemberDecorations(memberDecorations);
if (!memberDecorations.empty())
@ -315,20 +315,21 @@ public:
LogicalResult
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
auto dstType =
getTypeConverter()->convertType(op.getComponentPtr().getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
// To use GEP we need to add a first 0 index to go through the pointer.
auto indices = llvm::to_vector<4>(adaptor.getIndices());
Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
auto llvmIndexType = getTypeConverter()->convertType(indexType);
if (!llvmIndexType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
auto elementType = typeConverter.convertType(
auto elementType = getTypeConverter()->convertType(
cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
if (!elementType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -345,7 +346,7 @@ public:
LogicalResult
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getPointer().getType());
auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
@ -363,16 +364,16 @@ public:
matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
// Create a mask with bits set outside [Offset, Offset + Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
@ -410,7 +411,7 @@ public:
if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
return failure();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
@ -451,16 +452,16 @@ public:
matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
// Create a constant that holds the size of the `Base`.
IntegerType integerType;
@ -504,16 +505,16 @@ public:
matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
*getTypeConverter(), rewriter);
// Create a mask with bits set at [0, Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
@ -580,7 +581,7 @@ public:
LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -612,7 +613,7 @@ public:
LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -643,7 +644,7 @@ public:
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVMOp>(
@ -749,7 +750,7 @@ public:
return failure();
auto srcType = cast<spirv::PointerType>(op.getType());
auto dstType = typeConverter.convertType(srcType.getPointeeType());
auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -810,7 +811,7 @@ public:
Type fromType = op.getOperand().getType();
Type toType = op.getType();
auto dstType = this->typeConverter.convertType(toType);
auto dstType = this->getTypeConverter()->convertType(toType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -846,7 +847,7 @@ public:
}
// Function returns a single result.
auto dstType = typeConverter.convertType(callOp.getType(0));
auto dstType = getTypeConverter()->convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
@ -868,7 +869,7 @@ public:
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -888,7 +889,7 @@ public:
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -907,7 +908,7 @@ public:
matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -930,7 +931,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
if (!op.getMemoryAccess()) {
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, /*alignment=*/0,
*this->getTypeConverter(), /*alignment=*/0,
/*isVolatile=*/false,
/*isNonTemporal=*/false);
}
@ -945,8 +946,8 @@ public:
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, alignment, isVolatile,
isNonTemporal);
*this->getTypeConverter(), alignment,
isVolatile, isNonTemporal);
}
default:
// There is no support of other memory access attributes.
@ -965,7 +966,7 @@ public:
matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = notOp.getType();
auto dstType = this->typeConverter.convertType(srcType);
auto dstType = this->getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(notOp, "type conversion failed");
@ -1196,7 +1197,7 @@ public:
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@ -1247,7 +1248,7 @@ public:
LogicalResult
matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(tanOp.getType());
auto dstType = getTypeConverter()->convertType(tanOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
@ -1273,7 +1274,7 @@ public:
matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = tanhOp.getType();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
@ -1307,21 +1308,21 @@ public:
if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
return failure();
auto dstType = typeConverter.convertType(srcType);
auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Location loc = varOp.getLoc();
Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
auto elementType = typeConverter.convertType(pointerTo);
auto elementType = getTypeConverter()->convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
size);
return success();
}
auto elementType = typeConverter.convertType(pointerTo);
auto elementType = getTypeConverter()->convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Value allocated =
@ -1344,7 +1345,7 @@ public:
LogicalResult
matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(bitcastOp.getType());
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
@ -1377,7 +1378,7 @@ public:
auto funcType = funcOp.getFunctionType();
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
signatureConverter);
if (!llvmType)
@ -1418,8 +1419,8 @@ public:
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConverter))) {
if (failed(rewriter.convertRegionTypes(
&newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
return failure();
}
rewriter.eraseOp(funcOp);
@ -1474,7 +1475,7 @@ public:
return success();
}
auto dstType = typeConverter.convertType(op.getType());
auto dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto scalarType = cast<VectorType>(dstType).getElementType();
@ -1535,7 +1536,7 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
}
void mlir::populateSPIRVToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
spirv::ClientAPI clientAPI) {
patterns.add<
// Arithmetic ops
@ -1653,12 +1654,12 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
}

View File

@ -35,7 +35,7 @@ namespace {
class TensorExtractPattern final
: public OpConversionPattern<tensor::ExtractOp> {
public:
TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
int64_t threshold, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
byteCountThreshold(threshold) {}
@ -103,8 +103,8 @@ private:
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
void mlir::populateTensorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
RewritePatternSet &patterns) {
patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
byteCountThreshold);

View File

@ -2588,7 +2588,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgConversionPatterns(
TypeConverter &converter, RewritePatternSet *patterns) {
const TypeConverter &converter, RewritePatternSet *patterns) {
// We have multiple resize coverters to handle degenerate cases.
patterns->add<GenericResizeConverter>(patterns->getContext(),

View File

@ -438,7 +438,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
TypeConverter &converter, RewritePatternSet *patterns) {
const TypeConverter &converter, RewritePatternSet *patterns) {
patterns
->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
converter, patterns->getContext());

View File

@ -91,8 +91,8 @@ struct UBToLLVMConversionPass
// Pattern Population
//===----------------------------------------------------------------------===//
void mlir::ub::populateUBToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::ub::populateUBToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<PoisonOpLowering>(converter);
}

View File

@ -79,6 +79,6 @@ struct UBToSPIRVConversionPass final
//===----------------------------------------------------------------------===//
void mlir::ub::populateUBToSPIRVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<PoisonOpLowering>(converter, patterns.getContext());
}

View File

@ -1881,7 +1881,7 @@ struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
@ -1909,7 +1909,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<VectorMatmulOpConversion>(converter);
patterns.add<VectorFlatTransposeOpConversion>(converter);
}

View File

@ -950,8 +950,8 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,

View File

@ -203,7 +203,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
TileMulFConversion, TileMulIConversion>(converter);
}

View File

@ -51,7 +51,8 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
}
void arith::populateArithNarrowTypeEmulationPatterns(
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
const NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
// Populate `func.*` conversion patterns.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
typeConverter);

View File

@ -41,7 +41,7 @@ struct EmulateUnsupportedFloatsPass
};
struct EmulateFloatPattern final : ConversionPattern {
EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
: ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
LogicalResult match(Operation *op) const override;
@ -106,12 +106,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
RewritePatternSet &patterns, TypeConverter &converter) {
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
}
void mlir::arith::populateEmulateUnsupportedFloatsLegality(
ConversionTarget &target, TypeConverter &converter) {
ConversionTarget &target, const TypeConverter &converter) {
// Don't try to legalize functions and other ops that don't need expansion.
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
target.addDynamicallyLegalDialect<arith::ArithDialect>(

View File

@ -1122,7 +1122,8 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
}
void arith::populateArithWideIntEmulationPatterns(
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
const WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
// Populate `func.*` conversion patterns.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
typeConverter);

View File

@ -200,7 +200,7 @@ struct CreateMaskOpLowering
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// Populate conversion patterns
// clang-format off

View File

@ -130,7 +130,7 @@ public:
} // namespace
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
const BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
patterns.getContext());
}

View File

@ -37,7 +37,7 @@ template <typename SourceOp>
class DecomposeCallGraphTypesOpConversionPattern
: public OpConversionPattern<SourceOp> {
public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
ValueDecomposer &decomposer,
PatternBenefit benefit = 1)
@ -188,7 +188,7 @@ struct DecomposeCallGraphTypesForCallOp
} // namespace
void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, TypeConverter &typeConverter,
MLIRContext *context, const TypeConverter &typeConverter,
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
patterns
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,

View File

@ -44,7 +44,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
} // namespace
void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter) {
const TypeConverter &converter) {
patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
}
@ -59,7 +59,7 @@ public:
BranchOpInterface>::OpInterfaceConversionPattern;
BranchOpInterfaceTypeConversion(
TypeConverter &typeConverter, MLIRContext *ctx,
const TypeConverter &typeConverter, MLIRContext *ctx,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
: OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
shouldConvertBranchOperand(shouldConvertBranchOperand) {}
@ -115,14 +115,14 @@ public:
} // namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
RewritePatternSet &patterns, const TypeConverter &typeConverter,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
patterns.add<BranchOpInterfaceTypeConversion>(
typeConverter, patterns.getContext(), shouldConvertBranchOperand);
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
Operation *op, TypeConverter &converter) {
Operation *op, const TypeConverter &converter) {
// All successor operands of branch like operations must be rewritten.
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
@ -137,14 +137,13 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
return false;
}
void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
void mlir::populateReturnOpTypeConversionPattern(
RewritePatternSet &patterns, const TypeConverter &typeConverter) {
patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
}
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
TypeConverter &converter,
bool returnOpAlwaysLegal) {
bool mlir::isLegalForReturnOpTypeConversionPattern(
Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
// If this is a `return` and the user pass wants to convert/transform across
// function boundaries, then `converter` is invoked to check whether the
// `return` op is legal.

View File

@ -72,7 +72,7 @@ public:
namespace mlir {
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
// clang-format off

Some files were not shown because too many files have changed in this diff Show More