
The current implementation of getBackwardSlice will crash if an operation in the dependency chain is defined by an operation with multiple regions or blocks. Crashing is bad (and forbids many analyses from using getBackwardSlice, as well as causing existing users of getBackwardSlice to fail for IR with this property). This PR instead causes the analysis to return a failure, rather than crash in the cases it cannot compute the full slice --------- Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
297 lines
12 KiB
C++
297 lines
12 KiB
C++
//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements Analysis functions specific to slicing in Function.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Analysis/TopologicalSortUtils.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
|
|
///
|
|
/// Implements Analysis functions specific to slicing in Function.
|
|
///
|
|
|
|
using namespace mlir;
|
|
|
|
static void
|
|
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
|
|
const SliceOptions::TransitiveFilter &filter = nullptr) {
|
|
if (!op)
|
|
return;
|
|
|
|
// Evaluate whether we should keep this use.
|
|
// This is useful in particular to implement scoping; i.e. return the
|
|
// transitive forwardSlice in the current scope.
|
|
if (filter && !filter(op))
|
|
return;
|
|
|
|
for (Region ®ion : op->getRegions())
|
|
for (Block &block : region)
|
|
for (Operation &blockOp : block)
|
|
if (forwardSlice->count(&blockOp) == 0)
|
|
getForwardSliceImpl(&blockOp, forwardSlice, filter);
|
|
for (Value result : op->getResults()) {
|
|
for (Operation *userOp : result.getUsers())
|
|
if (forwardSlice->count(userOp) == 0)
|
|
getForwardSliceImpl(userOp, forwardSlice, filter);
|
|
}
|
|
|
|
forwardSlice->insert(op);
|
|
}
|
|
|
|
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
|
|
const ForwardSliceOptions &options) {
|
|
getForwardSliceImpl(op, forwardSlice, options.filter);
|
|
if (!options.inclusive) {
|
|
// Don't insert the top level operation, we just queried on it and don't
|
|
// want it in the results.
|
|
forwardSlice->remove(op);
|
|
}
|
|
|
|
// Reverse to get back the actual topological order.
|
|
// std::reverse does not work out of the box on SetVector and I want an
|
|
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
|
|
SmallVector<Operation *, 0> v(forwardSlice->takeVector());
|
|
forwardSlice->insert(v.rbegin(), v.rend());
|
|
}
|
|
|
|
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
|
|
const SliceOptions &options) {
|
|
for (Operation *user : root.getUsers())
|
|
getForwardSliceImpl(user, forwardSlice, options.filter);
|
|
|
|
// Reverse to get back the actual topological order.
|
|
// std::reverse does not work out of the box on SetVector and I want an
|
|
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
|
|
SmallVector<Operation *, 0> v(forwardSlice->takeVector());
|
|
forwardSlice->insert(v.rbegin(), v.rend());
|
|
}
|
|
|
|
static LogicalResult getBackwardSliceImpl(Operation *op,
|
|
SetVector<Operation *> *backwardSlice,
|
|
const BackwardSliceOptions &options) {
|
|
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
|
return success();
|
|
|
|
// Evaluate whether we should keep this def.
|
|
// This is useful in particular to implement scoping; i.e. return the
|
|
// transitive backwardSlice in the current scope.
|
|
if (options.filter && !options.filter(op))
|
|
return success();
|
|
|
|
auto processValue = [&](Value value) {
|
|
if (auto *definingOp = value.getDefiningOp()) {
|
|
if (backwardSlice->count(definingOp) == 0)
|
|
return getBackwardSliceImpl(definingOp, backwardSlice, options);
|
|
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
|
if (options.omitBlockArguments)
|
|
return success();
|
|
|
|
Block *block = blockArg.getOwner();
|
|
Operation *parentOp = block->getParentOp();
|
|
// TODO: determine whether we want to recurse backward into the other
|
|
// blocks of parentOp, which are not technically backward unless they flow
|
|
// into us. For now, just bail.
|
|
if (parentOp && backwardSlice->count(parentOp) == 0) {
|
|
if (parentOp->getNumRegions() == 1 &&
|
|
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
|
|
return getBackwardSliceImpl(parentOp, backwardSlice, options);
|
|
}
|
|
}
|
|
}
|
|
return failure();
|
|
};
|
|
|
|
bool succeeded = true;
|
|
|
|
if (!options.omitUsesFromAbove) {
|
|
llvm::for_each(op->getRegions(), [&](Region ®ion) {
|
|
// Walk this region recursively to collect the regions that descend from
|
|
// this op's nested regions (inclusive).
|
|
SmallPtrSet<Region *, 4> descendents;
|
|
region.walk(
|
|
[&](Region *childRegion) { descendents.insert(childRegion); });
|
|
region.walk([&](Operation *op) {
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
if (!descendents.contains(operand.get().getParentRegion()))
|
|
if (!processValue(operand.get()).succeeded()) {
|
|
return WalkResult::interrupt();
|
|
}
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
});
|
|
}
|
|
llvm::for_each(op->getOperands(), processValue);
|
|
|
|
backwardSlice->insert(op);
|
|
return success(succeeded);
|
|
}
|
|
|
|
LogicalResult mlir::getBackwardSlice(Operation *op,
|
|
SetVector<Operation *> *backwardSlice,
|
|
const BackwardSliceOptions &options) {
|
|
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
|
|
|
|
if (!options.inclusive) {
|
|
// Don't insert the top level operation, we just queried on it and don't
|
|
// want it in the results.
|
|
backwardSlice->remove(op);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
LogicalResult mlir::getBackwardSlice(Value root,
|
|
SetVector<Operation *> *backwardSlice,
|
|
const BackwardSliceOptions &options) {
|
|
if (Operation *definingOp = root.getDefiningOp()) {
|
|
return getBackwardSlice(definingOp, backwardSlice, options);
|
|
}
|
|
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
|
|
return getBackwardSlice(bbAargOwner, backwardSlice, options);
|
|
}
|
|
|
|
SetVector<Operation *>
|
|
mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
|
|
const ForwardSliceOptions &forwardSliceOptions) {
|
|
SetVector<Operation *> slice;
|
|
slice.insert(op);
|
|
|
|
unsigned currentIndex = 0;
|
|
SetVector<Operation *> backwardSlice;
|
|
SetVector<Operation *> forwardSlice;
|
|
while (currentIndex != slice.size()) {
|
|
auto *currentOp = (slice)[currentIndex];
|
|
// Compute and insert the backwardSlice starting from currentOp.
|
|
backwardSlice.clear();
|
|
LogicalResult result =
|
|
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
|
|
assert(result.succeeded());
|
|
slice.insert_range(backwardSlice);
|
|
|
|
// Compute and insert the forwardSlice starting from currentOp.
|
|
forwardSlice.clear();
|
|
getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
|
|
slice.insert_range(forwardSlice);
|
|
++currentIndex;
|
|
}
|
|
return topologicalSort(slice);
|
|
}
|
|
|
|
/// Returns true if `value` (transitively) depends on iteration-carried values
|
|
/// of the given `ancestorOp`.
|
|
static bool dependsOnCarriedVals(Value value,
|
|
ArrayRef<BlockArgument> iterCarriedArgs,
|
|
Operation *ancestorOp) {
|
|
// Compute the backward slice of the value.
|
|
SetVector<Operation *> slice;
|
|
BackwardSliceOptions sliceOptions;
|
|
sliceOptions.filter = [&](Operation *op) {
|
|
return !ancestorOp->isAncestor(op);
|
|
};
|
|
LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
|
|
assert(result.succeeded());
|
|
|
|
// Check that none of the operands of the operations in the backward slice are
|
|
// loop iteration arguments, and neither is the value itself.
|
|
SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
|
|
if (iterCarriedValSet.contains(value))
|
|
return true;
|
|
|
|
for (Operation *op : slice)
|
|
for (Value operand : op->getOperands())
|
|
if (iterCarriedValSet.contains(operand))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Utility to match a generic reduction given a list of iteration-carried
|
|
/// arguments, `iterCarriedArgs` and the position of the potential reduction
|
|
/// argument within the list, `redPos`. If a reduction is matched, returns the
|
|
/// reduced value and the topologically-sorted list of combiner operations
|
|
/// involved in the reduction. Otherwise, returns a null value.
|
|
///
|
|
/// The matching algorithm relies on the following invariants, which are subject
|
|
/// to change:
|
|
/// 1. The first combiner operation must be a binary operation with the
|
|
/// iteration-carried value and the reduced value as operands.
|
|
/// 2. The iteration-carried value and combiner operations must be side
|
|
/// effect-free, have single result and a single use.
|
|
/// 3. Combiner operations must be immediately nested in the region op
|
|
/// performing the reduction.
|
|
/// 4. Reduction def-use chain must end in a terminator op that yields the
|
|
/// next iteration/output values in the same order as the iteration-carried
|
|
/// values in `iterCarriedArgs`.
|
|
/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
|
|
/// of the region op performing the reduction.
|
|
///
|
|
/// This utility is generic enough to detect reductions involving multiple
|
|
/// combiner operations (disabled for now) across multiple dialects, including
|
|
/// Linalg, Affine and SCF. For the sake of genericity, it does not return
|
|
/// specific enum values for the combiner operations since its goal is also
|
|
/// matching reductions without pre-defined semantics in core MLIR. It's up to
|
|
/// each client to make sense out of the list of combiner operations. It's also
|
|
/// up to each client to check for additional invariants on the expected
|
|
/// reductions not covered by this generic matching.
|
|
Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
|
|
unsigned redPos,
|
|
SmallVectorImpl<Operation *> &combinerOps) {
|
|
assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
|
|
|
|
BlockArgument redCarriedVal = iterCarriedArgs[redPos];
|
|
if (!redCarriedVal.hasOneUse())
|
|
return nullptr;
|
|
|
|
// For now, the first combiner op must be a binary op.
|
|
Operation *combinerOp = *redCarriedVal.getUsers().begin();
|
|
if (combinerOp->getNumOperands() != 2)
|
|
return nullptr;
|
|
Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
|
|
? combinerOp->getOperand(1)
|
|
: combinerOp->getOperand(0);
|
|
|
|
Operation *redRegionOp =
|
|
iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
|
|
if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
|
|
return nullptr;
|
|
|
|
// Traverse the def-use chain starting from the first combiner op until a
|
|
// terminator is found. Gather all the combiner ops along the way in
|
|
// topological order.
|
|
while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
|
|
if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
|
|
!combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
|
|
return nullptr;
|
|
|
|
combinerOps.push_back(combinerOp);
|
|
combinerOp = *combinerOp->getUsers().begin();
|
|
}
|
|
|
|
// Limit matching to single combiner op until we can properly test reductions
|
|
// involving multiple combiners.
|
|
if (combinerOps.size() != 1)
|
|
return nullptr;
|
|
|
|
// Check that the yielded value is in the same position as in
|
|
// `iterCarriedArgs`.
|
|
Operation *terminatorOp = combinerOp;
|
|
if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
|
|
return nullptr;
|
|
|
|
return reducedVal;
|
|
}
|