llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
fabrizio-indirli c5ae550344
[mlir][scf] Refactor and improve ParallelLoopFusion (#179284)
Refactor and extend the scf::ParalleLoopFusion pass:
- Refactor code, rename functions and add comments to improve
readability
- Make the dependency analysis safer by checking for read-after-write
dependencies also with vector.load/store & vector.transfer_read/write
ops, in addition to memref.load/store, and bail out when other
unsupported ops with memory effects are found.
- Extend the cases when the fusion is applied: allow fusing also when
one of the two loops reads/writes to memory through a full view/alias of
the buffer (read/written by the dual operation in the other loop) that
can be trivially resolved, including rank-reducing full subviews.
2026-02-18 14:54:12 +00:00

873 lines
36 KiB
C++

//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements loop fusion on parallel loops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>
#include <tuple>
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::scf;
/// Verify there are no nested ParallelOps.
static bool hasNestedParallelOp(ParallelOp ploop) {
auto walkResult =
ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
return walkResult.wasInterrupted();
}
/// Verify equal iteration spaces.
static bool equalIterationSpaces(ParallelOp firstPloop,
ParallelOp secondPloop) {
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
return false;
auto matchOperands = [&](const OperandRange &lhs,
const OperandRange &rhs) -> bool {
// TODO: Extend this to support aliases and equal constants.
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
};
return matchOperands(firstPloop.getLowerBound(),
secondPloop.getLowerBound()) &&
matchOperands(firstPloop.getUpperBound(),
secondPloop.getUpperBound()) &&
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}
/// Check if both operations are the same type of memory write op and
/// write to the same memory location (same buffer and same indices).
static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
if (!op1 || !op2 || op1->getName() != op2->getName())
return false;
if (op1 == op2)
return true;
// support only these memory-writing ops for now
if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
return false;
bool opsAreIdentical =
llvm::TypeSwitch<Operation *, bool>(op1)
.Case([&](memref::StoreOp storeOp1) {
auto storeOp2 = cast<memref::StoreOp>(op2);
return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
(storeOp1.getIndices() == storeOp2.getIndices());
})
.Case([&](vector::TransferWriteOp writeOp1) {
auto writeOp2 = cast<vector::TransferWriteOp>(op2);
return (writeOp1.getBase() == writeOp2.getBase()) &&
(writeOp1.getIndices() == writeOp2.getIndices()) &&
(writeOp1.getMask() == writeOp2.getMask()) &&
(writeOp1.getValueToStore().getType() ==
writeOp2.getValueToStore().getType()) &&
(writeOp1.getInBounds() == writeOp2.getInBounds());
})
.Case([&](vector::StoreOp vecStoreOp1) {
auto vecStoreOp2 = cast<vector::StoreOp>(op2);
return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
(vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
(vecStoreOp1.getValueToStore().getType() ==
vecStoreOp2.getValueToStore().getType()) &&
(vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
(vecStoreOp1.getNontemporal() ==
vecStoreOp2.getNontemporal());
})
.Default([](Operation *) { return false; });
return opsAreIdentical;
}
/// Check if val1 (from the first parallel loop) and val2 (from the
/// second) are equivalent, considering the mapping of induction variables from
/// the first to the second parallel loop.
static bool valsAreEquivalent(Value val1, Value val2,
const IRMapping &loopsIVsMap) {
if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
loopsIVsMap.lookupOrDefault(val2) == val1)
return true;
Operation *val1DefOp = val1.getDefiningOp();
Operation *val2DefOp = val2.getDefiningOp();
if (!val1DefOp || !val2DefOp)
return false;
if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
return false;
return OperationEquivalence::isEquivalentTo(
val1DefOp, val2DefOp,
[&](Value v1, Value v2) {
return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
loopsIVsMap.lookupOrDefault(v2) == v1);
},
/*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
}
/// If the `expr` value is the result of an integer addition of `base` and a
/// constant, return the constant.
static std::optional<int64_t> getAddConstant(Value expr, Value base,
const IRMapping &loopsIVsMap) {
if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
if (auto constOp = getConstantIntValue(addOp.getLhs());
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
return constOp.value();
if (auto constOp = getConstantIntValue(addOp.getRhs());
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
return constOp.value();
return std::nullopt;
}
if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
if (auto constOp = getConstantIntValue(addOp.getLhs());
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
return constOp.value();
if (auto constOp = getConstantIntValue(addOp.getRhs());
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
return constOp.value();
return std::nullopt;
}
if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
AffineMap map = applyOp.getAffineMap();
if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
map.getNumSymbols() != 0)
return std::nullopt;
if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
return std::nullopt;
AffineExpr result = map.getResult(0);
auto bin = dyn_cast<AffineBinaryOpExpr>(result);
if (!bin || bin.getKind() != AffineExprKind::Add)
return std::nullopt;
auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
if (lhsConst && rhsDim)
return lhsConst.getValue();
if (rhsConst && lhsDim)
return rhsConst.getValue();
}
return std::nullopt;
}
// Return true if the scalar load index may hit any element covered by a
// vector.store/transfer_write along a single memref dimension. Supported cases:
//
// 1) Direct index match (with optional offset):
// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
// %x = memref.load %A[%i] : memref<...>
//
// 2) Loop IV range intersects the write range:
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
//
// 3) Constant index (or IV + constant) within the write range:
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
// %x = memref.load %A[%c2] : memref<...>
// %y = memref.load %A[%i + %c1] : memref<...>
//
// Args:
// - loadIndex: index used by the scalar load for this dimension.
// - offset: subview offset for the base memref dimension (if any).
// - writeIndex: index used by the transfer_write for this dimension. Can be
// null if the dim was dropped by a rank reducing subview, whose result is
// written by the vector.write.
// - extent: vector size along this dimension (number of elements written).
// - loopsIVsMap: IV equivalence map between fused loops.
static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
Value writeIndex, int64_t extent,
const IRMapping &loopsIVsMap) {
if (extent <= 0)
return false;
// Extract constant loop bounds for loop IVs (e.g. from scf.for).
auto getConstLoopBoundsForIV =
[](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
auto blockArg = dyn_cast<BlockArgument>(index);
if (!blockArg)
return std::nullopt;
auto *parentOp = blockArg.getOwner()->getParentOp();
auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
if (!loopLike)
return std::nullopt;
auto ranges = getConstLoopBounds(loopLike);
if (ranges.empty())
return std::nullopt;
auto ivs = loopLike.getLoopInductionVars();
if (!ivs)
return std::nullopt;
auto it = llvm::find(*ivs, blockArg);
if (it == ivs->end())
return std::nullopt;
unsigned pos = std::distance(ivs->begin(), it);
if (pos >= ranges.size())
return std::nullopt;
auto [lb, ub, step] = ranges[pos];
return std::make_tuple(lb, ub, step);
};
std::optional<int64_t> offsetConst = getConstantIntValue(offset);
std::optional<int64_t> writeConst =
writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
if (!writeConst && writeIndex) {
// Treat single-iteration IVs as constants for matching.
if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
auto [lb, ub, step] = *bounds;
if (step > 0 && ub == lb + step)
writeConst = lb;
}
}
// Check whether a loop IV is fully contained in a constant write range.
auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
int64_t rangeStart, int64_t rangeExtent) -> bool {
if (rangeExtent <= 0 || step <= 0)
return false;
if (ub <= lb)
return false;
int64_t rangeEnd = rangeStart + rangeExtent;
return lb >= rangeStart && ub <= rangeEnd;
};
if (offsetConst && writeConst) {
// Constant start of the write range; check constant load or loop IV range.
int64_t start = *offsetConst + *writeConst;
if (auto loadConst = getConstantIntValue(loadIndex))
return (*loadConst >= start && *loadConst < start + extent);
if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
auto [lb, ub, step] = *bounds;
return loopIVWithinRange(lb, ub, step, start, extent);
}
}
if (writeIndex) {
// Direct IV match (or IV + constant) against the write index.
if (offsetConst && *offsetConst == 0 &&
valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
return true;
if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
// Match load index of the form writeIndex + C within the write extent.
if (offsetConst) {
int64_t start = *offsetConst;
return (*addConst >= start && *addConst < start + extent);
}
}
return false;
}
if (auto offsetVal = dyn_cast<Value>(offset)) {
// Exact match when extent is 1 and the load hits the offset value.
if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
return true;
}
return false;
}
/// Return the base memref value used by the given memory op.
static Value getBaseMemref(Operation *op) {
// TODO: use the common interface for memory ops once available.
return llvm::TypeSwitch<Operation *, Value>(op)
.Case([&](memref::LoadOp load) { return load.getMemRef(); })
.Case([&](memref::StoreOp store) { return store.getMemRef(); })
.Case([&](vector::TransferReadOp read) { return read.getBase(); })
.Case([&](vector::TransferWriteOp write) { return write.getBase(); })
.Case([&](vector::LoadOp load) { return load.getBase(); })
.Case([&](vector::StoreOp store) { return store.getBase(); })
.Default([](Operation *) { return Value(); });
}
/// Recognize scalar memref.load of an element produced by a vector write
/// (vector.transfer_write or vector.store, optionally through a rank-reducing
/// unit-stride subview) of the same buffer. This covers the pattern where a
/// vector write stores a full lane pack and a subsequent scalar load reads an
/// element from that lane pack. EXAMPLE:
/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
/// ...
/// }
///
static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
ValueRange writeIndices, VectorType vecTy,
ArrayRef<int64_t> vectorDimForWriteDim,
const IRMapping &ivsMap) {
if (!vecTy)
return false;
Value base = writeBase;
// The write base if there is no subview, or the subview source otherwise.
MemrefValue baseMemref = nullptr;
SmallVector<OpFoldResult> offsets;
llvm::SmallBitVector droppedDims;
bool hasSubview = false;
auto *ctx = loadOp.getContext();
if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
if (!subView.hasUnitStride())
return false;
baseMemref = cast<MemrefValue>(subView.getSource());
offsets = llvm::to_vector(subView.getMixedOffsets());
droppedDims = subView.getDroppedDims();
hasSubview = true;
} else {
baseMemref = dyn_cast<MemrefValue>(base);
if (!baseMemref)
return false;
}
auto loadIndices = loadOp.getIndices();
unsigned baseRank = baseMemref.getType().getRank();
if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
return false;
unsigned writeRank = writeIndices.size();
if ((!hasSubview && writeRank != baseRank) ||
(hasSubview && offsets.size() != baseRank) ||
(vectorDimForWriteDim.size() != writeRank))
return false;
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
unsigned writeMemrefDim = 0;
for (unsigned baseDim : llvm::seq(baseRank)) {
bool wasDropped = (hasSubview && droppedDims.test(baseDim));
int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
int64_t extent = 1;
if (vectorDim >= 0) {
int64_t dimSize = vecTy.getDimSize(vectorDim);
if (dimSize == ShapedType::kDynamic)
return false;
extent = dimSize;
}
Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
OpFoldResult offset =
hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
extent, ivsMap))
return false;
if (!wasDropped)
++writeMemrefDim;
}
return true;
}
/// Recognize scalar memref.load of an element produced by a
/// vector.transfer_write
static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
vector::TransferWriteOp writeOp,
const IRMapping &ivsMap) {
auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
if (!vecTy)
return false;
unsigned writeRank = writeOp.getIndices().size();
AffineMap permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isProjectedPermutation() ||
permutationMap.getNumResults() != vecTy.getRank() ||
permutationMap.getNumDims() != writeRank)
return false;
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
if (!dimExpr)
return false;
unsigned writeDim = dimExpr.getPosition();
if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
return false;
vectorDimForWriteDim[writeDim] = vecDim;
}
return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
vecTy, vectorDimForWriteDim, ivsMap);
}
/// Recognize scalar memref.load of an element produced by a vector.store
static bool loadMatchesVectorStore(memref::LoadOp loadOp,
vector::StoreOp storeOp,
const IRMapping &ivsMap) {
auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
if (!vecTy)
return false;
unsigned writeRank = storeOp.getIndices().size();
if (vecTy.getRank() > writeRank)
return false;
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
unsigned vecRank = vecTy.getRank();
for (unsigned i = 0; i < vecRank; ++i) {
unsigned writeDim = writeRank - vecRank + i;
vectorDimForWriteDim[writeDim] = i;
}
return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
vecTy, vectorDimForWriteDim, ivsMap);
}
/// Check if both operations access the same positions of the same
/// buffer, but one of the two does it through a rank-reducing full subview of
/// the buffer (the other's base). EXAMPLE:
/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
/// to memref<2x2xf32>
/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
template <typename OpTy1, typename OpTy2>
static bool opsAccessSameIndicesViaRankReducingSubview(
OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
OpBuilder &b) {
auto base1 = cast<MemrefValue>(getBaseMemref(op1));
auto base2 = cast<MemrefValue>(getBaseMemref(op2));
if (!base1 || !base2)
return false;
auto accessThroughTrivialSubviewIsSame =
[&b](memref::SubViewOp subView, ValueRange subViewAccess,
ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
SmallVector<Value> resolvedSubviewAccess;
LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
if (failed(resolved) ||
(resolvedSubviewAccess.size() != sourceAccess.size()))
return false;
for (auto [dimIdx, resolvedIndex] :
llvm::enumerate(resolvedSubviewAccess)) {
if (!matchPattern(resolvedIndex, m_Zero()) &&
!valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
return false;
}
return true;
};
// Case 1: op1 uses a subview of op2's base.
if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
subView &&
memref::isSameViewOrTrivialAlias(
base2, cast<MemrefValue>(subView.getSource())) &&
accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
op2.getIndices(),
firstToSecondPloopIVsMap))
return true;
// Case 2: op2 uses a subview of op1's base.
if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
subView &&
memref::isSameViewOrTrivialAlias(
base1, cast<MemrefValue>(subView.getSource())) &&
accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
op1.getIndices(),
firstToSecondPloopIVsMap))
return true;
return false;
}
/// Check if both memory read/write operations access the same indices
/// (considering also the mapping of induction variables from the first to the
/// second parallel loop).
template <typename OpTy1, typename OpTy2>
static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
const IRMapping &loopsIVsMap, OpBuilder &b) {
auto indices1 = op1.getIndices();
auto indices2 = op2.getIndices();
if (indices1.size() != indices2.size())
return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
return false;
}
return true;
}
/// Check if the loadOp reads from the same memory location (same buffer,
/// same indices and same properties) as written by the storeOp.
static bool
loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp,
const IRMapping &firstToSecondPloopIVsMap,
OpBuilder &b) {
if (!loadOp || !storeOp)
return false;
// Support only these memory-reading ops for now
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
return false;
bool accessSameMemory =
llvm::TypeSwitch<Operation *, bool>(loadOp)
.Case([&](memref::LoadOp memLoadOp) {
if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
return opsAccessSameIndices(memLoadOp, memStoreOp,
firstToSecondPloopIVsMap, b);
if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
firstToSecondPloopIVsMap);
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
return loadMatchesVectorStore(memLoadOp, vecStoreOp,
firstToSecondPloopIVsMap);
return false;
})
.Case([&](vector::TransferReadOp vecReadOp) {
auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
if (!vecWriteOp)
return false;
return opsAccessSameIndices(vecReadOp, vecWriteOp,
firstToSecondPloopIVsMap, b) &&
(vecReadOp.getMask() == vecWriteOp.getMask()) &&
(vecReadOp.getInBounds() == vecWriteOp.getInBounds());
})
.Case([&](vector::LoadOp vecLoadOp) {
auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
if (!vecStoreOp)
return false;
return opsAccessSameIndices(vecLoadOp, vecStoreOp,
firstToSecondPloopIVsMap, b) &&
(vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
})
.Default([](Operation *) { return false; });
return accessSameMemory;
}
static Value getStoreOpTargetBuffer(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
.Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
.Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
.Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
.Default([](Operation *) { return Value(); });
}
/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
/// aliasing between the loadOp and storeOp can be resolved by analyzing their
/// access patterns.
static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
const IRMapping &loopsIVsMap) {
if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
transfWriteOp && isa<memref::LoadOp>(loadOp))
return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
loopsIVsMap);
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
vecStoreOp && isa<memref::LoadOp>(loadOp))
return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
loopsIVsMap);
return false;
}
/// Check that the parallel loops have no mixed access to the same buffers.
/// Return `true` if the second parallel loop does not read or write the buffers
/// written by the first loop using different indices.
static bool haveNoDataDependenciesExceptSameIndex(
ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias, OpBuilder &b) {
// Map buffers to their store/write ops in the firstPloop
DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
// Record all the memory buffers used in store/write ops found in firstPloop
llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
auto collectStoreOpsInWalk = [&](Operation *op) {
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
// Ignore ops that don't write to memory
if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
!memOpInterf.hasEffect<MemoryEffects::Free>()))
return WalkResult::advance();
// Only these memory-writing ops are supported for now:
// memref.store, vector.transfer_write, vector.store
Value storeOpBase = getStoreOpTargetBuffer(op);
if (!storeOpBase)
return WalkResult::interrupt();
// Expect the base operand to be a Memref
MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
if (!storeOpBaseMemref)
return WalkResult::interrupt();
// Get the original memref buffer, skipping full view-like ops
Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
bufferStoresInFirstPloop[buffer].push_back(op);
buffersWrittenInFirstPloop.insert(buffer);
return WalkResult::advance();
};
// Walk the first parallel loop to collect all store/write ops and their
// target buffers
if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
return false;
// Check that this load/read op encountered while walking the second parallel
// loop does not have incompatible data dependencies with the store/write ops
// collected from the first parallel loop: the loops can be fused only if in
// the 2nd loop there are no loads/stores from/to the buffers written in the
// 1st loop, except when on the same exact memory location (same indices) as
// written in the 1st loop.
auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
// To be conservative, we should stop on ops that don't advertise their
// memory effects. However, many ops don't implement MemoryEffectOpInterface
// yet, so for now we just skip them.
// TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
if (!memOpInterf &&
!loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
return WalkResult::advance();
// Ignore ops that don't read from memory, and wrapping ops that have nested
// memory effects (e.g. loops, conditionals) as they will be analyzed when
// visiting their nested ops.
if ((!memOpInterf &&
loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
(memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
return WalkResult::advance();
// Support only these memory-reading ops for now
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
!isa<MemrefValue>(loadOp->getOperand(0)))
return WalkResult::interrupt();
MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
for (Value storedMem : buffersWrittenInFirstPloop)
if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
!llvm::all_of(bufferStoresInFirstPloop[storedMem],
[&](Operation *storeOp) {
return canResolveAlias(loadOp, storeOp,
firstToSecondPloopIndices);
})) {
return WalkResult::interrupt();
}
auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
if (writeOpsIt == bufferStoresInFirstPloop.end())
return WalkResult::advance();
// Store/write ops to this buffer in the firstPloop
SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
// If the first loop has no writes to this buffer, continue
if (writeOps.empty())
return WalkResult::advance();
Operation *writeOp = writeOps.front();
// In the first parallel loop, multiple writes to the same memref are
// allowed only on the same memory location
if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
return opsWriteSameMemLocation(writeOp, otherWriteOp);
})) {
return WalkResult::interrupt();
}
// Check that the load in secondPloop reads from the same memory location as
// written by the corresponding store in firstPloop
if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
firstToSecondPloopIndices, b)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
};
// Walk the second parallel loop to check load/read ops against the stores
// collected from the first parallel loop.
return !secondPloop.getBody()
->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
.wasInterrupted();
}
/// Check that in each loop there are no read ops on the buffers written
/// by the other loop, except when reading from the same exact memory location
/// (same indices) as written in the other loop.
static bool
noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias,
OpBuilder &b) {
if (!haveNoDataDependenciesExceptSameIndex(
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
return false;
IRMapping secondToFirstPloopIndices;
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
firstPloop.getBody()->getArguments());
return haveNoDataDependenciesExceptSameIndex(
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
}
/// Check if fusion of the two parallel loops is legal:
/// i.e. no nested parallel loops, equal iteration spaces,
/// and no incompatible data dependencies between the loops.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias,
OpBuilder &b) {
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
equalIterationSpaces(firstPloop, secondPloop) &&
noIncompatibleDataDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices, mayAlias, b);
}
/// Prepend operations of firstPloop's body into secondPloop's body.
/// Update secondPloop with new loop.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
OpBuilder builder,
llvm::function_ref<bool(Value, Value)> mayAlias) {
Block *block1 = firstPloop.getBody();
Block *block2 = secondPloop.getBody();
IRMapping firstToSecondPloopIndices;
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias, builder))
return;
DominanceInfo dom;
// We are fusing first loop into second, make sure there are no users of the
// first loop results between loops.
for (Operation *user : firstPloop->getUsers())
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
return;
ValueRange inits1 = firstPloop.getInitVals();
ValueRange inits2 = secondPloop.getInitVals();
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
newInitVars.append(inits2.begin(), inits2.end());
IRRewriter b(builder);
b.setInsertionPoint(secondPloop);
auto newSecondPloop = ParallelOp::create(
b, secondPloop.getLoc(), secondPloop.getLowerBound(),
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
Block *newBlock = newSecondPloop.getBody();
auto term1 = cast<ReduceOp>(block1->getTerminator());
auto term2 = cast<ReduceOp>(block2->getTerminator());
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
newBlock->getArguments());
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
newBlock->getArguments());
ValueRange results = newSecondPloop.getResults();
if (!results.empty()) {
b.setInsertionPointToEnd(newBlock);
ValueRange reduceArgs1 = term1.getOperands();
ValueRange reduceArgs2 = term2.getOperands();
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs);
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
term1.getReductions(), term2.getReductions()))) {
Block &oldRedBlock = reg.front();
Block &newRedBlock = newReduceOp.getReductions()[i].front();
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
newRedBlock.getArguments());
}
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
}
term1->erase();
term2->erase();
firstPloop.erase();
secondPloop.erase();
secondPloop = newSecondPloop;
}
void mlir::scf::naivelyFuseParallelOps(
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
for (auto &block : region) {
ploopChains.clear();
ploopChains.push_back({});
// Not using `walk()` to traverse only top-level parallel loops and also
// make sure that there are no side-effecting ops between the parallel
// loops.
bool noSideEffects = true;
for (auto &op : block) {
if (auto ploop = dyn_cast<ParallelOp>(op)) {
if (noSideEffects) {
ploopChains.back().push_back(ploop);
} else {
ploopChains.push_back({ploop});
noSideEffects = true;
}
continue;
}
// TODO: Handle region side effects properly.
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
}
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
auto &aa = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
// If the memref is defined in one of the parallel loops body, careful
// alias analysis is needed.
// TODO: check if this is still needed as a separate check.
auto val1Def = val1.getDefiningOp();
auto val2Def = val2.getDefiningOp();
auto val1Loop =
val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
auto val2Loop =
val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
if (val1Loop != val2Loop)
return true;
return !aa.alias(val1, val2).isNo();
};
getOperation()->walk([&](Operation *child) {
for (Region &region : child->getRegions())
naivelyFuseParallelOps(region, mayAlias);
});
}
};
} // namespace
std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
return std::make_unique<ParallelLoopFusion>();
}