llvm-project/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
Jeremy Kun b533b0ec34
Define a DataFlowSolver helper that loads sensible default analyses (#143415)
Cf. https://discourse.llvm.org/t/mlir-dead-code-analysis/67568/10

Custom analysis passes will not work properly unless both
DeadCodeAnalysis and SparseConstantPropagation are loaded to the
DataFlowSolver. This is intended behavior, but surprising to many users
as shown in the thread. In lieu of a longer-term fix (which I am not
knowledgeable enough to implement myself, yet), this commit adds a
helper function that loads these two analyses, as well as providing
breadcrumbs for an explanation of the problem. The existing places in
the codebase where these two analyses are loaded for the purpose of
running other unrelated analyses are replaced by the use of the helper.

---------

Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
Co-authored-by: Oleksandr "Alex" Zinenko <azinenko@amd.com>
2025-06-20 08:16:52 -07:00

219 lines
7.8 KiB
C++

//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::dataflow;
namespace {
/// Lattice value storing the a set of memory resources that something
/// is written to.
struct WrittenToLatticeValue {
bool operator==(const WrittenToLatticeValue &other) {
return this->writes == other.writes;
}
static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
const WrittenToLatticeValue &rhs) {
WrittenToLatticeValue res = lhs;
(void)res.addWrites(rhs.writes);
return res;
}
static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
const WrittenToLatticeValue &rhs) {
// Should not be triggered by this test, but required by `Lattice<T>`
llvm_unreachable("Join should not be triggered by this test");
}
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
int sizeBefore = this->writes.size();
this->writes.insert_range(writes);
int sizeAfter = this->writes.size();
return sizeBefore == sizeAfter ? ChangeResult::NoChange
: ChangeResult::Change;
}
void print(raw_ostream &os) const {
os << "[";
llvm::interleave(
writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
os << "]";
}
void clear() { writes.clear(); }
SetVector<StringAttr> writes;
};
/// This lattice represents, for a given value, the set of memory resources that
/// this value, or anything derived from this value, is potentially written to.
struct WrittenTo : public Lattice<WrittenToLatticeValue> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
using Lattice::Lattice;
};
/// An analysis that, by going backwards along the dataflow graph, annotates
/// each value with all the memory resources it (or anything derived from it)
/// is eventually written to.
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
public:
WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
bool assumeFuncWrites)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
assumeFuncWrites(assumeFuncWrites) {}
LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
void visitBranchOperand(OpOperand &operand) override;
void visitCallOperand(OpOperand &operand) override;
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
void setToExitState(WrittenTo *lattice) override {
lattice->getValue().clear();
}
private:
bool assumeFuncWrites;
};
LogicalResult
WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) {
if (auto store = dyn_cast<memref::StoreOp>(op)) {
SetVector<StringAttr> newWrites;
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
propagateIfChanged(operands[0],
operands[0]->getValue().addWrites(newWrites));
return success();
} // By default, every result of an op depends on every operand.
for (const WrittenTo *r : results) {
for (WrittenTo *operand : operands) {
meet(operand, *r);
}
addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op));
}
return success();
}
void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
// Mark branch operands as "brancharg%d", with %d the operand number.
WrittenTo *lattice = getLatticeElement(operand.get());
SetVector<StringAttr> newWrites;
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"brancharg" + Twine(operand.getOperandNumber())));
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
// Mark call operands as "callarg%d", with %d the operand number.
WrittenTo *lattice = getLatticeElement(operand.get());
SetVector<StringAttr> newWrites;
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"callarg" + Twine(operand.getOperandNumber())));
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) {
if (!assumeFuncWrites) {
return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
results);
}
for (WrittenTo *lattice : operands) {
SetVector<StringAttr> newWrites;
StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
if (!name) {
name = StringAttr::get(call->getContext(),
call.getOperation()->getName().getStringRef());
}
newWrites.insert(name);
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
}
} // end anonymous namespace
namespace {
struct TestWrittenToPass
: public PassWrapper<TestWrittenToPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
TestWrittenToPass() = default;
TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
interprocedural = other.interprocedural;
assumeFuncWrites = other.assumeFuncWrites;
}
StringRef getArgument() const override { return "test-written-to"; }
Option<bool> interprocedural{
*this, "interprocedural", llvm::cl::init(true),
llvm::cl::desc("perform interprocedural analysis")};
Option<bool> assumeFuncWrites{
*this, "assume-func-writes", llvm::cl::init(false),
llvm::cl::desc(
"assume external functions have write effect on all arguments")};
void runOnOperation() override {
Operation *op = getOperation();
SymbolTableCollection symbolTable;
DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
loadBaselineAnalyses(solver);
solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
raw_ostream &os = llvm::outs();
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
os << "test_tag: " << tag.getValue() << ":\n";
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
assert(writtenTo && "expected a sparse lattice");
os << " operand #" << index << ": ";
writtenTo->print(os);
os << "\n";
}
for (auto [index, operand] : llvm::enumerate(op->getResults())) {
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
assert(writtenTo && "expected a sparse lattice");
os << " result #" << index << ": ";
writtenTo->print(os);
os << "\n";
}
});
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
} // end namespace test
} // end namespace mlir