From fc9feee1efa870cbcf17cfee9bc15e57ecbccfd7 Mon Sep 17 00:00:00 2001 From: neildhar Date: Mon, 19 Jan 2026 08:51:54 -0800 Subject: [PATCH] [mlir][LLVM] Speed up `extractvalue(insertvalue)` canonicalization (#176478) The current `ExtractValueOp::fold` implementation traverses the entire chain of `InsertValueOp`s leading up to it. This can be extremely slow if there are a huge number of `ExtractValueOp`s using values from the same chain. This PR improves this significantly in cases where a large number of the `ExtractValueOp`s are actually reading from the same `InsertValueOp`. That is, for patterns like: ``` %i0 = llvm.insertvalue %v0, %undef[0] %i1 = llvm.insertvalue %v1, %0[1] ... %i999 = llvm.insertvalue %v999, %998[999] %e0 = llvm.extractvalue %i999[0] %e1 = llvm.extractvalue %i999[1] ... %e999 = llvm.extractvalue %i999[999] ``` In such cases, the resolution can be performed much faster using a canonicalisation pattern on `InsertValueOp` that is applied to `%i999`, because we can collect all of the `ExtractValueOp`s that use it, and then do a single traversal of the chain to resolve them. With this change, most of the resolution happens as part of the `InsertValueOp` canonicalisation step, and there is much less work to do when `ExtractValueOp::fold` is run. Note that for now, this leaves the implementation of `ExtractValueOp` as-is so the order in which patterns are applied affects whether we see the speedup. This requires patterns to be applied in top-down order, which is the default for the canonicaliser pass. I am separately working on simplifying `ExtractValueOp::fold` to do less traversal, but that requires some care to ensure existing cases are not pessimised. --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 + mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 98 +++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 971710fa3ee1..6789ca22c3d5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op< }]; let hasVerifier = 1; + let hasCanonicalizer = 1; string llvmInstName = "InsertValue"; string llvmBuilder = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index f9162b35966c..6281f0d6e0b0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state, // InsertValueOp //===----------------------------------------------------------------------===// +namespace { +/// Update any ExtractValueOps using a given InsertValueOp to instead read from +/// the closest InsertValueOp in the chain leading up to the current op that +/// writes to the same member. This traversal could be done entirely in +/// ExtractValueOp::fold, but doing it here significantly speeds things up +/// because we can handle several ExtractValueOps with a single traversal. +/// For instance, in this example: +/// %i0 = llvm.insertvalue %v0, %undef[0] +/// %i1 = llvm.insertvalue %v1, %0[1] +/// ... +/// %i999 = llvm.insertvalue %v999, %998[999] +/// %e0 = llvm.extractvalue %i999[0] +/// %e1 = llvm.extractvalue %i999[1] +/// ... +/// %e999 = llvm.extractvalue %i999[999] +/// Individually running the folder on each extractvalue would require +/// traversing the insertvalue chain 1000 times, but running this pattern on the +/// InsertValueOp would allow us to achieve the same result with a single +/// traversal. The resulting IR after this pattern will then be: +/// %i0 = llvm.insertvalue %v0, %undef[0] +/// %i1 = llvm.insertvalue %v1, %0[1] +/// ... +/// %i999 = llvm.insertvalue %v999, %998[999] +/// %e0 = llvm.extractvalue %i0[0] +/// %e1 = llvm.extractvalue %i1[1] +/// ... +/// %e999 = llvm.extractvalue %i999[999] +struct ResolveExtractValueSource : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertValueOp insertOp, + PatternRewriter &rewriter) const override { + bool changed = false; + // Map each position in the top-level struct to the ExtractOps that read + // from it. For the example in the doc-comment above this map will be empty + // when we visit ops %i0 - %i998. For %i999, it will contain: + // 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 } + DenseMap> posToExtractOps; + auto insertBaseIdx = insertOp.getPosition()[0]; + for (auto &use : insertOp->getUses()) { + if (auto extractOp = dyn_cast(use.getOwner())) { + auto baseIdx = extractOp.getPosition()[0]; + // We can skip reads of the member that insertOp writes to since they + // will not be updated. + if (baseIdx == insertBaseIdx) + continue; + posToExtractOps[baseIdx].push_back(extractOp); + } + } + // Walk up the chain of insertions and try to resolve the remaining + // extractions that access the same member. + Value nextContainer = insertOp.getContainer(); + while (!posToExtractOps.empty()) { + auto curInsert = + dyn_cast_or_null(nextContainer.getDefiningOp()); + if (!curInsert) + break; + nextContainer = curInsert.getContainer(); + + // Check if any extractions read the member written by this insertion. + auto curInsertBaseIdx = curInsert.getPosition()[0]; + auto it = posToExtractOps.find(curInsertBaseIdx); + if (it == posToExtractOps.end()) + continue; + + // Update the ExtractOps to read from the current insertion. + for (auto &extractOp : it->second) { + rewriter.modifyOpInPlace(extractOp, [&] { + extractOp.getContainerMutable().assign(curInsert); + }); + } + // The entry should never be empty if it exists, so if we are at this + // point, set changed to true. + assert(!it->second.empty()); + changed |= true; + posToExtractOps.erase(it); + } + // There was no insertion along the chain that wrote the member accessed by + // these extracts. So we can update them to use the top of the chain. + for (auto &[baseIdx, extracts] : posToExtractOps) { + for (auto &extractOp : extracts) { + rewriter.modifyOpInPlace(extractOp, [&] { + extractOp.getContainerMutable().assign(nextContainer); + }); + } + assert(!extracts.empty() && "Empty list in map"); + changed = true; + } + return success(changed); + } +}; +} // namespace + +void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + /// Infer the value type from the container type and position. static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,