[MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (#188514)

The `PromotableRegionOpInterface` implementations use two helpers that
are likely useful for other dialects implementing this interface as
well:
- `updateTerminator`: Appends the reaching definition as an operand to a
block's terminator, falling back to a default when the block has no
entry (e.g. dead code).
- `replaceWithNewResults`: Clones an operation with additional result
types while preserving its regions, then replaces the original.

This PR extracts them into a common utility header so that downstream
dialects can reuse them directly.
I'm open to discussion about the location of these utilities.
This commit is contained in:
Berke Ates 2026-03-30 22:20:39 +02:00 committed by GitHub
parent 06725d7ef5
commit b6e4d27c48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 319 additions and 63 deletions

View File

@ -0,0 +1,36 @@
//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares common utilities for implementing MemorySlot interfaces,
// in particular PromotableRegionOpInterface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace memoryslot {
/// Appends the reaching definition for the given block as an operand to its
/// terminator. If the block has no entry in `reachingAtBlockEnd` (e.g. dead
/// code or the region does not use the slot), `defaultReachingDef` is used.
void updateTerminator(Block *block, Value defaultReachingDef,
const DenseMap<Block *, Value> &reachingAtBlockEnd);
/// Creates a shallow copy of an operation with new result types, moving the
/// regions out of the original operation and deleting the original operation.
Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
TypeRange resultTypes);
} // namespace memoryslot
} // namespace mlir
#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H

View File

@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
MLIRMemorySlotUtils
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface

View File

