Refactor and extend the scf::ParalleLoopFusion pass: - Refactor code, rename functions and add comments to improve readability - Make the dependency analysis safer by checking for read-after-write dependencies also with vector.load/store & vector.transfer_read/write ops, in addition to memref.load/store, and bail out when other unsupported ops with memory effects are found. - Extend the cases when the fusion is applied: allow fusing also when one of the two loops reads/writes to memory through a full view/alias of the buffer (read/written by the dual operation in the other loop) that can be trivially resolved, including rank-reducing full subviews.
873 lines
36 KiB
C++
873 lines
36 KiB
C++
//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements loop fusion on parallel loops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
|
|
|
#include "mlir/Analysis/AliasAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <optional>
|
|
#include <tuple>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
|
|
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::scf;
|
|
|
|
/// Verify there are no nested ParallelOps.
|
|
static bool hasNestedParallelOp(ParallelOp ploop) {
|
|
auto walkResult =
|
|
ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
|
|
return walkResult.wasInterrupted();
|
|
}
|
|
|
|
/// Verify equal iteration spaces.
|
|
static bool equalIterationSpaces(ParallelOp firstPloop,
|
|
ParallelOp secondPloop) {
|
|
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
|
|
return false;
|
|
|
|
auto matchOperands = [&](const OperandRange &lhs,
|
|
const OperandRange &rhs) -> bool {
|
|
// TODO: Extend this to support aliases and equal constants.
|
|
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
|
|
};
|
|
return matchOperands(firstPloop.getLowerBound(),
|
|
secondPloop.getLowerBound()) &&
|
|
matchOperands(firstPloop.getUpperBound(),
|
|
secondPloop.getUpperBound()) &&
|
|
matchOperands(firstPloop.getStep(), secondPloop.getStep());
|
|
}
|
|
|
|
/// Check if both operations are the same type of memory write op and
|
|
/// write to the same memory location (same buffer and same indices).
|
|
static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
|
|
if (!op1 || !op2 || op1->getName() != op2->getName())
|
|
return false;
|
|
if (op1 == op2)
|
|
return true;
|
|
// support only these memory-writing ops for now
|
|
if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
|
|
return false;
|
|
bool opsAreIdentical =
|
|
llvm::TypeSwitch<Operation *, bool>(op1)
|
|
.Case([&](memref::StoreOp storeOp1) {
|
|
auto storeOp2 = cast<memref::StoreOp>(op2);
|
|
return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
|
|
(storeOp1.getIndices() == storeOp2.getIndices());
|
|
})
|
|
.Case([&](vector::TransferWriteOp writeOp1) {
|
|
auto writeOp2 = cast<vector::TransferWriteOp>(op2);
|
|
return (writeOp1.getBase() == writeOp2.getBase()) &&
|
|
(writeOp1.getIndices() == writeOp2.getIndices()) &&
|
|
(writeOp1.getMask() == writeOp2.getMask()) &&
|
|
(writeOp1.getValueToStore().getType() ==
|
|
writeOp2.getValueToStore().getType()) &&
|
|
(writeOp1.getInBounds() == writeOp2.getInBounds());
|
|
})
|
|
.Case([&](vector::StoreOp vecStoreOp1) {
|
|
auto vecStoreOp2 = cast<vector::StoreOp>(op2);
|
|
return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
|
|
(vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
|
|
(vecStoreOp1.getValueToStore().getType() ==
|
|
vecStoreOp2.getValueToStore().getType()) &&
|
|
(vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
|
|
(vecStoreOp1.getNontemporal() ==
|
|
vecStoreOp2.getNontemporal());
|
|
})
|
|
.Default([](Operation *) { return false; });
|
|
return opsAreIdentical;
|
|
}
|
|
|
|
/// Check if val1 (from the first parallel loop) and val2 (from the
|
|
/// second) are equivalent, considering the mapping of induction variables from
|
|
/// the first to the second parallel loop.
|
|
static bool valsAreEquivalent(Value val1, Value val2,
|
|
const IRMapping &loopsIVsMap) {
|
|
if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
|
|
loopsIVsMap.lookupOrDefault(val2) == val1)
|
|
return true;
|
|
Operation *val1DefOp = val1.getDefiningOp();
|
|
Operation *val2DefOp = val2.getDefiningOp();
|
|
if (!val1DefOp || !val2DefOp)
|
|
return false;
|
|
if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
|
|
return false;
|
|
return OperationEquivalence::isEquivalentTo(
|
|
val1DefOp, val2DefOp,
|
|
[&](Value v1, Value v2) {
|
|
return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
|
|
loopsIVsMap.lookupOrDefault(v2) == v1);
|
|
},
|
|
/*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
|
|
}
|
|
|
|
/// If the `expr` value is the result of an integer addition of `base` and a
|
|
/// constant, return the constant.
|
|
static std::optional<int64_t> getAddConstant(Value expr, Value base,
|
|
const IRMapping &loopsIVsMap) {
|
|
if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
|
|
if (auto constOp = getConstantIntValue(addOp.getLhs());
|
|
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
|
|
return constOp.value();
|
|
if (auto constOp = getConstantIntValue(addOp.getRhs());
|
|
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
|
|
return constOp.value();
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
|
|
if (auto constOp = getConstantIntValue(addOp.getLhs());
|
|
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
|
|
return constOp.value();
|
|
if (auto constOp = getConstantIntValue(addOp.getRhs());
|
|
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
|
|
return constOp.value();
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
|
|
AffineMap map = applyOp.getAffineMap();
|
|
if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
|
|
map.getNumSymbols() != 0)
|
|
return std::nullopt;
|
|
if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
|
|
return std::nullopt;
|
|
AffineExpr result = map.getResult(0);
|
|
auto bin = dyn_cast<AffineBinaryOpExpr>(result);
|
|
if (!bin || bin.getKind() != AffineExprKind::Add)
|
|
return std::nullopt;
|
|
auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
|
|
auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
|
|
auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
|
|
auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
|
|
if (lhsConst && rhsDim)
|
|
return lhsConst.getValue();
|
|
if (rhsConst && lhsDim)
|
|
return rhsConst.getValue();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Return true if the scalar load index may hit any element covered by a
|
|
// vector.store/transfer_write along a single memref dimension. Supported cases:
|
|
//
|
|
// 1) Direct index match (with optional offset):
|
|
// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
|
|
// %x = memref.load %A[%i] : memref<...>
|
|
//
|
|
// 2) Loop IV range intersects the write range:
|
|
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
|
|
// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
|
|
//
|
|
// 3) Constant index (or IV + constant) within the write range:
|
|
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
|
|
// %x = memref.load %A[%c2] : memref<...>
|
|
// %y = memref.load %A[%i + %c1] : memref<...>
|
|
//
|
|
// Args:
|
|
// - loadIndex: index used by the scalar load for this dimension.
|
|
// - offset: subview offset for the base memref dimension (if any).
|
|
// - writeIndex: index used by the transfer_write for this dimension. Can be
|
|
// null if the dim was dropped by a rank reducing subview, whose result is
|
|
// written by the vector.write.
|
|
// - extent: vector size along this dimension (number of elements written).
|
|
// - loopsIVsMap: IV equivalence map between fused loops.
|
|
static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
|
|
Value writeIndex, int64_t extent,
|
|
const IRMapping &loopsIVsMap) {
|
|
if (extent <= 0)
|
|
return false;
|
|
|
|
// Extract constant loop bounds for loop IVs (e.g. from scf.for).
|
|
auto getConstLoopBoundsForIV =
|
|
[](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
|
|
auto blockArg = dyn_cast<BlockArgument>(index);
|
|
if (!blockArg)
|
|
return std::nullopt;
|
|
auto *parentOp = blockArg.getOwner()->getParentOp();
|
|
auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
|
|
if (!loopLike)
|
|
return std::nullopt;
|
|
auto ranges = getConstLoopBounds(loopLike);
|
|
if (ranges.empty())
|
|
return std::nullopt;
|
|
|
|
auto ivs = loopLike.getLoopInductionVars();
|
|
if (!ivs)
|
|
return std::nullopt;
|
|
auto it = llvm::find(*ivs, blockArg);
|
|
if (it == ivs->end())
|
|
return std::nullopt;
|
|
unsigned pos = std::distance(ivs->begin(), it);
|
|
if (pos >= ranges.size())
|
|
return std::nullopt;
|
|
auto [lb, ub, step] = ranges[pos];
|
|
return std::make_tuple(lb, ub, step);
|
|
};
|
|
|
|
std::optional<int64_t> offsetConst = getConstantIntValue(offset);
|
|
std::optional<int64_t> writeConst =
|
|
writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
|
|
if (!writeConst && writeIndex) {
|
|
// Treat single-iteration IVs as constants for matching.
|
|
if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
|
|
auto [lb, ub, step] = *bounds;
|
|
if (step > 0 && ub == lb + step)
|
|
writeConst = lb;
|
|
}
|
|
}
|
|
|
|
// Check whether a loop IV is fully contained in a constant write range.
|
|
auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
|
|
int64_t rangeStart, int64_t rangeExtent) -> bool {
|
|
if (rangeExtent <= 0 || step <= 0)
|
|
return false;
|
|
if (ub <= lb)
|
|
return false;
|
|
int64_t rangeEnd = rangeStart + rangeExtent;
|
|
return lb >= rangeStart && ub <= rangeEnd;
|
|
};
|
|
|
|
if (offsetConst && writeConst) {
|
|
// Constant start of the write range; check constant load or loop IV range.
|
|
int64_t start = *offsetConst + *writeConst;
|
|
if (auto loadConst = getConstantIntValue(loadIndex))
|
|
return (*loadConst >= start && *loadConst < start + extent);
|
|
if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
|
|
auto [lb, ub, step] = *bounds;
|
|
return loopIVWithinRange(lb, ub, step, start, extent);
|
|
}
|
|
}
|
|
|
|
if (writeIndex) {
|
|
// Direct IV match (or IV + constant) against the write index.
|
|
if (offsetConst && *offsetConst == 0 &&
|
|
valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
|
|
return true;
|
|
if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
|
|
// Match load index of the form writeIndex + C within the write extent.
|
|
if (offsetConst) {
|
|
int64_t start = *offsetConst;
|
|
return (*addConst >= start && *addConst < start + extent);
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (auto offsetVal = dyn_cast<Value>(offset)) {
|
|
// Exact match when extent is 1 and the load hits the offset value.
|
|
if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Return the base memref value used by the given memory op.
|
|
static Value getBaseMemref(Operation *op) {
|
|
// TODO: use the common interface for memory ops once available.
|
|
return llvm::TypeSwitch<Operation *, Value>(op)
|
|
.Case([&](memref::LoadOp load) { return load.getMemRef(); })
|
|
.Case([&](memref::StoreOp store) { return store.getMemRef(); })
|
|
.Case([&](vector::TransferReadOp read) { return read.getBase(); })
|
|
.Case([&](vector::TransferWriteOp write) { return write.getBase(); })
|
|
.Case([&](vector::LoadOp load) { return load.getBase(); })
|
|
.Case([&](vector::StoreOp store) { return store.getBase(); })
|
|
.Default([](Operation *) { return Value(); });
|
|
}
|
|
|
|
/// Recognize scalar memref.load of an element produced by a vector write
|
|
/// (vector.transfer_write or vector.store, optionally through a rank-reducing
|
|
/// unit-stride subview) of the same buffer. This covers the pattern where a
|
|
/// vector write stores a full lane pack and a subsequent scalar load reads an
|
|
/// element from that lane pack. EXAMPLE:
|
|
/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
|
|
/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
|
|
/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
|
|
/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
|
|
/// ...
|
|
/// }
|
|
///
|
|
static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
|
|
ValueRange writeIndices, VectorType vecTy,
|
|
ArrayRef<int64_t> vectorDimForWriteDim,
|
|
const IRMapping &ivsMap) {
|
|
if (!vecTy)
|
|
return false;
|
|
|
|
Value base = writeBase;
|
|
// The write base if there is no subview, or the subview source otherwise.
|
|
MemrefValue baseMemref = nullptr;
|
|
SmallVector<OpFoldResult> offsets;
|
|
llvm::SmallBitVector droppedDims;
|
|
bool hasSubview = false;
|
|
auto *ctx = loadOp.getContext();
|
|
if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
|
|
if (!subView.hasUnitStride())
|
|
return false;
|
|
baseMemref = cast<MemrefValue>(subView.getSource());
|
|
offsets = llvm::to_vector(subView.getMixedOffsets());
|
|
droppedDims = subView.getDroppedDims();
|
|
hasSubview = true;
|
|
} else {
|
|
baseMemref = dyn_cast<MemrefValue>(base);
|
|
if (!baseMemref)
|
|
return false;
|
|
}
|
|
|
|
auto loadIndices = loadOp.getIndices();
|
|
unsigned baseRank = baseMemref.getType().getRank();
|
|
if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
|
|
return false;
|
|
|
|
unsigned writeRank = writeIndices.size();
|
|
if ((!hasSubview && writeRank != baseRank) ||
|
|
(hasSubview && offsets.size() != baseRank) ||
|
|
(vectorDimForWriteDim.size() != writeRank))
|
|
return false;
|
|
|
|
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
|
|
unsigned writeMemrefDim = 0;
|
|
for (unsigned baseDim : llvm::seq(baseRank)) {
|
|
bool wasDropped = (hasSubview && droppedDims.test(baseDim));
|
|
int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
|
|
int64_t extent = 1;
|
|
if (vectorDim >= 0) {
|
|
int64_t dimSize = vecTy.getDimSize(vectorDim);
|
|
if (dimSize == ShapedType::kDynamic)
|
|
return false;
|
|
extent = dimSize;
|
|
}
|
|
Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
|
|
OpFoldResult offset =
|
|
hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
|
|
if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
|
|
extent, ivsMap))
|
|
return false;
|
|
if (!wasDropped)
|
|
++writeMemrefDim;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Recognize scalar memref.load of an element produced by a
|
|
/// vector.transfer_write
|
|
static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
|
|
vector::TransferWriteOp writeOp,
|
|
const IRMapping &ivsMap) {
|
|
auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
|
|
if (!vecTy)
|
|
return false;
|
|
|
|
unsigned writeRank = writeOp.getIndices().size();
|
|
AffineMap permutationMap = writeOp.getPermutationMap();
|
|
if (!permutationMap.isProjectedPermutation() ||
|
|
permutationMap.getNumResults() != vecTy.getRank() ||
|
|
permutationMap.getNumDims() != writeRank)
|
|
return false;
|
|
|
|
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
|
|
for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
|
|
auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
|
|
if (!dimExpr)
|
|
return false;
|
|
unsigned writeDim = dimExpr.getPosition();
|
|
if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
|
|
return false;
|
|
vectorDimForWriteDim[writeDim] = vecDim;
|
|
}
|
|
|
|
return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
|
|
vecTy, vectorDimForWriteDim, ivsMap);
|
|
}
|
|
|
|
/// Recognize scalar memref.load of an element produced by a vector.store
|
|
static bool loadMatchesVectorStore(memref::LoadOp loadOp,
|
|
vector::StoreOp storeOp,
|
|
const IRMapping &ivsMap) {
|
|
auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
|
|
if (!vecTy)
|
|
return false;
|
|
|
|
unsigned writeRank = storeOp.getIndices().size();
|
|
if (vecTy.getRank() > writeRank)
|
|
return false;
|
|
|
|
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
|
|
unsigned vecRank = vecTy.getRank();
|
|
for (unsigned i = 0; i < vecRank; ++i) {
|
|
unsigned writeDim = writeRank - vecRank + i;
|
|
vectorDimForWriteDim[writeDim] = i;
|
|
}
|
|
|
|
return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
|
|
vecTy, vectorDimForWriteDim, ivsMap);
|
|
}
|
|
|
|
/// Check if both operations access the same positions of the same
|
|
/// buffer, but one of the two does it through a rank-reducing full subview of
|
|
/// the buffer (the other's base). EXAMPLE:
|
|
/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
|
|
/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
|
|
/// to memref<2x2xf32>
|
|
/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
|
|
template <typename OpTy1, typename OpTy2>
|
|
static bool opsAccessSameIndicesViaRankReducingSubview(
|
|
OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
|
|
OpBuilder &b) {
|
|
auto base1 = cast<MemrefValue>(getBaseMemref(op1));
|
|
auto base2 = cast<MemrefValue>(getBaseMemref(op2));
|
|
if (!base1 || !base2)
|
|
return false;
|
|
|
|
auto accessThroughTrivialSubviewIsSame =
|
|
[&b](memref::SubViewOp subView, ValueRange subViewAccess,
|
|
ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
|
|
SmallVector<Value> resolvedSubviewAccess;
|
|
LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
|
|
subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
|
|
if (failed(resolved) ||
|
|
(resolvedSubviewAccess.size() != sourceAccess.size()))
|
|
return false;
|
|
for (auto [dimIdx, resolvedIndex] :
|
|
llvm::enumerate(resolvedSubviewAccess)) {
|
|
if (!matchPattern(resolvedIndex, m_Zero()) &&
|
|
!valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
// Case 1: op1 uses a subview of op2's base.
|
|
if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
|
|
subView &&
|
|
memref::isSameViewOrTrivialAlias(
|
|
base2, cast<MemrefValue>(subView.getSource())) &&
|
|
accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
|
|
op2.getIndices(),
|
|
firstToSecondPloopIVsMap))
|
|
return true;
|
|
|
|
// Case 2: op2 uses a subview of op1's base.
|
|
if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
|
|
subView &&
|
|
memref::isSameViewOrTrivialAlias(
|
|
base1, cast<MemrefValue>(subView.getSource())) &&
|
|
accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
|
|
op1.getIndices(),
|
|
firstToSecondPloopIVsMap))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Check if both memory read/write operations access the same indices
|
|
/// (considering also the mapping of induction variables from the first to the
|
|
/// second parallel loop).
|
|
template <typename OpTy1, typename OpTy2>
|
|
static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
|
|
const IRMapping &loopsIVsMap, OpBuilder &b) {
|
|
auto indices1 = op1.getIndices();
|
|
auto indices2 = op2.getIndices();
|
|
if (indices1.size() != indices2.size())
|
|
return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
|
|
for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
|
|
if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Check if the loadOp reads from the same memory location (same buffer,
|
|
/// same indices and same properties) as written by the storeOp.
|
|
static bool
|
|
loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp,
|
|
const IRMapping &firstToSecondPloopIVsMap,
|
|
OpBuilder &b) {
|
|
if (!loadOp || !storeOp)
|
|
return false;
|
|
// Support only these memory-reading ops for now
|
|
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
|
|
return false;
|
|
bool accessSameMemory =
|
|
llvm::TypeSwitch<Operation *, bool>(loadOp)
|
|
.Case([&](memref::LoadOp memLoadOp) {
|
|
if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
|
|
return opsAccessSameIndices(memLoadOp, memStoreOp,
|
|
firstToSecondPloopIVsMap, b);
|
|
if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
|
|
return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
|
|
firstToSecondPloopIVsMap);
|
|
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
|
|
return loadMatchesVectorStore(memLoadOp, vecStoreOp,
|
|
firstToSecondPloopIVsMap);
|
|
return false;
|
|
})
|
|
.Case([&](vector::TransferReadOp vecReadOp) {
|
|
auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
|
|
if (!vecWriteOp)
|
|
return false;
|
|
return opsAccessSameIndices(vecReadOp, vecWriteOp,
|
|
firstToSecondPloopIVsMap, b) &&
|
|
(vecReadOp.getMask() == vecWriteOp.getMask()) &&
|
|
(vecReadOp.getInBounds() == vecWriteOp.getInBounds());
|
|
})
|
|
.Case([&](vector::LoadOp vecLoadOp) {
|
|
auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
|
|
if (!vecStoreOp)
|
|
return false;
|
|
return opsAccessSameIndices(vecLoadOp, vecStoreOp,
|
|
firstToSecondPloopIVsMap, b) &&
|
|
(vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
|
|
})
|
|
.Default([](Operation *) { return false; });
|
|
return accessSameMemory;
|
|
}
|
|
|
|
static Value getStoreOpTargetBuffer(Operation *op) {
|
|
return llvm::TypeSwitch<Operation *, Value>(op)
|
|
.Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
|
|
.Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
|
|
.Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
|
|
.Default([](Operation *) { return Value(); });
|
|
}
|
|
|
|
/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
|
|
/// aliasing between the loadOp and storeOp can be resolved by analyzing their
|
|
/// access patterns.
|
|
static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
|
|
const IRMapping &loopsIVsMap) {
|
|
if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
|
|
transfWriteOp && isa<memref::LoadOp>(loadOp))
|
|
return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
|
|
loopsIVsMap);
|
|
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
|
|
vecStoreOp && isa<memref::LoadOp>(loadOp))
|
|
return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
|
|
loopsIVsMap);
|
|
return false;
|
|
}
|
|
|
|
/// Check that the parallel loops have no mixed access to the same buffers.
|
|
/// Return `true` if the second parallel loop does not read or write the buffers
|
|
/// written by the first loop using different indices.
|
|
static bool haveNoDataDependenciesExceptSameIndex(
|
|
ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias, OpBuilder &b) {
|
|
// Map buffers to their store/write ops in the firstPloop
|
|
DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
|
|
// Record all the memory buffers used in store/write ops found in firstPloop
|
|
llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
|
|
|
|
auto collectStoreOpsInWalk = [&](Operation *op) {
|
|
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
|
|
// Ignore ops that don't write to memory
|
|
if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
|
|
!memOpInterf.hasEffect<MemoryEffects::Free>()))
|
|
return WalkResult::advance();
|
|
|
|
// Only these memory-writing ops are supported for now:
|
|
// memref.store, vector.transfer_write, vector.store
|
|
Value storeOpBase = getStoreOpTargetBuffer(op);
|
|
if (!storeOpBase)
|
|
return WalkResult::interrupt();
|
|
|
|
// Expect the base operand to be a Memref
|
|
MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
|
|
if (!storeOpBaseMemref)
|
|
return WalkResult::interrupt();
|
|
// Get the original memref buffer, skipping full view-like ops
|
|
Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
|
|
bufferStoresInFirstPloop[buffer].push_back(op);
|
|
buffersWrittenInFirstPloop.insert(buffer);
|
|
return WalkResult::advance();
|
|
};
|
|
|
|
// Walk the first parallel loop to collect all store/write ops and their
|
|
// target buffers
|
|
if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
|
|
return false;
|
|
|
|
// Check that this load/read op encountered while walking the second parallel
|
|
// loop does not have incompatible data dependencies with the store/write ops
|
|
// collected from the first parallel loop: the loops can be fused only if in
|
|
// the 2nd loop there are no loads/stores from/to the buffers written in the
|
|
// 1st loop, except when on the same exact memory location (same indices) as
|
|
// written in the 1st loop.
|
|
auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
|
|
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
|
|
// To be conservative, we should stop on ops that don't advertise their
|
|
// memory effects. However, many ops don't implement MemoryEffectOpInterface
|
|
// yet, so for now we just skip them.
|
|
// TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
|
|
if (!memOpInterf &&
|
|
!loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
|
|
return WalkResult::advance();
|
|
// Ignore ops that don't read from memory, and wrapping ops that have nested
|
|
// memory effects (e.g. loops, conditionals) as they will be analyzed when
|
|
// visiting their nested ops.
|
|
if ((!memOpInterf &&
|
|
loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
|
|
(memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
|
|
return WalkResult::advance();
|
|
// Support only these memory-reading ops for now
|
|
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
|
|
!isa<MemrefValue>(loadOp->getOperand(0)))
|
|
return WalkResult::interrupt();
|
|
|
|
MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
|
|
MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
|
|
|
|
for (Value storedMem : buffersWrittenInFirstPloop)
|
|
if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
|
|
!llvm::all_of(bufferStoresInFirstPloop[storedMem],
|
|
[&](Operation *storeOp) {
|
|
return canResolveAlias(loadOp, storeOp,
|
|
firstToSecondPloopIndices);
|
|
})) {
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
|
|
if (writeOpsIt == bufferStoresInFirstPloop.end())
|
|
return WalkResult::advance();
|
|
// Store/write ops to this buffer in the firstPloop
|
|
SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
|
|
|
|
// If the first loop has no writes to this buffer, continue
|
|
if (writeOps.empty())
|
|
return WalkResult::advance();
|
|
|
|
Operation *writeOp = writeOps.front();
|
|
|
|
// In the first parallel loop, multiple writes to the same memref are
|
|
// allowed only on the same memory location
|
|
if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
|
|
return opsWriteSameMemLocation(writeOp, otherWriteOp);
|
|
})) {
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
// Check that the load in secondPloop reads from the same memory location as
|
|
// written by the corresponding store in firstPloop
|
|
if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
|
|
firstToSecondPloopIndices, b)) {
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
};
|
|
|
|
// Walk the second parallel loop to check load/read ops against the stores
|
|
// collected from the first parallel loop.
|
|
return !secondPloop.getBody()
|
|
->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
|
|
.wasInterrupted();
|
|
}
|
|
|
|
/// Check that in each loop there are no read ops on the buffers written
|
|
/// by the other loop, except when reading from the same exact memory location
|
|
/// (same indices) as written in the other loop.
|
|
static bool
|
|
noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias,
|
|
OpBuilder &b) {
|
|
if (!haveNoDataDependenciesExceptSameIndex(
|
|
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
|
|
return false;
|
|
|
|
IRMapping secondToFirstPloopIndices;
|
|
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
|
|
firstPloop.getBody()->getArguments());
|
|
return haveNoDataDependenciesExceptSameIndex(
|
|
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
|
|
}
|
|
|
|
/// Check if fusion of the two parallel loops is legal:
|
|
/// i.e. no nested parallel loops, equal iteration spaces,
|
|
/// and no incompatible data dependencies between the loops.
|
|
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias,
|
|
OpBuilder &b) {
|
|
return !hasNestedParallelOp(firstPloop) &&
|
|
!hasNestedParallelOp(secondPloop) &&
|
|
equalIterationSpaces(firstPloop, secondPloop) &&
|
|
noIncompatibleDataDependencies(firstPloop, secondPloop,
|
|
firstToSecondPloopIndices, mayAlias, b);
|
|
}
|
|
|
|
/// Prepend operations of firstPloop's body into secondPloop's body.
|
|
/// Update secondPloop with new loop.
|
|
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
|
|
OpBuilder builder,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
Block *block1 = firstPloop.getBody();
|
|
Block *block2 = secondPloop.getBody();
|
|
IRMapping firstToSecondPloopIndices;
|
|
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
|
|
|
|
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
|
|
mayAlias, builder))
|
|
return;
|
|
|
|
DominanceInfo dom;
|
|
// We are fusing first loop into second, make sure there are no users of the
|
|
// first loop results between loops.
|
|
for (Operation *user : firstPloop->getUsers())
|
|
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
|
|
return;
|
|
|
|
ValueRange inits1 = firstPloop.getInitVals();
|
|
ValueRange inits2 = secondPloop.getInitVals();
|
|
|
|
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
|
|
newInitVars.append(inits2.begin(), inits2.end());
|
|
|
|
IRRewriter b(builder);
|
|
b.setInsertionPoint(secondPloop);
|
|
auto newSecondPloop = ParallelOp::create(
|
|
b, secondPloop.getLoc(), secondPloop.getLowerBound(),
|
|
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
|
|
|
|
Block *newBlock = newSecondPloop.getBody();
|
|
auto term1 = cast<ReduceOp>(block1->getTerminator());
|
|
auto term2 = cast<ReduceOp>(block2->getTerminator());
|
|
|
|
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
|
|
newBlock->getArguments());
|
|
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
|
|
newBlock->getArguments());
|
|
|
|
ValueRange results = newSecondPloop.getResults();
|
|
if (!results.empty()) {
|
|
b.setInsertionPointToEnd(newBlock);
|
|
|
|
ValueRange reduceArgs1 = term1.getOperands();
|
|
ValueRange reduceArgs2 = term2.getOperands();
|
|
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
|
|
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
|
|
|
|
auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs);
|
|
|
|
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
|
|
term1.getReductions(), term2.getReductions()))) {
|
|
Block &oldRedBlock = reg.front();
|
|
Block &newRedBlock = newReduceOp.getReductions()[i].front();
|
|
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
|
|
newRedBlock.getArguments());
|
|
}
|
|
|
|
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
|
|
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
|
|
}
|
|
term1->erase();
|
|
term2->erase();
|
|
firstPloop.erase();
|
|
secondPloop.erase();
|
|
secondPloop = newSecondPloop;
|
|
}
|
|
|
|
void mlir::scf::naivelyFuseParallelOps(
|
|
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
OpBuilder b(region);
|
|
// Consider every single block and attempt to fuse adjacent loops.
|
|
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
|
|
for (auto &block : region) {
|
|
ploopChains.clear();
|
|
ploopChains.push_back({});
|
|
|
|
// Not using `walk()` to traverse only top-level parallel loops and also
|
|
// make sure that there are no side-effecting ops between the parallel
|
|
// loops.
|
|
bool noSideEffects = true;
|
|
for (auto &op : block) {
|
|
if (auto ploop = dyn_cast<ParallelOp>(op)) {
|
|
if (noSideEffects) {
|
|
ploopChains.back().push_back(ploop);
|
|
} else {
|
|
ploopChains.push_back({ploop});
|
|
noSideEffects = true;
|
|
}
|
|
continue;
|
|
}
|
|
// TODO: Handle region side effects properly.
|
|
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
|
|
}
|
|
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
|
|
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
|
|
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
struct ParallelLoopFusion
|
|
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
|
|
void runOnOperation() override {
|
|
auto &aa = getAnalysis<AliasAnalysis>();
|
|
|
|
auto mayAlias = [&](Value val1, Value val2) -> bool {
|
|
// If the memref is defined in one of the parallel loops body, careful
|
|
// alias analysis is needed.
|
|
// TODO: check if this is still needed as a separate check.
|
|
auto val1Def = val1.getDefiningOp();
|
|
auto val2Def = val2.getDefiningOp();
|
|
auto val1Loop =
|
|
val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
|
|
auto val2Loop =
|
|
val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
|
|
if (val1Loop != val2Loop)
|
|
return true;
|
|
|
|
return !aa.alias(val1, val2).isNo();
|
|
};
|
|
|
|
getOperation()->walk([&](Operation *child) {
|
|
for (Region ®ion : child->getRegions())
|
|
naivelyFuseParallelOps(region, mayAlias);
|
|
});
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
|
|
return std::make_unique<ParallelLoopFusion>();
|
|
}
|