diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 971710fa3ee1..6789ca22c3d5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op< }]; let hasVerifier = 1; + let hasCanonicalizer = 1; string llvmInstName = "InsertValue"; string llvmBuilder = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index f9162b35966c..6281f0d6e0b0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state, // InsertValueOp //===----------------------------------------------------------------------===// +namespace { +/// Update any ExtractValueOps using a given InsertValueOp to instead read from +/// the closest InsertValueOp in the chain leading up to the current op that +/// writes to the same member. This traversal could be done entirely in +/// ExtractValueOp::fold, but doing it here significantly speeds things up +/// because we can handle several ExtractValueOps with a single traversal. +/// For instance, in this example: +/// %i0 = llvm.insertvalue %v0, %undef[0] +/// %i1 = llvm.insertvalue %v1, %0[1] +/// ... +/// %i999 = llvm.insertvalue %v999, %998[999] +/// %e0 = llvm.extractvalue %i999[0] +/// %e1 = llvm.extractvalue %i999[1] +/// ... +/// %e999 = llvm.extractvalue %i999[999] +/// Individually running the folder on each extractvalue would require +/// traversing the insertvalue chain 1000 times, but running this pattern on the +/// InsertValueOp would allow us to achieve the same result with a single +/// traversal. The resulting IR after this pattern will then be: +/// %i0 = llvm.insertvalue %v0, %undef[0] +/// %i1 = llvm.insertvalue %v1, %0[1] +/// ... +/// %i999 = llvm.insertvalue %v999, %998[999] +/// %e0 = llvm.extractvalue %i0[0] +/// %e1 = llvm.extractvalue %i1[1] +/// ... +/// %e999 = llvm.extractvalue %i999[999] +struct ResolveExtractValueSource : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertValueOp insertOp, + PatternRewriter &rewriter) const override { + bool changed = false; + // Map each position in the top-level struct to the ExtractOps that read + // from it. For the example in the doc-comment above this map will be empty + // when we visit ops %i0 - %i998. For %i999, it will contain: + // 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 } + DenseMap> posToExtractOps; + auto insertBaseIdx = insertOp.getPosition()[0]; + for (auto &use : insertOp->getUses()) { + if (auto extractOp = dyn_cast(use.getOwner())) { + auto baseIdx = extractOp.getPosition()[0]; + // We can skip reads of the member that insertOp writes to since they + // will not be updated. + if (baseIdx == insertBaseIdx) + continue; + posToExtractOps[baseIdx].push_back(extractOp); + } + } + // Walk up the chain of insertions and try to resolve the remaining + // extractions that access the same member. + Value nextContainer = insertOp.getContainer(); + while (!posToExtractOps.empty()) { + auto curInsert = + dyn_cast_or_null(nextContainer.getDefiningOp()); + if (!curInsert) + break; + nextContainer = curInsert.getContainer(); + + // Check if any extractions read the member written by this insertion. + auto curInsertBaseIdx = curInsert.getPosition()[0]; + auto it = posToExtractOps.find(curInsertBaseIdx); + if (it == posToExtractOps.end()) + continue; + + // Update the ExtractOps to read from the current insertion. + for (auto &extractOp : it->second) { + rewriter.modifyOpInPlace(extractOp, [&] { + extractOp.getContainerMutable().assign(curInsert); + }); + } + // The entry should never be empty if it exists, so if we are at this + // point, set changed to true. + assert(!it->second.empty()); + changed |= true; + posToExtractOps.erase(it); + } + // There was no insertion along the chain that wrote the member accessed by + // these extracts. So we can update them to use the top of the chain. + for (auto &[baseIdx, extracts] : posToExtractOps) { + for (auto &extractOp : extracts) { + rewriter.modifyOpInPlace(extractOp, [&] { + extractOp.getContainerMutable().assign(nextContainer); + }); + } + assert(!extracts.empty() && "Empty list in map"); + changed = true; + } + return success(changed); + } +}; +} // namespace + +void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + /// Infer the value type from the container type and position. static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,