[mlir][LLVM] Speed up extractvalue(insertvalue) canonicalization (#176478)

The current `ExtractValueOp::fold` implementation traverses the entire
chain of `InsertValueOp`s leading up to it. This can be extremely slow
if there are a huge number of `ExtractValueOp`s using values from the
same chain.

This PR improves this significantly in cases where a large number of the
`ExtractValueOp`s are actually reading from the same `InsertValueOp`.
That is, for patterns like:

```
%i0 = llvm.insertvalue %v0, %undef[0]
%i1 = llvm.insertvalue %v1, %0[1]
...
%i999 = llvm.insertvalue %v999, %998[999]
%e0 = llvm.extractvalue %i999[0]
%e1 = llvm.extractvalue %i999[1]
...
%e999 = llvm.extractvalue %i999[999]
```

In such cases, the resolution can be performed much faster using a
canonicalisation pattern on `InsertValueOp` that is applied to `%i999`,
because we can collect all of the `ExtractValueOp`s that use it, and
then do a single traversal of the chain to resolve them.

With this change, most of the resolution happens as part of the
`InsertValueOp` canonicalisation step, and there is much less work to do
when `ExtractValueOp::fold` is run.

Note that for now, this leaves the implementation of `ExtractValueOp`
as-is so the order in which patterns are applied affects whether we see
the speedup. This requires patterns to be applied in top-down order,
which is the default for the canonicaliser pass. I am separately working
on simplifying `ExtractValueOp::fold` to do less traversal, but that
requires some care to ensure existing cases are not pessimised.
This commit is contained in:
neildhar 2026-01-19 08:51:54 -08:00 committed by GitHub
parent 8168577795
commit fc9feee1ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 99 additions and 0 deletions

View File

@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
}];
let hasVerifier = 1;
let hasCanonicalizer = 1;
string llvmInstName = "InsertValue";
string llvmBuilder = [{

View File

@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
// InsertValueOp
//===----------------------------------------------------------------------===//
namespace {
/// Update any ExtractValueOps using a given InsertValueOp to instead read from
/// the closest InsertValueOp in the chain leading up to the current op that
/// writes to the same member. This traversal could be done entirely in
/// ExtractValueOp::fold, but doing it here significantly speeds things up
/// because we can handle several ExtractValueOps with a single traversal.
/// For instance, in this example:
/// %i0 = llvm.insertvalue %v0, %undef[0]
/// %i1 = llvm.insertvalue %v1, %0[1]
/// ...
/// %i999 = llvm.insertvalue %v999, %998[999]
/// %e0 = llvm.extractvalue %i999[0]
/// %e1 = llvm.extractvalue %i999[1]
/// ...
/// %e999 = llvm.extractvalue %i999[999]
/// Individually running the folder on each extractvalue would require
/// traversing the insertvalue chain 1000 times, but running this pattern on the
/// InsertValueOp would allow us to achieve the same result with a single
/// traversal. The resulting IR after this pattern will then be:
/// %i0 = llvm.insertvalue %v0, %undef[0]
/// %i1 = llvm.insertvalue %v1, %0[1]
/// ...
/// %i999 = llvm.insertvalue %v999, %998[999]
/// %e0 = llvm.extractvalue %i0[0]
/// %e1 = llvm.extractvalue %i1[1]
/// ...
/// %e999 = llvm.extractvalue %i999[999]
struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertValueOp insertOp,
PatternRewriter &rewriter) const override {
bool changed = false;
// Map each position in the top-level struct to the ExtractOps that read
// from it. For the example in the doc-comment above this map will be empty
// when we visit ops %i0 - %i998. For %i999, it will contain:
// 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
auto insertBaseIdx = insertOp.getPosition()[0];
for (auto &use : insertOp->getUses()) {
if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
auto baseIdx = extractOp.getPosition()[0];
// We can skip reads of the member that insertOp writes to since they
// will not be updated.
if (baseIdx == insertBaseIdx)
continue;
posToExtractOps[baseIdx].push_back(extractOp);
}
}
// Walk up the chain of insertions and try to resolve the remaining
// extractions that access the same member.
Value nextContainer = insertOp.getContainer();
while (!posToExtractOps.empty()) {
auto curInsert =
dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
if (!curInsert)
break;
nextContainer = curInsert.getContainer();
// Check if any extractions read the member written by this insertion.
auto curInsertBaseIdx = curInsert.getPosition()[0];
auto it = posToExtractOps.find(curInsertBaseIdx);
if (it == posToExtractOps.end())
continue;
// Update the ExtractOps to read from the current insertion.
for (auto &extractOp : it->second) {
rewriter.modifyOpInPlace(extractOp, [&] {
extractOp.getContainerMutable().assign(curInsert);
});
}
// The entry should never be empty if it exists, so if we are at this
// point, set changed to true.
assert(!it->second.empty());
changed |= true;
posToExtractOps.erase(it);
}
// There was no insertion along the chain that wrote the member accessed by
// these extracts. So we can update them to use the top of the chain.
for (auto &[baseIdx, extracts] : posToExtractOps) {
for (auto &extractOp : extracts) {
rewriter.modifyOpInPlace(extractOp, [&] {
extractOp.getContainerMutable().assign(nextContainer);
});
}
assert(!extracts.empty() && "Empty list in map");
changed = true;
}
return success(changed);
}
};
} // namespace
void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ResolveExtractValueSource>(context);
}
/// Infer the value type from the container type and position.
static ParseResult
parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,