[mlir][Transforms][NFC] Rename MaterializationCallbackFn (#138814)

There are two kind of materialization callbacks: one for target
materializations and one for source materializations. The callback type
for target materializations is `TargetMaterializationCallbackFn`. This
commit renames the one for source materializations from
`MaterializationCallbackFn` to `SourceMaterializationCallbackFn`, for
consistency.

There used to be a single callback type for both kind of
materializations, but the materialization function signatures have
changed over time.

Also clean up a few places in the documentation that still referred to
argument materializations.
This commit is contained in:
Matthias Springer 2025-05-08 08:22:38 +02:00 committed by GitHub
parent df4eac2f8b
commit fc8484f0e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 14 deletions

View File

@ -338,7 +338,7 @@ class TypeConverter {
typename T = typename llvm::function_traits<FnT>::template arg_t<1>> typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) { void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back( sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback))); wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
} }
/// This method registers a materialization that will be called when /// This method registers a materialization that will be called when
@ -362,7 +362,7 @@ class TypeConverter {
typename T = typename llvm::function_traits<FnT>::template arg_t<1>> typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) { void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back( targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback))); wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
} }
}; };
``` ```

View File

@ -186,7 +186,7 @@ public:
std::decay_t<FnT>>::template arg_t<1>> std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) { void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back( sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback))); wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
} }
/// This method registers a materialization that will be called when /// This method registers a materialization that will be called when
@ -330,11 +330,10 @@ private:
using ConversionCallbackFn = std::function<std::optional<LogicalResult>( using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>; Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source/argument /// The signature of the callback used to materialize a source conversion.
/// conversion.
/// ///
/// Arguments: builder, result type, inputs, location /// Arguments: builder, result type, inputs, location
using MaterializationCallbackFn = using SourceMaterializationCallbackFn =
std::function<Value(OpBuilder &, Type, ValueRange, Location)>; std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
/// The signature of the callback used to materialize a target conversion. /// The signature of the callback used to materialize a target conversion.
@ -387,12 +386,12 @@ private:
cachedMultiConversions.clear(); cachedMultiConversions.clear();
} }
/// Generate a wrapper for the given argument/source materialization /// Generate a wrapper for the given source materialization callback. The
/// callback. The callback may take any subclass of `Type` and the /// callback may take any subclass of `Type` and the wrapper will check for
/// wrapper will check for the target type to be of the expected class /// the target type to be of the expected class before calling the callback.
/// before calling the callback.
template <typename T, typename FnT> template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const { SourceMaterializationCallbackFn
wrapSourceMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)]( return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs, OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
@ -491,7 +490,7 @@ private:
SmallVector<ConversionCallbackFn, 4> conversions; SmallVector<ConversionCallbackFn, 4> conversions;
/// The list of registered materialization functions. /// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations; SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
/// The list of registered type attribute conversion functions. /// The list of registered type attribute conversion functions.
@ -740,7 +739,7 @@ public:
/// ///
/// Optionally, a type converter can be provided to build materializations. /// Optionally, a type converter can be provided to build materializations.
/// Note: If no type converter was provided or the type converter does not /// Note: If no type converter was provided or the type converter does not
/// specify any suitable argument/target materialization rules, the dialect /// specify any suitable source/target materialization rules, the dialect
/// conversion may fail to legalize unresolved materializations. /// conversion may fail to legalize unresolved materializations.
Block * Block *
applySignatureConversion(Block *block, applySignatureConversion(Block *block,

View File

@ -2959,7 +2959,7 @@ TypeConverter::convertSignatureArgs(TypeRange types,
Value TypeConverter::materializeSourceConversion(OpBuilder &builder, Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType, Location loc, Type resultType,
ValueRange inputs) const { ValueRange inputs) const {
for (const MaterializationCallbackFn &fn : for (const SourceMaterializationCallbackFn &fn :
llvm::reverse(sourceMaterializations)) llvm::reverse(sourceMaterializations))
if (Value result = fn(builder, resultType, inputs, loc)) if (Value result = fn(builder, resultType, inputs, loc))
return result; return result;