[MLIR] Enable caching of type conversion in the presence of context-aware conversion (#158072)
The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones. This isn't NFC because if fixed a bug where we use to unconditionally cache when using the `convertType(Type t, ...` API, while now all APIs are aware of context-aware conversions.
This commit is contained in:
parent
4ce74bfb4d
commit
b22f94dcc5
@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a
|
||||
`Type`. A context-aware conversion function converts a `Value` into a type. The
|
||||
latter allows users to customize type conversion rules based on the IR.
|
||||
|
||||
Note: When there is at least one context-aware type conversion function, the
|
||||
result of type conversions can no longer be cached, which can increase
|
||||
compilation time. Use this feature with caution!
|
||||
Note: context-aware type conversion functions impact the ability of the
|
||||
framework to cache the conversion result. In the absence of a context-aware
|
||||
conversion, all context-free type conversions can be cached. Otherwise only the
|
||||
context-free conversions added after a context-aware type conversion can be
|
||||
cached (conversions are applied in reverse order).
|
||||
As such it is advised to add context-aware conversions as early as possible in
|
||||
the sequence of `addConversion` calls (so that they apply last).
|
||||
|
||||
A `materialization` describes how a list of values should be converted to a
|
||||
list of values with specific types. An important distinction from a
|
||||
|
||||
@ -433,7 +433,7 @@ private:
|
||||
std::is_same_v<T, Value>,
|
||||
ConversionCallbackFn>
|
||||
wrapCallback(FnT &&callback) {
|
||||
hasContextAwareTypeConversions = true;
|
||||
contextAwareTypeConversionsIndex = conversions.size();
|
||||
return [callback = std::forward<FnT>(callback)](
|
||||
PointerUnion<Type, Value> typeOrValue,
|
||||
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
|
||||
@ -555,6 +555,10 @@ private:
|
||||
cachedMultiConversions.clear();
|
||||
}
|
||||
|
||||
/// Internal implementation of the type conversion.
|
||||
LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
|
||||
SmallVectorImpl<Type> &results) const;
|
||||
|
||||
/// The set of registered conversion functions.
|
||||
SmallVector<ConversionCallbackFn, 4> conversions;
|
||||
|
||||
@ -575,10 +579,13 @@ private:
|
||||
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
|
||||
/// Whether the type converter has context-aware type conversions. I.e.,
|
||||
/// conversion rules that depend on the SSA value instead of just the type.
|
||||
/// Type conversion caching is deactivated when there are context-aware
|
||||
/// conversions because the type converter may return different results for
|
||||
/// the same input type.
|
||||
bool hasContextAwareTypeConversions = false;
|
||||
/// We store here the index in the `conversions` vector of the last added
|
||||
/// context-aware conversion, if any. This is useful because we can't cache
|
||||
/// the result of type conversion happening after context-aware conversions,
|
||||
/// because the type converter may return different results for the same input
|
||||
/// type. This is why it is recommened to add context-aware conversions first,
|
||||
/// any context-free conversions after will benefit from caching.
|
||||
int contextAwareTypeConversionsIndex = -1;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
|
||||
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
|
||||
}
|
||||
|
||||
LogicalResult TypeConverter::convertType(Type t,
|
||||
SmallVectorImpl<Type> &results) const {
|
||||
assert(t && "expected non-null type");
|
||||
|
||||
/// Internal implementation of the type conversion.
|
||||
/// This is used with either a Type or a Value as the first argument.
|
||||
/// - we can cache the context-free conversions until the last registered
|
||||
/// context-aware conversion.
|
||||
/// - we can't cache the result of type conversion happening after context-aware
|
||||
/// conversions, because the type converter may return different results for the
|
||||
/// same input type.
|
||||
LogicalResult
|
||||
TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
|
||||
SmallVectorImpl<Type> &results) const {
|
||||
assert(typeOrValue && "expected non-null type");
|
||||
Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
|
||||
: cast<Type>(typeOrValue);
|
||||
{
|
||||
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
|
||||
std::defer_lock);
|
||||
@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
|
||||
// registered first.
|
||||
size_t currentCount = results.size();
|
||||
|
||||
// We can cache the context-free conversions until the last registered
|
||||
// context-aware conversion. But only if we're processing a Value right now.
|
||||
auto isCacheable = [&](int index) {
|
||||
int numberOfConversionsUntilContextAware =
|
||||
conversions.size() - 1 - contextAwareTypeConversionsIndex;
|
||||
return index < numberOfConversionsUntilContextAware;
|
||||
};
|
||||
|
||||
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
|
||||
std::defer_lock);
|
||||
|
||||
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
|
||||
if (std::optional<LogicalResult> result = converter(t, results)) {
|
||||
if (t.getContext()->isMultithreadingEnabled())
|
||||
cacheWriteLock.lock();
|
||||
if (!succeeded(*result)) {
|
||||
assert(results.size() == currentCount &&
|
||||
"failed type conversion should not change results");
|
||||
cachedDirectConversions.try_emplace(t, nullptr);
|
||||
return failure();
|
||||
}
|
||||
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
|
||||
if (newTypes.size() == 1)
|
||||
cachedDirectConversions.try_emplace(t, newTypes.front());
|
||||
else
|
||||
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
|
||||
return success();
|
||||
} else {
|
||||
for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
|
||||
const ConversionCallbackFn &converter = indexedConverter.value();
|
||||
std::optional<LogicalResult> result = converter(typeOrValue, results);
|
||||
if (!result) {
|
||||
assert(results.size() == currentCount &&
|
||||
"failed type conversion should not change results");
|
||||
continue;
|
||||
}
|
||||
if (!isCacheable(indexedConverter.index()))
|
||||
return success();
|
||||
if (t.getContext()->isMultithreadingEnabled())
|
||||
cacheWriteLock.lock();
|
||||
if (!succeeded(*result)) {
|
||||
assert(results.size() == currentCount &&
|
||||
"failed type conversion should not change results");
|
||||
cachedDirectConversions.try_emplace(t, nullptr);
|
||||
return failure();
|
||||
}
|
||||
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
|
||||
if (newTypes.size() == 1)
|
||||
cachedDirectConversions.try_emplace(t, newTypes.front());
|
||||
else
|
||||
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult TypeConverter::convertType(Type t,
|
||||
SmallVectorImpl<Type> &results) const {
|
||||
return convertTypeImpl(t, results);
|
||||
}
|
||||
|
||||
LogicalResult TypeConverter::convertType(Value v,
|
||||
SmallVectorImpl<Type> &results) const {
|
||||
assert(v && "expected non-null value");
|
||||
|
||||
// If this type converter does not have context-aware type conversions, call
|
||||
// the type-based overload, which has caching.
|
||||
if (!hasContextAwareTypeConversions)
|
||||
return convertType(v.getType(), results);
|
||||
|
||||
// Walk the added converters in reverse order to apply the most recently
|
||||
// registered first.
|
||||
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
|
||||
if (std::optional<LogicalResult> result = converter(v, results)) {
|
||||
if (!succeeded(*result))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
return convertTypeImpl(v, results);
|
||||
}
|
||||
|
||||
Type TypeConverter::convertType(Type t) const {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user