llvm-project/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
Berke Ates b6e4d27c48
[MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (#188514)
The `PromotableRegionOpInterface` implementations use two helpers that
are likely useful for other dialects implementing this interface as
well:
- `updateTerminator`: Appends the reaching definition as an operand to a
block's terminator, falling back to a default when the block has no
entry (e.g. dead code).
- `replaceWithNewResults`: Clones an operation with additional result
types while preserving its regions, then replaces the original.

This PR extracts them into a common utility header so that downstream
dialects can reuse them directly.
I'm open to discussion about the location of these utilities.
2026-03-30 22:20:39 +02:00

52 lines
2.0 KiB
C++

//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements common utilities for implementing MemorySlot interfaces,
// in particular PromotableRegionOpInterface.
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
using namespace mlir;
void mlir::memoryslot::updateTerminator(
Block *block, Value defaultReachingDef,
const DenseMap<Block *, Value> &reachingAtBlockEnd) {
Value blockReachingDef = reachingAtBlockEnd.lookup(block);
if (!blockReachingDef)
blockReachingDef = defaultReachingDef;
Operation *terminator = block->getTerminator();
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
}
Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
Operation *op,
TypeRange resultTypes) {
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
OperationState state(op->getLoc(), op->getName(), op->getOperands(),
resultTypes, op->getAttrs());
state.propertiesAttr = op->getPropertiesAsAttribute();
unsigned numRegions = op->getNumRegions();
for (unsigned i = 0; i < numRegions; ++i)
state.addRegion();
Operation *newOp = rewriter.create(state);
rewriter.startOpModification(newOp);
rewriter.startOpModification(op);
for (unsigned i = 0; i < numRegions; ++i)
newOp->getRegion(i).takeBody(op->getRegion(i));
rewriter.finalizeOpModification(op);
rewriter.finalizeOpModification(newOp);
rewriter.replaceAllOpUsesWith(
op, newOp->getResults().take_front(op->getNumResults()));
rewriter.eraseOp(op);
return newOp;
}