[mlir][scf] Refactor and improve ParallelLoopFusion (#179284)
Refactor and extend the scf::ParalleLoopFusion pass: - Refactor code, rename functions and add comments to improve readability - Make the dependency analysis safer by checking for read-after-write dependencies also with vector.load/store & vector.transfer_read/write ops, in addition to memref.load/store, and bail out when other unsupported ops with memory effects are found. - Extend the cases when the fusion is applied: allow fusing also when one of the two loops reads/writes to memory through a full view/alias of the buffer (read/written by the dual operation in the other loop) that can be trivially resolved, including rank-reducing full subviews.
This commit is contained in:
parent
88872a7cf8
commit
c5ae550344
@ -157,6 +157,21 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
|
||||
ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices);
|
||||
|
||||
/// Given the 'indices' of a load/store operation where the memref is a result
|
||||
/// of a rank-reducing full subview op, returns the indices w.r.t to the source
|
||||
/// memref of the memref.subview op. For example
|
||||
///
|
||||
/// %alias = memref.subview %src[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
|
||||
/// to memref<2x2xf32>
|
||||
/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
|
||||
///
|
||||
/// could be folded into
|
||||
///
|
||||
/// %val = memref.load %src[0, %i, %j] : memref<1x2x2xf32>
|
||||
LogicalResult resolveSourceIndicesRankReducingSubview(
|
||||
Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices);
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
|
||||
namespace mlir {
|
||||
class Location;
|
||||
@ -248,6 +249,12 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
|
||||
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
|
||||
IRMapping *clonedToSrcOpsMap = nullptr);
|
||||
|
||||
/// Get constant loop bounds and steps for each of the induction variables of
|
||||
/// the given loop operation, if all the loop's ranges are constant. Each entry
|
||||
/// in the returned vector is a tuple (lowerBound, upperBound, step).
|
||||
llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
|
||||
getConstLoopBounds(mlir::LoopLikeOpInterface loopOp);
|
||||
|
||||
/// Get constant trip counts for each of the induction variables of the given
|
||||
/// loop operation. If any of the loop's trip counts is not constant, return an
|
||||
/// empty vector.
|
||||
|
||||
@ -286,5 +286,46 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult resolveSourceIndicesRankReducingSubview(
|
||||
Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices) {
|
||||
if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
|
||||
return failure();
|
||||
|
||||
MemRefType srcType = subViewOp.getSourceType();
|
||||
MemRefType resType = subViewOp.getType();
|
||||
unsigned srcRank = srcType.getRank();
|
||||
unsigned resRank = resType.getRank();
|
||||
if (srcRank <= resRank || indices.size() != resRank)
|
||||
return failure();
|
||||
|
||||
auto droppedDims = subViewOp.getDroppedDims();
|
||||
if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
|
||||
return failure();
|
||||
|
||||
auto mixedSizes = subViewOp.getMixedSizes();
|
||||
if (mixedSizes.size() != srcRank)
|
||||
return failure();
|
||||
|
||||
unsigned resultDim = 0;
|
||||
for (unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
|
||||
if (droppedDims.test(sourceDim)) {
|
||||
auto sizeCst = getConstantIntValue(mixedSizes[sourceDim]);
|
||||
if (!sizeCst || *sizeCst != 1)
|
||||
return failure();
|
||||
sourceIndices.push_back(
|
||||
getValueOrCreateConstantIndexOp(b, loc, b.getIndexAttr(0)));
|
||||
continue;
|
||||
}
|
||||
if (resultDim >= indices.size())
|
||||
return failure();
|
||||
sourceIndices.push_back(indices[resultDim++]);
|
||||
}
|
||||
if (resultDim != indices.size())
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
||||
MLIRBufferizationTransforms
|
||||
MLIRDestinationStyleOpInterface
|
||||
MLIRDialectUtils
|
||||
MLIRIndexDialect
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIRPass
|
||||
|
||||
@ -13,15 +13,31 @@
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
|
||||
@ -55,114 +71,670 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
|
||||
matchOperands(firstPloop.getStep(), secondPloop.getStep());
|
||||
}
|
||||
|
||||
/// Checks if the parallel loops have mixed access to the same buffers. Returns
|
||||
/// `true` if the first parallel loop writes to the same indices that the second
|
||||
/// loop reads.
|
||||
static bool haveNoReadsAfterWriteExceptSameIndex(
|
||||
ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
|
||||
SmallVector<Value> bufferStoresVec;
|
||||
firstPloop.getBody()->walk([&](memref::StoreOp store) {
|
||||
bufferStores[store.getMemRef()].push_back(store.getIndices());
|
||||
bufferStoresVec.emplace_back(store.getMemRef());
|
||||
});
|
||||
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
|
||||
Value loadMem = load.getMemRef();
|
||||
// Stop if the memref is defined in secondPloop body. Careful alias analysis
|
||||
// is needed.
|
||||
auto *memrefDef = loadMem.getDefiningOp();
|
||||
if (memrefDef && memrefDef->getBlock() == load->getBlock())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
for (Value store : bufferStoresVec)
|
||||
if (store != loadMem && mayAlias(store, loadMem))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
auto write = bufferStores.find(loadMem);
|
||||
if (write == bufferStores.end())
|
||||
return WalkResult::advance();
|
||||
|
||||
// Check that at last one store was retrieved
|
||||
if (write->second.empty())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
auto storeIndices = write->second.front();
|
||||
|
||||
// Multiple writes to the same memref are allowed only on the same indices
|
||||
for (const auto &othStoreIndices : write->second) {
|
||||
if (othStoreIndices != storeIndices)
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
// Check that the load indices of secondPloop coincide with store indices of
|
||||
// firstPloop for the same memrefs.
|
||||
auto loadIndices = load.getIndices();
|
||||
if (storeIndices.size() != loadIndices.size())
|
||||
return WalkResult::interrupt();
|
||||
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
|
||||
if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
|
||||
loadIndices[i]) {
|
||||
auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
|
||||
auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
|
||||
if (storeIndexDefOp && loadIndexDefOp) {
|
||||
if (!isMemoryEffectFree(storeIndexDefOp))
|
||||
return WalkResult::interrupt();
|
||||
if (!isMemoryEffectFree(loadIndexDefOp))
|
||||
return WalkResult::interrupt();
|
||||
if (!OperationEquivalence::isEquivalentTo(
|
||||
storeIndexDefOp, loadIndexDefOp,
|
||||
[&](Value storeIndex, Value loadIndex) {
|
||||
if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
|
||||
firstToSecondPloopIndices.lookupOrDefault(loadIndex))
|
||||
return failure();
|
||||
else
|
||||
return success();
|
||||
},
|
||||
/*markEquivalent=*/nullptr,
|
||||
OperationEquivalence::Flags::IgnoreLocations)) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
} else {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return !walkResult.wasInterrupted();
|
||||
/// Check if both operations are the same type of memory write op and
|
||||
/// write to the same memory location (same buffer and same indices).
|
||||
static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
|
||||
if (!op1 || !op2 || op1->getName() != op2->getName())
|
||||
return false;
|
||||
if (op1 == op2)
|
||||
return true;
|
||||
// support only these memory-writing ops for now
|
||||
if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
|
||||
return false;
|
||||
bool opsAreIdentical =
|
||||
llvm::TypeSwitch<Operation *, bool>(op1)
|
||||
.Case([&](memref::StoreOp storeOp1) {
|
||||
auto storeOp2 = cast<memref::StoreOp>(op2);
|
||||
return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
|
||||
(storeOp1.getIndices() == storeOp2.getIndices());
|
||||
})
|
||||
.Case([&](vector::TransferWriteOp writeOp1) {
|
||||
auto writeOp2 = cast<vector::TransferWriteOp>(op2);
|
||||
return (writeOp1.getBase() == writeOp2.getBase()) &&
|
||||
(writeOp1.getIndices() == writeOp2.getIndices()) &&
|
||||
(writeOp1.getMask() == writeOp2.getMask()) &&
|
||||
(writeOp1.getValueToStore().getType() ==
|
||||
writeOp2.getValueToStore().getType()) &&
|
||||
(writeOp1.getInBounds() == writeOp2.getInBounds());
|
||||
})
|
||||
.Case([&](vector::StoreOp vecStoreOp1) {
|
||||
auto vecStoreOp2 = cast<vector::StoreOp>(op2);
|
||||
return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
|
||||
(vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
|
||||
(vecStoreOp1.getValueToStore().getType() ==
|
||||
vecStoreOp2.getValueToStore().getType()) &&
|
||||
(vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
|
||||
(vecStoreOp1.getNontemporal() ==
|
||||
vecStoreOp2.getNontemporal());
|
||||
})
|
||||
.Default([](Operation *) { return false; });
|
||||
return opsAreIdentical;
|
||||
}
|
||||
|
||||
/// Analyzes dependencies in the most primitive way by checking simple read and
|
||||
/// write patterns.
|
||||
static LogicalResult
|
||||
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
if (!haveNoReadsAfterWriteExceptSameIndex(
|
||||
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
|
||||
return failure();
|
||||
/// Check if val1 (from the first parallel loop) and val2 (from the
|
||||
/// second) are equivalent, considering the mapping of induction variables from
|
||||
/// the first to the second parallel loop.
|
||||
static bool valsAreEquivalent(Value val1, Value val2,
|
||||
const IRMapping &loopsIVsMap) {
|
||||
if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
|
||||
loopsIVsMap.lookupOrDefault(val2) == val1)
|
||||
return true;
|
||||
Operation *val1DefOp = val1.getDefiningOp();
|
||||
Operation *val2DefOp = val2.getDefiningOp();
|
||||
if (!val1DefOp || !val2DefOp)
|
||||
return false;
|
||||
if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
|
||||
return false;
|
||||
return OperationEquivalence::isEquivalentTo(
|
||||
val1DefOp, val2DefOp,
|
||||
[&](Value v1, Value v2) {
|
||||
return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
|
||||
loopsIVsMap.lookupOrDefault(v2) == v1);
|
||||
},
|
||||
/*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
|
||||
}
|
||||
|
||||
/// If the `expr` value is the result of an integer addition of `base` and a
|
||||
/// constant, return the constant.
|
||||
static std::optional<int64_t> getAddConstant(Value expr, Value base,
|
||||
const IRMapping &loopsIVsMap) {
|
||||
if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
|
||||
if (auto constOp = getConstantIntValue(addOp.getLhs());
|
||||
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
|
||||
return constOp.value();
|
||||
if (auto constOp = getConstantIntValue(addOp.getRhs());
|
||||
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
|
||||
return constOp.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
|
||||
if (auto constOp = getConstantIntValue(addOp.getLhs());
|
||||
constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
|
||||
return constOp.value();
|
||||
if (auto constOp = getConstantIntValue(addOp.getRhs());
|
||||
constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
|
||||
return constOp.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
|
||||
AffineMap map = applyOp.getAffineMap();
|
||||
if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
|
||||
map.getNumSymbols() != 0)
|
||||
return std::nullopt;
|
||||
if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
|
||||
return std::nullopt;
|
||||
AffineExpr result = map.getResult(0);
|
||||
auto bin = dyn_cast<AffineBinaryOpExpr>(result);
|
||||
if (!bin || bin.getKind() != AffineExprKind::Add)
|
||||
return std::nullopt;
|
||||
auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
|
||||
auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
|
||||
auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
|
||||
auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
|
||||
if (lhsConst && rhsDim)
|
||||
return lhsConst.getValue();
|
||||
if (rhsConst && lhsDim)
|
||||
return rhsConst.getValue();
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Return true if the scalar load index may hit any element covered by a
|
||||
// vector.store/transfer_write along a single memref dimension. Supported cases:
|
||||
//
|
||||
// 1) Direct index match (with optional offset):
|
||||
// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
|
||||
// %x = memref.load %A[%i] : memref<...>
|
||||
//
|
||||
// 2) Loop IV range intersects the write range:
|
||||
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
|
||||
// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
|
||||
//
|
||||
// 3) Constant index (or IV + constant) within the write range:
|
||||
// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
|
||||
// %x = memref.load %A[%c2] : memref<...>
|
||||
// %y = memref.load %A[%i + %c1] : memref<...>
|
||||
//
|
||||
// Args:
|
||||
// - loadIndex: index used by the scalar load for this dimension.
|
||||
// - offset: subview offset for the base memref dimension (if any).
|
||||
// - writeIndex: index used by the transfer_write for this dimension. Can be
|
||||
// null if the dim was dropped by a rank reducing subview, whose result is
|
||||
// written by the vector.write.
|
||||
// - extent: vector size along this dimension (number of elements written).
|
||||
// - loopsIVsMap: IV equivalence map between fused loops.
|
||||
static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
|
||||
Value writeIndex, int64_t extent,
|
||||
const IRMapping &loopsIVsMap) {
|
||||
if (extent <= 0)
|
||||
return false;
|
||||
|
||||
// Extract constant loop bounds for loop IVs (e.g. from scf.for).
|
||||
auto getConstLoopBoundsForIV =
|
||||
[](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
|
||||
auto blockArg = dyn_cast<BlockArgument>(index);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
auto *parentOp = blockArg.getOwner()->getParentOp();
|
||||
auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
|
||||
if (!loopLike)
|
||||
return std::nullopt;
|
||||
auto ranges = getConstLoopBounds(loopLike);
|
||||
if (ranges.empty())
|
||||
return std::nullopt;
|
||||
|
||||
auto ivs = loopLike.getLoopInductionVars();
|
||||
if (!ivs)
|
||||
return std::nullopt;
|
||||
auto it = llvm::find(*ivs, blockArg);
|
||||
if (it == ivs->end())
|
||||
return std::nullopt;
|
||||
unsigned pos = std::distance(ivs->begin(), it);
|
||||
if (pos >= ranges.size())
|
||||
return std::nullopt;
|
||||
auto [lb, ub, step] = ranges[pos];
|
||||
return std::make_tuple(lb, ub, step);
|
||||
};
|
||||
|
||||
std::optional<int64_t> offsetConst = getConstantIntValue(offset);
|
||||
std::optional<int64_t> writeConst =
|
||||
writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
|
||||
if (!writeConst && writeIndex) {
|
||||
// Treat single-iteration IVs as constants for matching.
|
||||
if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
|
||||
auto [lb, ub, step] = *bounds;
|
||||
if (step > 0 && ub == lb + step)
|
||||
writeConst = lb;
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether a loop IV is fully contained in a constant write range.
|
||||
auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
|
||||
int64_t rangeStart, int64_t rangeExtent) -> bool {
|
||||
if (rangeExtent <= 0 || step <= 0)
|
||||
return false;
|
||||
if (ub <= lb)
|
||||
return false;
|
||||
int64_t rangeEnd = rangeStart + rangeExtent;
|
||||
return lb >= rangeStart && ub <= rangeEnd;
|
||||
};
|
||||
|
||||
if (offsetConst && writeConst) {
|
||||
// Constant start of the write range; check constant load or loop IV range.
|
||||
int64_t start = *offsetConst + *writeConst;
|
||||
if (auto loadConst = getConstantIntValue(loadIndex))
|
||||
return (*loadConst >= start && *loadConst < start + extent);
|
||||
if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
|
||||
auto [lb, ub, step] = *bounds;
|
||||
return loopIVWithinRange(lb, ub, step, start, extent);
|
||||
}
|
||||
}
|
||||
|
||||
if (writeIndex) {
|
||||
// Direct IV match (or IV + constant) against the write index.
|
||||
if (offsetConst && *offsetConst == 0 &&
|
||||
valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
|
||||
return true;
|
||||
if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
|
||||
// Match load index of the form writeIndex + C within the write extent.
|
||||
if (offsetConst) {
|
||||
int64_t start = *offsetConst;
|
||||
return (*addConst >= start && *addConst < start + extent);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto offsetVal = dyn_cast<Value>(offset)) {
|
||||
// Exact match when extent is 1 and the load hits the offset value.
|
||||
if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Return the base memref value used by the given memory op.
|
||||
static Value getBaseMemref(Operation *op) {
|
||||
// TODO: use the common interface for memory ops once available.
|
||||
return llvm::TypeSwitch<Operation *, Value>(op)
|
||||
.Case([&](memref::LoadOp load) { return load.getMemRef(); })
|
||||
.Case([&](memref::StoreOp store) { return store.getMemRef(); })
|
||||
.Case([&](vector::TransferReadOp read) { return read.getBase(); })
|
||||
.Case([&](vector::TransferWriteOp write) { return write.getBase(); })
|
||||
.Case([&](vector::LoadOp load) { return load.getBase(); })
|
||||
.Case([&](vector::StoreOp store) { return store.getBase(); })
|
||||
.Default([](Operation *) { return Value(); });
|
||||
}
|
||||
|
||||
/// Recognize scalar memref.load of an element produced by a vector write
|
||||
/// (vector.transfer_write or vector.store, optionally through a rank-reducing
|
||||
/// unit-stride subview) of the same buffer. This covers the pattern where a
|
||||
/// vector write stores a full lane pack and a subsequent scalar load reads an
|
||||
/// element from that lane pack. EXAMPLE:
|
||||
/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
|
||||
/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
|
||||
/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
|
||||
/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
|
||||
ValueRange writeIndices, VectorType vecTy,
|
||||
ArrayRef<int64_t> vectorDimForWriteDim,
|
||||
const IRMapping &ivsMap) {
|
||||
if (!vecTy)
|
||||
return false;
|
||||
|
||||
Value base = writeBase;
|
||||
// The write base if there is no subview, or the subview source otherwise.
|
||||
MemrefValue baseMemref = nullptr;
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
llvm::SmallBitVector droppedDims;
|
||||
bool hasSubview = false;
|
||||
auto *ctx = loadOp.getContext();
|
||||
if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
|
||||
if (!subView.hasUnitStride())
|
||||
return false;
|
||||
baseMemref = cast<MemrefValue>(subView.getSource());
|
||||
offsets = llvm::to_vector(subView.getMixedOffsets());
|
||||
droppedDims = subView.getDroppedDims();
|
||||
hasSubview = true;
|
||||
} else {
|
||||
baseMemref = dyn_cast<MemrefValue>(base);
|
||||
if (!baseMemref)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto loadIndices = loadOp.getIndices();
|
||||
unsigned baseRank = baseMemref.getType().getRank();
|
||||
if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
|
||||
return false;
|
||||
|
||||
unsigned writeRank = writeIndices.size();
|
||||
if ((!hasSubview && writeRank != baseRank) ||
|
||||
(hasSubview && offsets.size() != baseRank) ||
|
||||
(vectorDimForWriteDim.size() != writeRank))
|
||||
return false;
|
||||
|
||||
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
|
||||
unsigned writeMemrefDim = 0;
|
||||
for (unsigned baseDim : llvm::seq(baseRank)) {
|
||||
bool wasDropped = (hasSubview && droppedDims.test(baseDim));
|
||||
int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
|
||||
int64_t extent = 1;
|
||||
if (vectorDim >= 0) {
|
||||
int64_t dimSize = vecTy.getDimSize(vectorDim);
|
||||
if (dimSize == ShapedType::kDynamic)
|
||||
return false;
|
||||
extent = dimSize;
|
||||
}
|
||||
Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
|
||||
OpFoldResult offset =
|
||||
hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
|
||||
if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
|
||||
extent, ivsMap))
|
||||
return false;
|
||||
if (!wasDropped)
|
||||
++writeMemrefDim;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Recognize scalar memref.load of an element produced by a
|
||||
/// vector.transfer_write
|
||||
static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
|
||||
vector::TransferWriteOp writeOp,
|
||||
const IRMapping &ivsMap) {
|
||||
auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
|
||||
if (!vecTy)
|
||||
return false;
|
||||
|
||||
unsigned writeRank = writeOp.getIndices().size();
|
||||
AffineMap permutationMap = writeOp.getPermutationMap();
|
||||
if (!permutationMap.isProjectedPermutation() ||
|
||||
permutationMap.getNumResults() != vecTy.getRank() ||
|
||||
permutationMap.getNumDims() != writeRank)
|
||||
return false;
|
||||
|
||||
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
|
||||
for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
|
||||
auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
|
||||
if (!dimExpr)
|
||||
return false;
|
||||
unsigned writeDim = dimExpr.getPosition();
|
||||
if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
|
||||
return false;
|
||||
vectorDimForWriteDim[writeDim] = vecDim;
|
||||
}
|
||||
|
||||
return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
|
||||
vecTy, vectorDimForWriteDim, ivsMap);
|
||||
}
|
||||
|
||||
/// Recognize scalar memref.load of an element produced by a vector.store
|
||||
static bool loadMatchesVectorStore(memref::LoadOp loadOp,
|
||||
vector::StoreOp storeOp,
|
||||
const IRMapping &ivsMap) {
|
||||
auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
|
||||
if (!vecTy)
|
||||
return false;
|
||||
|
||||
unsigned writeRank = storeOp.getIndices().size();
|
||||
if (vecTy.getRank() > writeRank)
|
||||
return false;
|
||||
|
||||
SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
|
||||
unsigned vecRank = vecTy.getRank();
|
||||
for (unsigned i = 0; i < vecRank; ++i) {
|
||||
unsigned writeDim = writeRank - vecRank + i;
|
||||
vectorDimForWriteDim[writeDim] = i;
|
||||
}
|
||||
|
||||
return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
|
||||
vecTy, vectorDimForWriteDim, ivsMap);
|
||||
}
|
||||
|
||||
/// Check if both operations access the same positions of the same
|
||||
/// buffer, but one of the two does it through a rank-reducing full subview of
|
||||
/// the buffer (the other's base). EXAMPLE:
|
||||
/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
|
||||
/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
|
||||
/// to memref<2x2xf32>
|
||||
/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
|
||||
template <typename OpTy1, typename OpTy2>
|
||||
static bool opsAccessSameIndicesViaRankReducingSubview(
|
||||
OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
|
||||
OpBuilder &b) {
|
||||
auto base1 = cast<MemrefValue>(getBaseMemref(op1));
|
||||
auto base2 = cast<MemrefValue>(getBaseMemref(op2));
|
||||
if (!base1 || !base2)
|
||||
return false;
|
||||
|
||||
auto accessThroughTrivialSubviewIsSame =
|
||||
[&b](memref::SubViewOp subView, ValueRange subViewAccess,
|
||||
ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
|
||||
SmallVector<Value> resolvedSubviewAccess;
|
||||
LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
|
||||
subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
|
||||
if (failed(resolved) ||
|
||||
(resolvedSubviewAccess.size() != sourceAccess.size()))
|
||||
return false;
|
||||
for (auto [dimIdx, resolvedIndex] :
|
||||
llvm::enumerate(resolvedSubviewAccess)) {
|
||||
if (!matchPattern(resolvedIndex, m_Zero()) &&
|
||||
!valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Case 1: op1 uses a subview of op2's base.
|
||||
if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
|
||||
subView &&
|
||||
memref::isSameViewOrTrivialAlias(
|
||||
base2, cast<MemrefValue>(subView.getSource())) &&
|
||||
accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
|
||||
op2.getIndices(),
|
||||
firstToSecondPloopIVsMap))
|
||||
return true;
|
||||
|
||||
// Case 2: op2 uses a subview of op1's base.
|
||||
if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
|
||||
subView &&
|
||||
memref::isSameViewOrTrivialAlias(
|
||||
base1, cast<MemrefValue>(subView.getSource())) &&
|
||||
accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
|
||||
op1.getIndices(),
|
||||
firstToSecondPloopIVsMap))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Check if both memory read/write operations access the same indices
|
||||
/// (considering also the mapping of induction variables from the first to the
|
||||
/// second parallel loop).
|
||||
template <typename OpTy1, typename OpTy2>
|
||||
static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
|
||||
const IRMapping &loopsIVsMap, OpBuilder &b) {
|
||||
auto indices1 = op1.getIndices();
|
||||
auto indices2 = op2.getIndices();
|
||||
if (indices1.size() != indices2.size())
|
||||
return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
|
||||
for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
|
||||
if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Check if the loadOp reads from the same memory location (same buffer,
|
||||
/// same indices and same properties) as written by the storeOp.
|
||||
static bool
|
||||
loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp,
|
||||
const IRMapping &firstToSecondPloopIVsMap,
|
||||
OpBuilder &b) {
|
||||
if (!loadOp || !storeOp)
|
||||
return false;
|
||||
// Support only these memory-reading ops for now
|
||||
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
|
||||
return false;
|
||||
bool accessSameMemory =
|
||||
llvm::TypeSwitch<Operation *, bool>(loadOp)
|
||||
.Case([&](memref::LoadOp memLoadOp) {
|
||||
if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
|
||||
return opsAccessSameIndices(memLoadOp, memStoreOp,
|
||||
firstToSecondPloopIVsMap, b);
|
||||
if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
|
||||
return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
|
||||
firstToSecondPloopIVsMap);
|
||||
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
|
||||
return loadMatchesVectorStore(memLoadOp, vecStoreOp,
|
||||
firstToSecondPloopIVsMap);
|
||||
return false;
|
||||
})
|
||||
.Case([&](vector::TransferReadOp vecReadOp) {
|
||||
auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
|
||||
if (!vecWriteOp)
|
||||
return false;
|
||||
return opsAccessSameIndices(vecReadOp, vecWriteOp,
|
||||
firstToSecondPloopIVsMap, b) &&
|
||||
(vecReadOp.getMask() == vecWriteOp.getMask()) &&
|
||||
(vecReadOp.getInBounds() == vecWriteOp.getInBounds());
|
||||
})
|
||||
.Case([&](vector::LoadOp vecLoadOp) {
|
||||
auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
|
||||
if (!vecStoreOp)
|
||||
return false;
|
||||
return opsAccessSameIndices(vecLoadOp, vecStoreOp,
|
||||
firstToSecondPloopIVsMap, b) &&
|
||||
(vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
|
||||
})
|
||||
.Default([](Operation *) { return false; });
|
||||
return accessSameMemory;
|
||||
}
|
||||
|
||||
static Value getStoreOpTargetBuffer(Operation *op) {
|
||||
return llvm::TypeSwitch<Operation *, Value>(op)
|
||||
.Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
|
||||
.Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
|
||||
.Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
|
||||
.Default([](Operation *) { return Value(); });
|
||||
}
|
||||
|
||||
/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
|
||||
/// aliasing between the loadOp and storeOp can be resolved by analyzing their
|
||||
/// access patterns.
|
||||
static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
|
||||
const IRMapping &loopsIVsMap) {
|
||||
if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
|
||||
transfWriteOp && isa<memref::LoadOp>(loadOp))
|
||||
return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
|
||||
loopsIVsMap);
|
||||
if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
|
||||
vecStoreOp && isa<memref::LoadOp>(loadOp))
|
||||
return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
|
||||
loopsIVsMap);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Check that the parallel loops have no mixed access to the same buffers.
|
||||
/// Return `true` if the second parallel loop does not read or write the buffers
|
||||
/// written by the first loop using different indices.
|
||||
static bool haveNoDataDependenciesExceptSameIndex(
|
||||
ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias, OpBuilder &b) {
|
||||
// Map buffers to their store/write ops in the firstPloop
|
||||
DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
|
||||
// Record all the memory buffers used in store/write ops found in firstPloop
|
||||
llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
|
||||
|
||||
auto collectStoreOpsInWalk = [&](Operation *op) {
|
||||
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
|
||||
// Ignore ops that don't write to memory
|
||||
if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
|
||||
!memOpInterf.hasEffect<MemoryEffects::Free>()))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Only these memory-writing ops are supported for now:
|
||||
// memref.store, vector.transfer_write, vector.store
|
||||
Value storeOpBase = getStoreOpTargetBuffer(op);
|
||||
if (!storeOpBase)
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// Expect the base operand to be a Memref
|
||||
MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
|
||||
if (!storeOpBaseMemref)
|
||||
return WalkResult::interrupt();
|
||||
// Get the original memref buffer, skipping full view-like ops
|
||||
Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
|
||||
bufferStoresInFirstPloop[buffer].push_back(op);
|
||||
buffersWrittenInFirstPloop.insert(buffer);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
// Walk the first parallel loop to collect all store/write ops and their
|
||||
// target buffers
|
||||
if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
|
||||
return false;
|
||||
|
||||
// Check that this load/read op encountered while walking the second parallel
|
||||
// loop does not have incompatible data dependencies with the store/write ops
|
||||
// collected from the first parallel loop: the loops can be fused only if in
|
||||
// the 2nd loop there are no loads/stores from/to the buffers written in the
|
||||
// 1st loop, except when on the same exact memory location (same indices) as
|
||||
// written in the 1st loop.
|
||||
auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
|
||||
auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
|
||||
// To be conservative, we should stop on ops that don't advertise their
|
||||
// memory effects. However, many ops don't implement MemoryEffectOpInterface
|
||||
// yet, so for now we just skip them.
|
||||
// TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
|
||||
if (!memOpInterf &&
|
||||
!loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
|
||||
return WalkResult::advance();
|
||||
// Ignore ops that don't read from memory, and wrapping ops that have nested
|
||||
// memory effects (e.g. loops, conditionals) as they will be analyzed when
|
||||
// visiting their nested ops.
|
||||
if ((!memOpInterf &&
|
||||
loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
|
||||
(memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
|
||||
return WalkResult::advance();
|
||||
// Support only these memory-reading ops for now
|
||||
if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
|
||||
!isa<MemrefValue>(loadOp->getOperand(0)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
|
||||
MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
|
||||
|
||||
for (Value storedMem : buffersWrittenInFirstPloop)
|
||||
if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
|
||||
!llvm::all_of(bufferStoresInFirstPloop[storedMem],
|
||||
[&](Operation *storeOp) {
|
||||
return canResolveAlias(loadOp, storeOp,
|
||||
firstToSecondPloopIndices);
|
||||
})) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
|
||||
if (writeOpsIt == bufferStoresInFirstPloop.end())
|
||||
return WalkResult::advance();
|
||||
// Store/write ops to this buffer in the firstPloop
|
||||
SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
|
||||
|
||||
// If the first loop has no writes to this buffer, continue
|
||||
if (writeOps.empty())
|
||||
return WalkResult::advance();
|
||||
|
||||
Operation *writeOp = writeOps.front();
|
||||
|
||||
// In the first parallel loop, multiple writes to the same memref are
|
||||
// allowed only on the same memory location
|
||||
if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
|
||||
return opsWriteSameMemLocation(writeOp, otherWriteOp);
|
||||
})) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
// Check that the load in secondPloop reads from the same memory location as
|
||||
// written by the corresponding store in firstPloop
|
||||
if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
|
||||
firstToSecondPloopIndices, b)) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
// Walk the second parallel loop to check load/read ops against the stores
|
||||
// collected from the first parallel loop.
|
||||
return !secondPloop.getBody()
|
||||
->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
|
||||
.wasInterrupted();
|
||||
}
|
||||
|
||||
/// Check that in each loop there are no read ops on the buffers written
|
||||
/// by the other loop, except when reading from the same exact memory location
|
||||
/// (same indices) as written in the other loop.
|
||||
static bool
|
||||
noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias,
|
||||
OpBuilder &b) {
|
||||
if (!haveNoDataDependenciesExceptSameIndex(
|
||||
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
|
||||
return false;
|
||||
|
||||
IRMapping secondToFirstPloopIndices;
|
||||
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
|
||||
firstPloop.getBody()->getArguments());
|
||||
return success(haveNoReadsAfterWriteExceptSameIndex(
|
||||
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
|
||||
return haveNoDataDependenciesExceptSameIndex(
|
||||
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
|
||||
}
|
||||
|
||||
/// Check if fusion of the two parallel loops is legal:
|
||||
/// i.e. no nested parallel loops, equal iteration spaces,
|
||||
/// and no incompatible data dependencies between the loops.
|
||||
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias,
|
||||
OpBuilder &b) {
|
||||
return !hasNestedParallelOp(firstPloop) &&
|
||||
!hasNestedParallelOp(secondPloop) &&
|
||||
equalIterationSpaces(firstPloop, secondPloop) &&
|
||||
succeeded(verifyDependencies(firstPloop, secondPloop,
|
||||
firstToSecondPloopIndices, mayAlias));
|
||||
noIncompatibleDataDependencies(firstPloop, secondPloop,
|
||||
firstToSecondPloopIndices, mayAlias, b);
|
||||
}
|
||||
|
||||
/// Prepends operations of firstPloop's body into secondPloop's body.
|
||||
/// Updates secondPloop with new loop.
|
||||
/// Prepend operations of firstPloop's body into secondPloop's body.
|
||||
/// Update secondPloop with new loop.
|
||||
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
|
||||
OpBuilder builder,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
@ -172,7 +744,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
|
||||
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
|
||||
|
||||
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
|
||||
mayAlias))
|
||||
mayAlias, builder))
|
||||
return;
|
||||
|
||||
DominanceInfo dom;
|
||||
@ -272,6 +844,18 @@ struct ParallelLoopFusion
|
||||
auto &aa = getAnalysis<AliasAnalysis>();
|
||||
|
||||
auto mayAlias = [&](Value val1, Value val2) -> bool {
|
||||
// If the memref is defined in one of the parallel loops body, careful
|
||||
// alias analysis is needed.
|
||||
// TODO: check if this is still needed as a separate check.
|
||||
auto val1Def = val1.getDefiningOp();
|
||||
auto val2Def = val2.getDefiningOp();
|
||||
auto val1Loop =
|
||||
val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
|
||||
auto val2Loop =
|
||||
val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
|
||||
if (val1Loop != val2Loop)
|
||||
return true;
|
||||
|
||||
return !aa.alias(val1, val2).isNo();
|
||||
};
|
||||
|
||||
|
||||
@ -1560,6 +1560,25 @@ bool mlir::isPerfectlyNestedForLoops(
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
|
||||
mlir::getConstLoopBounds(mlir::LoopLikeOpInterface loopOp) {
|
||||
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
|
||||
std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
|
||||
std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
|
||||
if (!loBnds || !upBnds || !steps)
|
||||
return {};
|
||||
llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>> loopRanges;
|
||||
for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
|
||||
auto lbCst = getConstantIntValue(lb);
|
||||
auto ubCst = getConstantIntValue(ub);
|
||||
auto stepCst = getConstantIntValue(step);
|
||||
if (!lbCst || !ubCst || !stepCst)
|
||||
return {};
|
||||
loopRanges.emplace_back(*lbCst, *ubCst, *stepCst);
|
||||
}
|
||||
return loopRanges;
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::APInt>
|
||||
mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
|
||||
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
|
||||
|
||||
@ -314,23 +314,24 @@ func.func @do_not_fuse_unmatching_read_write_patterns(
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
|
||||
func.func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies() {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
%buffer = memref.alloc() : memref<2x2xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
|
||||
memref.store %c1fp, %buffer[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
|
||||
: memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
|
||||
%A = memref.subview %buffer[%i, %c0][2, 1][1, 1] : memref<2x2xf32> to memref<2x1xf32, strided<[2, 1], offset: ?>>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x1xf32, strided<[2, 1], offset: ?>>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
|
||||
// CHECK-LABEL: func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
@ -604,6 +605,415 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_trivial_rank_reducing_subview() {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
%buf = memref.alloc() : memref<1x2x2xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
memref.store %c1fp, %buf[%c0, %i, %j] : memref<1x2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
%sub = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]
|
||||
: memref<1x2x2xf32> to memref<2x2xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%v = memref.load %sub[%i, %j] : memref<2x2xf32>
|
||||
memref.store %v, %buf[%c0, %i, %j] : memref<1x2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
memref.dealloc %buf : memref<1x2x2xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_trivial_rank_reducing_subview
|
||||
// CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x2xf32>
|
||||
// CHECK: %[[SUB:.*]] = memref.subview %[[BUF]]
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: memref.store {{.*}}, %[[BUF]]
|
||||
// CHECK: %[[L:.*]] = memref.load %[[SUB]]
|
||||
// CHECK: memref.store %[[L]], %[[BUF]]
|
||||
// CHECK-NOT: scf.parallel
|
||||
// CHECK: memref.dealloc %[[BUF]] : memref<1x2x2xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_nontrivial_subview_offset() {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
%buf = memref.alloc() : memref<2x2x2xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
memref.store %c1fp, %buf[%c0, %i, %j] : memref<2x2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
%sub = memref.subview %buf[1, 0, 0][1, 2, 2][1, 1, 1]
|
||||
: memref<2x2x2xf32> to memref<2x2xf32, strided<[2, 1], offset: 4>>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%v = memref.load %sub[%i, %j]
|
||||
: memref<2x2xf32, strided<[2, 1], offset: 4>>
|
||||
memref.store %v, %buf[%c0, %i, %j] : memref<2x2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
memref.dealloc %buf : memref<2x2x2xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_nontrivial_subview_offset
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_load_store(%A: memref<4x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%vec0 = arith.constant dense<0.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%v = vector.load %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_load_store
|
||||
// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
|
||||
// CHECK: vector.store
|
||||
// CHECK: %[[V:.*]] = vector.load
|
||||
// CHECK: vector.store %[[V]]
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_vector_different_indices(%A: memref<4x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%vec0 = arith.constant dense<0.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
|
||||
%v = vector.load %A[%j, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_vector_different_indices
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_same_indices(%A: memref<4x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
|
||||
vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
|
||||
vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_same_indices
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_vector_transfer_different_indices(%A: memref<4x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
|
||||
vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
|
||||
%j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
|
||||
%v = vector.transfer_read %A[%j, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
|
||||
vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_vector_transfer_different_indices
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_with_subview(%A: memref<1x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%vec = arith.constant dense<1.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sub = memref.subview %A[0, 0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32>
|
||||
vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
|
||||
%v = memref.load %A[%c0, %k] : memref<1x4xf32>
|
||||
%n = arith.addf %v, %acc : f32
|
||||
scf.yield %n : f32
|
||||
}
|
||||
memref.store %sum, %A[%c0, %c0] : memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_with_subview
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_vector_transfer_nontrivial_subview(%A: memref<2x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v = vector.transfer_read %A[%c0, %i], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<2x4xf32>, vector<1xf32>
|
||||
vector.transfer_write %v, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
%sub = memref.subview %A[1, 0][1, 4][1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: 4>>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v = vector.transfer_read %sub[%i], %zero {in_bounds = [true]} : memref<4xf32, strided<[1], offset: 4>>, vector<1xf32>
|
||||
vector.transfer_write %v, %sub[%i] {in_bounds = [true]} : vector<1xf32>, memref<4xf32, strided<[1], offset: 4>>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_vector_transfer_nontrivial_subview
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_vector_transfer_different_masks(%A: memref<1x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%mask_true = vector.create_mask %c1 : vector<1xi1>
|
||||
%mask_false = vector.create_mask %c0 : vector<1xi1>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v = vector.transfer_read %A[%c0, %i], %zero, %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
|
||||
vector.transfer_write %v, %A[%c0, %i], %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v = vector.transfer_read %A[%c0, %i], %zero, %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
|
||||
vector.transfer_write %v, %A[%c0, %i], %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_vector_transfer_different_masks
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_subview_rank_reducing(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%vec = arith.constant dense<1.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
|
||||
vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
|
||||
%v = memref.load %A[%i, %k] : memref<1x4xf32>
|
||||
%n = arith.addf %v, %acc : f32
|
||||
scf.yield %n : f32
|
||||
}
|
||||
memref.store %sum, %B[%i, %c0] : memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_subview_rank_reducing
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_vector_transfer_subview_offset(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%vec = arith.constant dense<1.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
|
||||
vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
|
||||
%v = memref.load %A[%i, %k] : memref<1x4xf32>
|
||||
%n = arith.addf %v, %acc : f32
|
||||
scf.yield %n : f32
|
||||
}
|
||||
// Read from an offset alias to prevent fusion.
|
||||
%off = memref.subview %A[%i, %c1][1, 3][1, 1] : memref<1x4xf32> to memref<3xf32, strided<[1], offset: ?>>
|
||||
%v0 = memref.load %off[%c0] : memref<3xf32, strided<[1], offset: ?>>
|
||||
%res = arith.addf %sum, %v0 : f32
|
||||
memref.store %res, %B[%i, %c0] : memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fuse_vector_transfer_subview_offset
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_no_subview(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%vec = arith.constant dense<2.0> : vector<4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
vector.transfer_write %vec, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
|
||||
%v = memref.load %A[%c0, %k] : memref<1x4xf32>
|
||||
%n = arith.addf %v, %acc : f32
|
||||
scf.yield %n : f32
|
||||
}
|
||||
memref.store %sum, %B[%c0, %c0] : memref<1x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_no_subview
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%vec = arith.constant dense<1.0> : vector<2x4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v0 = memref.load %A[%c0, %c1] : memref<2x4xf32>
|
||||
%v1 = memref.load %A[%c1, %c2] : memref<2x4xf32>
|
||||
%sum = arith.addf %v0, %v1 : f32
|
||||
memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_rank2
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: memref.load
|
||||
// CHECK: memref.load
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_transfer_scalar_load_loop_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%zero = arith.constant 0.0 : f32
|
||||
%vec = arith.constant dense<2.0> : vector<2x4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
|
||||
%v = memref.load %A[%c1, %k] : memref<2x4xf32>
|
||||
%n = arith.addf %v, %acc : f32
|
||||
scf.yield %n : f32
|
||||
}
|
||||
memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_loop_rank2
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_vector_store_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%vec = arith.constant dense<3.0> : vector<2x4xf32>
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
vector.store %vec, %A[%c0, %c0] : memref<2x4xf32>, vector<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
|
||||
%v0 = memref.load %A[%c1, %c2] : memref<2x4xf32>
|
||||
%v1 = memref.load %A[%c0, %c3] : memref<2x4xf32>
|
||||
%sum = arith.addf %v0, %v1 : f32
|
||||
memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
|
||||
scf.reduce
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_vector_store_scalar_load_rank2
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: vector.store
|
||||
// CHECK: memref.load
|
||||
// CHECK: memref.load
|
||||
// CHECK-NOT: scf.parallel
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user