Allowing RDV to call getArgOperandsMutable() (#160415)
## Problem `RemoveDeadValues` can legally drop dead function arguments on private `func.func` callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a **segmented operand group** (i.ie, uses `AttrSizedOperandSegments`), unless the call op implements `getArgOperandsMutable` and the RDV pass actually uses it. ## Fix When RDV decides to drop callee function args, it should, for each call-site that implements `CallOpInterface`, **shrink the call's argument segment** via `getArgOperandsMutable()` using the same dead-arg indices. This keeps both the flat operand list and the `operand_segment_sizes` attribute in sync (that's what `MutableOperandRange` does when bound to the segment). ## Note This change is a no-op for: * call ops without segment operands (they still get their flat operands erased via the generic path) * call ops whose calle args weren't dropped (public, external, non-`func-func`, unresolved symbol, etc) * `llvm.call`/`llvm.invoke` (RDV doesn't drop `llvm.func` args --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This commit is contained in:
parent
acb826e64e
commit
3e746bd8fb
@ -88,6 +88,8 @@ struct FunctionToCleanUp {
|
||||
struct OperationToCleanup {
|
||||
Operation *op;
|
||||
BitVector nonLive;
|
||||
Operation *callee =
|
||||
nullptr; // Optional: For CallOpInterface ops, stores the callee function
|
||||
};
|
||||
|
||||
struct BlockArgsToCleanup {
|
||||
@ -306,19 +308,19 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
|
||||
nonLiveSet.insert(arg);
|
||||
}
|
||||
|
||||
// Do (2).
|
||||
// Do (2). (Skip creating generic operand cleanup entries for call ops.
|
||||
// Call arguments will be removed in the call-site specific segment-aware
|
||||
// cleanup, avoiding generic eraseOperands bitvector mechanics.)
|
||||
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
|
||||
for (SymbolTable::SymbolUse use : uses) {
|
||||
Operation *callOp = use.getUser();
|
||||
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
|
||||
// The number of operands in the call op may not match the number of
|
||||
// arguments in the func op.
|
||||
BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
|
||||
SmallVector<OpOperand *> callOpOperands =
|
||||
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
|
||||
for (int index : nonLiveArgs.set_bits())
|
||||
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
|
||||
cl.operands.push_back({callOp, nonLiveCallOperands});
|
||||
// Push an empty operand cleanup entry so that call-site specific logic in
|
||||
// cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
|
||||
// intentionally all false to avoid generic erasure.
|
||||
// Store the funcOp as the callee to avoid expensive symbol lookup later.
|
||||
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
|
||||
funcOp.getOperation()});
|
||||
}
|
||||
|
||||
// Do (3).
|
||||
@ -746,6 +748,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
|
||||
|
||||
// 3. Functions
|
||||
LDBG() << "Cleaning up " << list.functions.size() << " functions";
|
||||
// Record which function arguments were erased so we can shrink call-site
|
||||
// argument segments for CallOpInterface operations (e.g. ops using
|
||||
// AttrSizedOperandSegments) in the next phase.
|
||||
DenseMap<Operation *, BitVector> erasedFuncArgs;
|
||||
for (auto &f : list.functions) {
|
||||
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
|
||||
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
|
||||
@ -754,17 +760,52 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
|
||||
// Some functions may not allow erasing arguments or results. These calls
|
||||
// return failure in such cases without modifying the function, so it's okay
|
||||
// to proceed.
|
||||
(void)f.funcOp.eraseArguments(f.nonLiveArgs);
|
||||
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
|
||||
// Record only if we actually erased something.
|
||||
if (f.nonLiveArgs.any())
|
||||
erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
|
||||
}
|
||||
(void)f.funcOp.eraseResults(f.nonLiveRets);
|
||||
}
|
||||
|
||||
// 4. Operands
|
||||
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
|
||||
for (OperationToCleanup &o : list.operands) {
|
||||
if (o.op->getNumOperands() > 0) {
|
||||
LDBG() << "Erasing " << o.nonLive.count()
|
||||
<< " non-live operands from operation: "
|
||||
<< OpWithFlags(o.op, OpPrintingFlags().skipRegions());
|
||||
// Handle call-specific cleanup only when we have a cached callee reference.
|
||||
// This avoids expensive symbol lookup and is defensive against future
|
||||
// changes.
|
||||
bool handledAsCall = false;
|
||||
if (o.callee && isa<CallOpInterface>(o.op)) {
|
||||
auto call = cast<CallOpInterface>(o.op);
|
||||
auto it = erasedFuncArgs.find(o.callee);
|
||||
if (it != erasedFuncArgs.end()) {
|
||||
const BitVector &deadArgIdxs = it->second;
|
||||
MutableOperandRange args = call.getArgOperandsMutable();
|
||||
// First, erase the call arguments corresponding to erased callee
|
||||
// args. We iterate backwards to preserve indices.
|
||||
for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
|
||||
args.erase(argIdx);
|
||||
// If this operand cleanup entry also has a generic nonLive bitvector,
|
||||
// clear bits for call arguments we already erased above to avoid
|
||||
// double-erasing (which could impact other segments of ops with
|
||||
// AttrSizedOperandSegments).
|
||||
if (o.nonLive.any()) {
|
||||
// Map the argument logical index to the operand number(s) recorded.
|
||||
int operandOffset = call.getArgOperands().getBeginOperandIndex();
|
||||
for (int argIdx : deadArgIdxs.set_bits()) {
|
||||
int operandNumber = operandOffset + argIdx;
|
||||
if (operandNumber < static_cast<int>(o.nonLive.size()))
|
||||
o.nonLive.reset(operandNumber);
|
||||
}
|
||||
}
|
||||
handledAsCall = true;
|
||||
}
|
||||
}
|
||||
// Perform generic operand erasure for:
|
||||
// - Non-call operations
|
||||
// - Call operations without cached callee (where handledAsCall is false)
|
||||
// But skip call operations that were already handled via segment-aware path
|
||||
if (!handledAsCall && o.nonLive.any()) {
|
||||
o.op->eraseOperands(o.nonLive);
|
||||
}
|
||||
}
|
||||
|
||||
23
mlir/test/Transforms/remove-dead-values-call-segments.mlir
Normal file
23
mlir/test/Transforms/remove-dead-values-call-segments.mlir
Normal file
@ -0,0 +1,23 @@
|
||||
// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
|
||||
|
||||
// -----
|
||||
// Private callee: both args become dead after internal DCE; RDV drops callee
|
||||
// args and shrinks the *args* segment on the call-site to zero; sizes kept in
|
||||
// sync.
|
||||
|
||||
module {
|
||||
func.func private @callee(%x: i32, %y: i32) {
|
||||
%u = arith.addi %x, %x : i32 // %y is dead
|
||||
return
|
||||
}
|
||||
|
||||
func.func @caller(%a: i32, %b: i32) {
|
||||
// args segment initially has 2 operands.
|
||||
"test.call_with_segments"(%a, %b) { callee = @callee,
|
||||
operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
|
||||
// ^ args shrank from 2 -> 0
|
||||
@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add(&dialectCanonicalizationPattern);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestCallWithSegmentsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The op `test.call_with_segments` models a call-like operation whose operands
|
||||
// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
|
||||
// Only the middle segment represents the actual call arguments. The op uses
|
||||
// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
|
||||
// the generated `operandSegmentSizes` attribute. We provide custom helpers to
|
||||
// expose the logical call arguments as both a read-only range and a mutable
|
||||
// range bound to the proper segment so that insertion/erasure updates the
|
||||
// attribute automatically.
|
||||
|
||||
// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
|
||||
static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
|
||||
|
||||
Operation::operand_range CallWithSegmentsOp::getArgOperands() {
|
||||
// Leverage generated getters for segment sizes: slice between prefix and
|
||||
// suffix using current operand list.
|
||||
return getOperation()->getOperands().slice(getPrefix().size(),
|
||||
getArgs().size());
|
||||
}
|
||||
|
||||
MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
|
||||
Operation *op = getOperation();
|
||||
|
||||
// Obtain the canonical segment size attribute name for this op.
|
||||
auto segName =
|
||||
CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
|
||||
auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
|
||||
assert(sizesAttr && "missing operandSegmentSizes attribute on op");
|
||||
|
||||
// Compute the start and length of the args segment from the prefix size and
|
||||
// args size stored in the attribute.
|
||||
auto sizes = sizesAttr.asArrayRef();
|
||||
unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
|
||||
unsigned len = static_cast<unsigned>(sizes[1]); // args size
|
||||
|
||||
NamedAttribute segNamed(segName, sizesAttr);
|
||||
MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
|
||||
segNamed};
|
||||
|
||||
return MutableOperandRange(op, start, len, {binding});
|
||||
}
|
||||
|
||||
@ -3746,4 +3746,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def CallWithSegmentsOp : TEST_Op<"call_with_segments",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
let summary = "test call op with segmented args";
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$callee,
|
||||
Variadic<AnyType>:$prefix, // non-arg segment (e.g., 'in')
|
||||
Variadic<AnyType>:$args, // <-- the call *arguments* segment
|
||||
Variadic<AnyType>:$suffix // non-arg segment (e.g., 'out')
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
$callee `(` $prefix `:` type($prefix) `)`
|
||||
`(` $args `:` type($args) `)`
|
||||
`(` $suffix `:` type($suffix) `)` attr-dict
|
||||
}];
|
||||
|
||||
// Provide stub implementations for the ArgAndResultAttrsOpInterface.
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
|
||||
::mlir::ArrayAttr getResAttrsAttr() { return {}; }
|
||||
void setArgAttrsAttr(::mlir::ArrayAttr) {}
|
||||
void setResAttrsAttr(::mlir::ArrayAttr) {}
|
||||
::mlir::Attribute removeArgAttrsAttr() { return {}; }
|
||||
::mlir::Attribute removeResAttrsAttr() { return {}; }
|
||||
}];
|
||||
|
||||
let extraClassDefinition = [{
|
||||
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
|
||||
if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
|
||||
return ::mlir::CallInterfaceCallable(sym);
|
||||
return ::mlir::CallInterfaceCallable();
|
||||
}
|
||||
void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
|
||||
if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
|
||||
(*this)->setAttr("callee", sym);
|
||||
else
|
||||
(*this)->removeAttr("callee");
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user