[mlir][LLVM] Speed up extractvalue(insertvalue) canonicalization (#176478)
The current `ExtractValueOp::fold` implementation traverses the entire chain of `InsertValueOp`s leading up to it. This can be extremely slow if there are a huge number of `ExtractValueOp`s using values from the same chain. This PR improves this significantly in cases where a large number of the `ExtractValueOp`s are actually reading from the same `InsertValueOp`. That is, for patterns like: ``` %i0 = llvm.insertvalue %v0, %undef[0] %i1 = llvm.insertvalue %v1, %0[1] ... %i999 = llvm.insertvalue %v999, %998[999] %e0 = llvm.extractvalue %i999[0] %e1 = llvm.extractvalue %i999[1] ... %e999 = llvm.extractvalue %i999[999] ``` In such cases, the resolution can be performed much faster using a canonicalisation pattern on `InsertValueOp` that is applied to `%i999`, because we can collect all of the `ExtractValueOp`s that use it, and then do a single traversal of the chain to resolve them. With this change, most of the resolution happens as part of the `InsertValueOp` canonicalisation step, and there is much less work to do when `ExtractValueOp::fold` is run. Note that for now, this leaves the implementation of `ExtractValueOp` as-is so the order in which patterns are applied affects whether we see the speedup. This requires patterns to be applied in top-down order, which is the default for the canonicaliser pass. I am separately working on simplifying `ExtractValueOp::fold` to do less traversal, but that requires some care to ensure existing cases are not pessimised.
This commit is contained in:
parent
8168577795
commit
fc9feee1ef
@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
string llvmInstName = "InsertValue";
|
||||
string llvmBuilder = [{
|
||||
|
||||
@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
|
||||
// InsertValueOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Update any ExtractValueOps using a given InsertValueOp to instead read from
|
||||
/// the closest InsertValueOp in the chain leading up to the current op that
|
||||
/// writes to the same member. This traversal could be done entirely in
|
||||
/// ExtractValueOp::fold, but doing it here significantly speeds things up
|
||||
/// because we can handle several ExtractValueOps with a single traversal.
|
||||
/// For instance, in this example:
|
||||
/// %i0 = llvm.insertvalue %v0, %undef[0]
|
||||
/// %i1 = llvm.insertvalue %v1, %0[1]
|
||||
/// ...
|
||||
/// %i999 = llvm.insertvalue %v999, %998[999]
|
||||
/// %e0 = llvm.extractvalue %i999[0]
|
||||
/// %e1 = llvm.extractvalue %i999[1]
|
||||
/// ...
|
||||
/// %e999 = llvm.extractvalue %i999[999]
|
||||
/// Individually running the folder on each extractvalue would require
|
||||
/// traversing the insertvalue chain 1000 times, but running this pattern on the
|
||||
/// InsertValueOp would allow us to achieve the same result with a single
|
||||
/// traversal. The resulting IR after this pattern will then be:
|
||||
/// %i0 = llvm.insertvalue %v0, %undef[0]
|
||||
/// %i1 = llvm.insertvalue %v1, %0[1]
|
||||
/// ...
|
||||
/// %i999 = llvm.insertvalue %v999, %998[999]
|
||||
/// %e0 = llvm.extractvalue %i0[0]
|
||||
/// %e1 = llvm.extractvalue %i1[1]
|
||||
/// ...
|
||||
/// %e999 = llvm.extractvalue %i999[999]
|
||||
struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertValueOp insertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool changed = false;
|
||||
// Map each position in the top-level struct to the ExtractOps that read
|
||||
// from it. For the example in the doc-comment above this map will be empty
|
||||
// when we visit ops %i0 - %i998. For %i999, it will contain:
|
||||
// 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
|
||||
DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
|
||||
auto insertBaseIdx = insertOp.getPosition()[0];
|
||||
for (auto &use : insertOp->getUses()) {
|
||||
if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
|
||||
auto baseIdx = extractOp.getPosition()[0];
|
||||
// We can skip reads of the member that insertOp writes to since they
|
||||
// will not be updated.
|
||||
if (baseIdx == insertBaseIdx)
|
||||
continue;
|
||||
posToExtractOps[baseIdx].push_back(extractOp);
|
||||
}
|
||||
}
|
||||
// Walk up the chain of insertions and try to resolve the remaining
|
||||
// extractions that access the same member.
|
||||
Value nextContainer = insertOp.getContainer();
|
||||
while (!posToExtractOps.empty()) {
|
||||
auto curInsert =
|
||||
dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
|
||||
if (!curInsert)
|
||||
break;
|
||||
nextContainer = curInsert.getContainer();
|
||||
|
||||
// Check if any extractions read the member written by this insertion.
|
||||
auto curInsertBaseIdx = curInsert.getPosition()[0];
|
||||
auto it = posToExtractOps.find(curInsertBaseIdx);
|
||||
if (it == posToExtractOps.end())
|
||||
continue;
|
||||
|
||||
// Update the ExtractOps to read from the current insertion.
|
||||
for (auto &extractOp : it->second) {
|
||||
rewriter.modifyOpInPlace(extractOp, [&] {
|
||||
extractOp.getContainerMutable().assign(curInsert);
|
||||
});
|
||||
}
|
||||
// The entry should never be empty if it exists, so if we are at this
|
||||
// point, set changed to true.
|
||||
assert(!it->second.empty());
|
||||
changed |= true;
|
||||
posToExtractOps.erase(it);
|
||||
}
|
||||
// There was no insertion along the chain that wrote the member accessed by
|
||||
// these extracts. So we can update them to use the top of the chain.
|
||||
for (auto &[baseIdx, extracts] : posToExtractOps) {
|
||||
for (auto &extractOp : extracts) {
|
||||
rewriter.modifyOpInPlace(extractOp, [&] {
|
||||
extractOp.getContainerMutable().assign(nextContainer);
|
||||
});
|
||||
}
|
||||
assert(!extracts.empty() && "Empty list in map");
|
||||
changed = true;
|
||||
}
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ResolveExtractValueSource>(context);
|
||||
}
|
||||
|
||||
/// Infer the value type from the container type and position.
|
||||
static ParseResult
|
||||
parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user