@ -7,54 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
using namespace mlir;
using namespace mlir::scf;
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
/// Adds the corresponding reaching definition to the terminator of the block if
/// the terminator is of the provided type.
template <typename TermTy>
static void
updateTerminator(Block *block, Value defaultReachingDef,
const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
Operation *terminator = block->getTerminator();
if (!isa<TermTy>(terminator))
return;
Value blockReachingDef = reachingAtBlockEnd.lookup(block);
if (!blockReachingDef) {
// Block is dead code or the region is not using the slot, so we use the
// default provided reaching definition.
blockReachingDef = defaultReachingDef;
}
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
}
/// Creates a shallow copy of an operation with new result types, moving the
/// regions out of the original operation and deleting the original operation.
static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
TypeRange resultTypes) {
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
Operation *newOp =
mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
rewriter.startOpModification(newOp);
rewriter.startOpModification(op);
for (unsigned int i : llvm::seq(op->getNumRegions()))
newOp->getRegion(i).takeBody(op->getRegion(i));
rewriter.finalizeOpModification(op);
rewriter.finalizeOpModification(newOp);
SmallVector<Value> replacementValues(newOp->getResults().drop_back());
rewriter.replaceAllOpUsesWith(op, replacementValues);
rewriter.eraseOp(op);
return newOp;
}
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
for (Block &block : getRegion().getBlocks())
updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
if (isa<YieldOp>(block.getTerminator()))
memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
replaceWithNewResults(rewriter, getOperation(), resultTypes);
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
// Update the yield terminator to return the newly defined reaching
// definition.
updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
replaceWithNewResults(rewriter, getOperation(), resultTypes);
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
reachingAtBlockEnd);
memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
reachingAtBlockEnd);
if (getElseRegion().hasOneBlock()) {
updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
reachingAtBlockEnd);
memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
reachingAtBlockEnd);
} else {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&getElseRegion());
@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
resultTypes.push_back(slot.elemType);
Operation *newOp =
replaceWithNewResults(rewriter, getOperation(), resultTypes);
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
reachingAtBlockEnd);
memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
reachingAtBlockEnd);
for (Region &caseRegion : getCaseRegions())
updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
reachingAtBlockEnd);
memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
Operation *newOp =
replaceWithNewResults(rewriter, getOperation(), resultTypes);
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@ -339,10 +297,10 @@ Value WhileOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
updateTerminator<ConditionOp>(&getBefore().back(),
getBefore().getArguments().back(),
reachingAtBlockEnd);
updateTerminator<YieldOp>(
memoryslot::updateTerminator(&getBefore().back(),
getBefore().getArguments().back(),
reachingAtBlockEnd);
memoryslot::updateTerminator(
&getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
@ -350,6 +308,6 @@ Value WhileOp::finalizePromotion(
IRRewriter rewriter(builder);
Operation *newOp =
replaceWithNewResults(rewriter, getOperation(), resultTypes);
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}

View File

@ -1,3 +1,8 @@
set(LLVM_OPTIONAL_SOURCES
InferIntRangeCommon.cpp
MemorySlotUtils.cpp
)
add_mlir_library(MLIRInferIntRangeCommon
InferIntRangeCommon.cpp
@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRIR
)
add_mlir_library(MLIRMemorySlotUtils
MemorySlotUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
LINK_LIBS PUBLIC
MLIRIR
)

View File

@ -0,0 +1,51 @@
//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements common utilities for implementing MemorySlot interfaces,
// in particular PromotableRegionOpInterface.
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
using namespace mlir;
void mlir::memoryslot::updateTerminator(
Block *block, Value defaultReachingDef,
const DenseMap<Block *, Value> &reachingAtBlockEnd) {
Value blockReachingDef = reachingAtBlockEnd.lookup(block);
if (!blockReachingDef)
blockReachingDef = defaultReachingDef;
Operation *terminator = block->getTerminator();
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
}
Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
Operation *op,
TypeRange resultTypes) {
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
OperationState state(op->getLoc(), op->getName(), op->getOperands(),
resultTypes, op->getAttrs());
state.propertiesAttr = op->getPropertiesAsAttribute();
unsigned numRegions = op->getNumRegions();
for (unsigned i = 0; i < numRegions; ++i)
state.addRegion();
Operation *newOp = rewriter.create(state);
rewriter.startOpModification(newOp);
rewriter.startOpModification(op);
for (unsigned i = 0; i < numRegions; ++i)
newOp->getRegion(i).takeBody(op->getRegion(i));
rewriter.finalizeOpModification(op);
rewriter.finalizeOpModification(newOp);
rewriter.replaceAllOpUsesWith(
op, newOp->getResults().take_front(op->getNumResults()));
rewriter.eraseOp(op);
return newOp;
}

View File

@ -2,9 +2,14 @@ add_mlir_unittest(MLIRInterfacesTests
ControlFlowInterfacesTest.cpp
DataLayoutInterfacesTest.cpp
InferIntRangeInterfaceTest.cpp
MemorySlotUtilsTest.cpp
SideEffectInterfacesTest.cpp
InferTypeOpInterfaceTest.cpp
DEPENDS
MLIRTestInterfaceIncGen
)
target_include_directories(MLIRInterfacesTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
mlir_target_link_libraries(MLIRInterfacesTests
PRIVATE
@ -15,6 +20,8 @@ mlir_target_link_libraries(MLIRInterfacesTests
MLIRFuncDialect
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRMemorySlotUtils
MLIRParser
MLIRSideEffectInterfaces
)
target_link_libraries(MLIRInterfacesTests PRIVATE MLIRTestDialect)

View File

@ -0,0 +1,188 @@
//===- MemorySlotUtilsTest.cpp - MemorySlot utility tests -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/Parser/Parser.h"
#include "gtest/gtest.h"
using namespace mlir;
static Operation *createOp(MLIRContext &ctx, StringRef name,
TypeRange resultTypes = {},
unsigned numRegions = 0) {
return Operation::create(UnknownLoc::get(&ctx), OperationName(name, &ctx),
resultTypes, /*operands=*/{}, NamedAttrList(),
/*properties=*/nullptr, /*successors=*/{},
numRegions);
}
//===----------------------------------------------------------------------===//
// updateTerminator
//===----------------------------------------------------------------------===//
TEST(MemorySlotUtilsTest, UpdateTerminatorAppendsReachingDef) {
MLIRContext context;
context.allowUnregisteredDialects();
auto i32Ty = IntegerType::get(&context, 32);
Operation *outerOp = createOp(context, "foo.outer", {}, /*numRegions=*/1);
Block *block = new Block();
outerOp->getRegion(0).push_back(block);
Operation *defOp = createOp(context, "foo.def", {i32Ty});
block->push_back(defOp);
Operation *otherDefOp = createOp(context, "foo.other_def", {i32Ty});
block->push_back(otherDefOp);
Operation *terminator = createOp(context, "foo.terminator");
block->push_back(terminator);
DenseMap<Block *, Value> reachingAtBlockEnd;
reachingAtBlockEnd[block] = defOp->getResult(0);
// Pass otherDefOp as default, which should not be used since the map has an
// entry for this block.
memoryslot::updateTerminator(block, otherDefOp->getResult(0),
reachingAtBlockEnd);
EXPECT_EQ(terminator->getNumOperands(), 1u);
EXPECT_EQ(terminator->getOperand(0), defOp->getResult(0));
outerOp->destroy();
}
TEST(MemorySlotUtilsTest, UpdateTerminatorUsesDefaultForMissingBlock) {
MLIRContext context;
context.allowUnregisteredDialects();
auto i32Ty = IntegerType::get(&context, 32);
Operation *outerOp = createOp(context, "foo.outer", {}, /*numRegions=*/1);
Block *block = new Block();
outerOp->getRegion(0).push_back(block);
Operation *defOp = createOp(context, "foo.def", {i32Ty});
block->push_back(defOp);
Operation *otherDefOp = createOp(context, "foo.other_def", {i32Ty});
block->push_back(otherDefOp);
Operation *terminator = createOp(context, "foo.terminator");
block->push_back(terminator);
// Empty map: the default (defOp) should be used.
DenseMap<Block *, Value> reachingAtBlockEnd;
memoryslot::updateTerminator(block, defOp->getResult(0), reachingAtBlockEnd);
EXPECT_EQ(terminator->getNumOperands(), 1u);
EXPECT_EQ(terminator->getOperand(0), defOp->getResult(0));
outerOp->destroy();
}
//===----------------------------------------------------------------------===//
// replaceWithNewResults
//===----------------------------------------------------------------------===//
TEST(MemorySlotUtilsTest, ReplaceWithNewResultsAddsResults) {
MLIRContext context;
context.allowUnregisteredDialects();
auto i32Ty = IntegerType::get(&context, 32);
auto i64Ty = IntegerType::get(&context, 64);
auto f32Ty = Float32Type::get(&context);
Operation *parent = createOp(context, "foo.parent", {}, /*numRegions=*/1);
Block *block = new Block();
parent->getRegion(0).push_back(block);
Operation *op = createOp(context, "foo.op", {i32Ty});
block->push_back(op);
Operation *terminator = createOp(context, "foo.terminator");
block->push_back(terminator);
// Add two new results (i64 and f32) on top of the original i32.
IRRewriter rewriter(&context);
Operation *newOp =
memoryslot::replaceWithNewResults(rewriter, op, {i32Ty, i64Ty, f32Ty});
EXPECT_EQ(newOp->getNumResults(), 3u);
EXPECT_EQ(newOp->getResult(0).getType(), i32Ty);
EXPECT_EQ(newOp->getResult(1).getType(), i64Ty);
EXPECT_EQ(newOp->getResult(2).getType(), f32Ty);
EXPECT_EQ(newOp->getName().getStringRef(), "foo.op");
parent->destroy();
}
TEST(MemorySlotUtilsTest, ReplaceWithNewResultsPreservesRegions) {
MLIRContext context;
context.allowUnregisteredDialects();
auto i32Ty = IntegerType::get(&context, 32);
Operation *parent = createOp(context, "foo.parent", {}, /*numRegions=*/1);
Block *block = new Block();
parent->getRegion(0).push_back(block);
Operation *op = createOp(context, "foo.region_op", {}, /*numRegions=*/1);
block->push_back(op);
Block *innerBlock = new Block();
op->getRegion(0).push_back(innerBlock);
Operation *innerOp = createOp(context, "foo.inner");
innerBlock->push_back(innerOp);
Operation *terminator = createOp(context, "foo.terminator");
block->push_back(terminator);
IRRewriter rewriter(&context);
Operation *newOp = memoryslot::replaceWithNewResults(rewriter, op, {i32Ty});
EXPECT_EQ(newOp->getNumRegions(), 1u);
EXPECT_FALSE(newOp->getRegion(0).empty());
Operation &movedInnerOp = newOp->getRegion(0).front().front();
EXPECT_EQ(movedInnerOp.getName().getStringRef(), "foo.inner");
parent->destroy();
}
TEST(MemorySlotUtilsTest, ReplaceWithNewResultsPreservesProperties) {
MLIRContext context;
context.loadDialect<test::TestDialect>();
const char *src = R"mlir(
"builtin.module"() ({
test.with_properties a = 42, b = "hello", c = "world",
flag = true, array = [1, 2, 3], array32 = [4, 5]
}) : () -> ()
)mlir";
auto module = parseSourceString<ModuleOp>(src, &context);
ASSERT_TRUE(!!module);
auto &opInModule = module->getBody()->front();
ASSERT_EQ(opInModule.getName().getStringRef(), "test.with_properties");
auto i32Ty = IntegerType::get(&context, 32);
IRRewriter rewriter(&context);
Operation *newOp =
memoryslot::replaceWithNewResults(rewriter, &opInModule, {i32Ty});
std::string newStr;
{
llvm::raw_string_ostream os(newStr);
newOp->print(os);
}
EXPECT_EQ(newOp->getNumResults(), 1u);
StringRef view(newStr);
EXPECT_TRUE(view.contains("a = 42"));
EXPECT_TRUE(view.contains("b = \"hello\""));
EXPECT_TRUE(view.contains("c = \"world\""));
EXPECT_TRUE(view.contains("flag = true"));
EXPECT_TRUE(view.contains("array<i64: 1, 2, 3>"));
EXPECT_TRUE(view.contains("array<i32: 4, 5>"));
}