
The motivation is to avoid having to negate `isDynamic*` checks, avoid double negations, and allow for `ShapedType::isStaticDim` to be used in ADT functions without having to wrap it in a lambda performing the negation. Also add the new functions to C and Python bindings.
805 lines
32 KiB
C++
805 lines
32 KiB
C++
//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Utils/CodegenUtils.h"
|
|
#include "Utils/IterationGraphSorter.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// File Local Helper classes.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CRTP to help implementing a rewriter that demaps all its inputs.
|
|
template <typename SubClass, typename SourceOp>
|
|
struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
|
|
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
|
|
LogicalResult matchAndRewrite(SourceOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// Demaps non-trivial inputs.
|
|
bool changed = false;
|
|
SmallVector<Value> deMappedIns(op->getOperands());
|
|
for (Value &in : deMappedIns) {
|
|
if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
|
|
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
|
|
changed = true;
|
|
}
|
|
}
|
|
|
|
// CRTP call.
|
|
OpAdaptor adaptor(deMappedIns, op);
|
|
LogicalResult status =
|
|
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
|
|
return changed ? success() : status;
|
|
}
|
|
};
|
|
|
|
// Flattens an affine expression into a list of AffineDimExprs.
|
|
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
|
|
explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
|
|
void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
|
|
BitVector dims;
|
|
};
|
|
|
|
// Flattens an affine expression into a list of AffineDimExprs.
|
|
struct AffineExprAdmissibleVisitor
|
|
: public AffineExprVisitor<AffineExprAdmissibleVisitor> {
|
|
explicit AffineExprAdmissibleVisitor(bool isOutput)
|
|
: admissible(true), isOutput(isOutput){};
|
|
|
|
// We only allow AffineDimExpr on output.
|
|
void visitAddExpr(AffineBinaryOpExpr expr) {
|
|
if (isOutput)
|
|
admissible = false;
|
|
}
|
|
void visitMulExpr(AffineBinaryOpExpr expr) {
|
|
if (isOutput)
|
|
admissible = false;
|
|
}
|
|
|
|
// We disallow mod, floor div and ceil div on inputs.
|
|
void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
operator bool() { return admissible; }
|
|
|
|
private:
|
|
bool admissible;
|
|
bool isOutput;
|
|
};
|
|
|
|
// The first BitVector stores levels where inadmissible exprs are used.
|
|
// The second BitVector stores the AffineDimExp that are used by the
|
|
// inadmissible expressions.
|
|
using InadmissInfo = std::pair<BitVector, BitVector>;
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// File Local Helper methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Collects the inadmissible affine expression imposed on levels.
|
|
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
|
|
auto ret = std::make_pair(BitVector(map.getNumResults()),
|
|
BitVector(map.getNumDims()));
|
|
AffineDimCollector collector(map.getNumDims());
|
|
for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
|
|
AffineExprAdmissibleVisitor admissible(isOutput);
|
|
admissible.walkPostOrder(map.getResult(lvl));
|
|
if (!admissible) {
|
|
// Record the inadmissible level.
|
|
ret.first.set(lvl);
|
|
// Record the AffineDimExpr that is used in the inadmissible expr.
|
|
collector.walkPostOrder(map.getResult(lvl));
|
|
}
|
|
}
|
|
ret.second = collector.dims;
|
|
return ret;
|
|
}
|
|
|
|
// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
|
|
// inadmissible affine expressions can be eliminated.
|
|
// For example, we can rewrite
|
|
// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
|
|
// to
|
|
// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
|
|
// by composing inverse(idxMap), that is
|
|
// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
|
|
// -> ((l0 * 2 + l2) floordiv 2,
|
|
// (l1 * 3 + l3) floordiv 3,
|
|
// (l0 * 2 + l2) mod 2,
|
|
// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
|
|
//
|
|
// This function builds the inverse(idxMap) that replace every dimensions used
|
|
// in `info` to levels, and updates the iterator type array `itTps` for the new
|
|
// index variable introduced.
|
|
//
|
|
// Note that the returned affine map does not retain the order of the input
|
|
// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
|
|
// replaced levels, and remaining ones for unused dimensions.
|
|
// For example, to handle
|
|
// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
|
|
// which is a typical map for block_2to4. The function returns:
|
|
// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
|
|
// in which, (l0, l1) together replaces `d1`, yet they appear
|
|
// before `d0` in the resulting affine map.
|
|
// The index (loop) order can later be canonicalized by a topo sort.
|
|
static AffineMap
|
|
genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
|
|
SmallVector<utils::IteratorType> &itTps) {
|
|
MLIRContext *ctx = idxMap.getContext();
|
|
auto [inAdLvls, usedDims] = info;
|
|
// Note that idxMap does not equal to dim2Lvl map, it is computed by
|
|
// composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
|
|
// ID map.
|
|
// TODO: we might fail here, in those case we should really return
|
|
// failure instead of assertion error.
|
|
auto lvl2Idx = inferLvlToDim(idxMap, ctx);
|
|
|
|
assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
|
|
if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
|
|
// This could happen when some dimensions are projected.
|
|
// E.g., idx2Lvl = (*i*, j, k) -> (j, k)
|
|
// ==> lvl2Idx = (j, k) -> (j, k)
|
|
// In this case, we append the unused dimesion at the end.
|
|
// ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
|
|
SmallVector<AffineExpr> results;
|
|
AffineDimCollector usedInLvl(idxMap.getNumDims());
|
|
for (auto e : idxMap.getResults())
|
|
usedInLvl.walkPostOrder(e);
|
|
|
|
unsigned curUsedDimID = 0;
|
|
unsigned curUnusedDimID = lvl2Idx.getNumDims();
|
|
|
|
BitVector unused = usedInLvl.dims.flip();
|
|
for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
|
|
if (unused.test(i))
|
|
results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
|
|
else
|
|
results.push_back(lvl2Idx.getResult(curUsedDimID++));
|
|
}
|
|
lvl2Idx =
|
|
AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
|
|
}
|
|
assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
|
|
|
|
// We do not need to replace the DimExpr that is not used in inadmissible
|
|
// level expressions. We use the first inAdLvl.count() dim to represent the
|
|
// replaced level, the remainings are reserved for unchanged ones.
|
|
// Note that results from the inverse map computed previously does not follow
|
|
// the convention we used, and we need to fix the mismatch below.
|
|
unsigned curRepID = 0;
|
|
unsigned curOriID = inAdLvls.count();
|
|
SmallVector<AffineExpr> results;
|
|
SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
|
|
SmallVector<utils::IteratorType> transItTps;
|
|
|
|
for (unsigned l : inAdLvls.set_bits()) {
|
|
// By our convention, the inadmissible level `l` always appears in the
|
|
// leading part (accumulated by curRepID) of the affine map's parameter
|
|
// list. Record the mapping so that we can replace all the uses of `l` to
|
|
// the correct position after the translation.
|
|
dimRep[l] = getAffineDimExpr(curRepID++, ctx);
|
|
// A new index variable is introduced for the inadmissible level, inherit
|
|
// the iterator type. E.g., if l0 = d0 floordiv 2, the
|
|
// iterator type of l0 equals to the iterator type of d0.
|
|
AffineExpr lvlExp = idxMap.getResult(l);
|
|
AffineDimCollector collector(idxMap.getNumDims());
|
|
collector.walkPostOrder(lvlExp);
|
|
// We assumes a level can only be derived from one dimension.
|
|
assert(collector.dims.count() == 1);
|
|
transItTps.push_back(itTps[collector.dims.find_first()]);
|
|
}
|
|
|
|
for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
|
|
if (usedDims.test(d)) {
|
|
// The dimension is used in some of the inadmissible levels, and it need
|
|
// to be inversed. Get the inversion from the inverse map, and fix the
|
|
// mismatch captured by the above loop.
|
|
results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
|
|
} else {
|
|
// The dimension is not used in any of the inadmissible levels, and it
|
|
// does not need to be inversed. Fix the mismatch by mapping it to the
|
|
// trailing part of the affine map (accumulated by curOriID).
|
|
results.push_back(getAffineDimExpr(curOriID++, ctx));
|
|
transItTps.push_back(itTps[d]);
|
|
}
|
|
}
|
|
unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
|
|
// Update iterator type.
|
|
itTps.assign(transItTps.begin(), transItTps.end());
|
|
return AffineMap::get(numDim, 0, results, ctx);
|
|
}
|
|
|
|
// Translates the index map in the linalg::GenericOp from idx->dim map to
|
|
// idx->lvl map. Returns failure if the index map can not be translated to an
|
|
// admissible form.
|
|
// Returns the translated index map array and the iterator type array.
|
|
static std::optional<std::pair<ArrayAttr, ArrayAttr>>
|
|
translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
|
|
// idxMap is a idx2dim map before reinterpretation.
|
|
MLIRContext *ctx = op.getContext();
|
|
SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
|
|
SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
|
|
for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
|
|
Value tensor = op->getOpOperand(i).get();
|
|
auto stt = tryGetSparseTensorType(tensor);
|
|
if (stt && !stt->isIdentity()) {
|
|
AffineMap dim2Lvl = stt->getDimToLvl();
|
|
// By composing the idx2dim(dim2lvl), we got a idx2lvl Map
|
|
idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
|
|
}
|
|
}
|
|
|
|
// A naive way to handle common constant expressions that arise during dim2lvl
|
|
// translation.
|
|
auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
|
|
unsigned pos, int64_t lvlSz) {
|
|
if (ShapedType::isStatic(lvlSz)) {
|
|
auto c0 = getAffineConstantExpr(0, ctx);
|
|
auto lvlExp = getAffineDimExpr(pos, ctx);
|
|
auto szExp = getAffineConstantExpr(lvlSz, ctx);
|
|
|
|
// lvl floordiv lvlSz = 0
|
|
auto divExp =
|
|
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
|
|
cstMapping.try_emplace(divExp, c0);
|
|
|
|
// lvl mod lvlSz = lvl
|
|
auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
|
|
cstMapping.try_emplace(modExp, lvlExp);
|
|
}
|
|
};
|
|
|
|
unsigned boundedNum = 0;
|
|
// A fixed-point algorithm.
|
|
bool changed = true;
|
|
while (changed) {
|
|
changed = false;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
auto stt = tryGetSparseTensorType(operand.get());
|
|
// Skip on dense operands.
|
|
if (!stt || !stt->getEncoding())
|
|
continue;
|
|
|
|
unsigned tid = operand.getOperandNumber();
|
|
bool isOutput = &operand == op.getDpsInitOperand(0);
|
|
AffineMap idxMap = idxMapArray[tid];
|
|
InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
|
|
auto [inAdLvls, dimExprs] = inAdInfo;
|
|
for (unsigned d : dimExprs.set_bits()) {
|
|
// The first `boundedNum` used in the AffineMap is introduced to
|
|
// resolve previous inadmissible expressions. We can not replace them
|
|
// as it might bring back the inadmissible expressions.
|
|
if (d < boundedNum)
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (inAdLvls.count() != 0) {
|
|
// Naive constant progagation, should be sufficient to handle block
|
|
// sparsity in our cases.
|
|
SmallVector<int64_t> lvlShape = stt->getLvlShape();
|
|
DenseMap<AffineExpr, AffineExpr> cstMapping;
|
|
unsigned position = 0;
|
|
for (unsigned lvl : inAdLvls.set_bits()) {
|
|
int64_t lvlSz = lvlShape[lvl];
|
|
populateCstMapping(cstMapping, position, lvlSz);
|
|
position++;
|
|
}
|
|
|
|
AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
|
|
// Compose the lvl2Idx Map to all AffineIdxMap to eliminate
|
|
// inadmissible expressions.
|
|
for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
|
|
AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
|
|
idxMapArray[tid] = transMap.replace(
|
|
cstMapping, /*numResultDims=*/transMap.getNumDims(),
|
|
/*numResultSyms=*/0);
|
|
}
|
|
changed = true;
|
|
boundedNum += inAdLvls.count();
|
|
}
|
|
}
|
|
};
|
|
|
|
SmallVector<Attribute> iterAttr =
|
|
llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
|
|
return linalg::IteratorTypeAttr::get(ctx, itTp);
|
|
});
|
|
|
|
return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
|
|
rewriter.getArrayAttr(iterAttr));
|
|
}
|
|
|
|
// Generates a "de"mapping reinterpretation of the map.
|
|
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
|
|
Value val) {
|
|
return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
|
|
val);
|
|
}
|
|
|
|
// Generates a "re"mapping reinterpretation of the map.
|
|
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
|
|
Value val) {
|
|
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
|
|
}
|
|
|
|
static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
|
|
ValueRange outs) {
|
|
SmallVector<Value> ret(outs);
|
|
assert(outs.size() == types.size());
|
|
for (auto [r, t] : llvm::zip(ret, types))
|
|
if (r.getType() != t)
|
|
r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
|
|
return ret;
|
|
}
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Rewriting rules for linalg generic ops.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Sparse rewriting rule for the generic `linalg` operation.
|
|
struct GenericOpReinterpretMap
|
|
: public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
|
|
public:
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
// Only rewrite single output operations with pure (sparse) tensor
|
|
// semantics.
|
|
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
|
|
!hasAnySparseOperandOrResult(linalgOp) ||
|
|
!hasAnyNonIdentityOperandsOrResults(linalgOp))
|
|
return failure();
|
|
|
|
// Try translating the index map.
|
|
auto transMap = translateMap(linalgOp, rewriter);
|
|
if (!transMap)
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be sparsified.");
|
|
|
|
// On success, replace update the linalg operands and maps in place.
|
|
Value res = linalgOp.getResult(0);
|
|
auto stt = tryGetSparseTensorType(res);
|
|
auto [idxMap, itTp] = *transMap;
|
|
|
|
rewriter.startOpModification(linalgOp);
|
|
linalgOp.setIndexingMapsAttr(idxMap);
|
|
linalgOp.setIteratorTypesAttr(itTp);
|
|
// Use demapped arguments.
|
|
linalgOp.getInputsMutable().assign(adaptor.getInputs());
|
|
linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
|
|
res.setType(adaptor.getOutputs()[0].getType());
|
|
rewriter.finalizeOpModification(linalgOp);
|
|
|
|
rewriter.setInsertionPointAfter(linalgOp);
|
|
if (stt && stt->hasEncoding()) {
|
|
Value t = genRemap(rewriter, stt->getEncoding(), res);
|
|
rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
|
|
hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
|
|
!hasAnySparseOperandOrResult(linalgOp)) {
|
|
return failure();
|
|
}
|
|
|
|
const StringRef sorted = "sorted";
|
|
if (linalgOp->hasAttr(sorted))
|
|
return failure();
|
|
|
|
auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
|
|
bool isAdmissible = false;
|
|
AffineMap order;
|
|
// A const list of all masks that we used for iteration graph
|
|
// computation. Must be ordered from more strict to less strict.
|
|
// Ideally (though might not be guaranteed), the earlier a constraint mask
|
|
// can be satisfied, the faster the generated kernel will be.
|
|
const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
|
|
SortMask::kIncludeDenseInput,
|
|
SortMask::kIncludeDenseOutput,
|
|
SortMask::kSparseOnly};
|
|
for (const SortMask mask : allMasks) {
|
|
order = scheduler.sort(mask);
|
|
if (order) {
|
|
if (isAdmissibleOrder(linalgOp, order)) {
|
|
isAdmissible = true;
|
|
break;
|
|
}
|
|
// else try a set of less strict constraints.
|
|
}
|
|
}
|
|
|
|
if (!order) {
|
|
// Cycles detected.
|
|
if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be scheduled: loop detected.");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
if (!isAdmissible) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be scheduled.");
|
|
}
|
|
|
|
// Marks the GenericOp to avoid recursive matching.
|
|
rewriter.modifyOpInPlace(linalgOp, [&]() {
|
|
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
|
|
});
|
|
|
|
// Already sorted.
|
|
if (order.isIdentity())
|
|
return success();
|
|
|
|
assert(order.isPermutation());
|
|
// `order` is orignial loop -> sorted loop map
|
|
ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
|
|
SmallVector<Attribute> curItTypes;
|
|
curItTypes.reserve(preItTypes.size());
|
|
for (AffineExpr expr : order.getResults()) {
|
|
unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
|
|
curItTypes.push_back(preItTypes[loopID]);
|
|
}
|
|
|
|
// Inverse `order` to get sorted loop -> original loop map
|
|
order = inversePermutation(order);
|
|
SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
|
|
for (AffineMap &idxMap : idxMaps)
|
|
idxMap = idxMap.compose(order); // sorted loop -> lvl map
|
|
|
|
rewriter.startOpModification(linalgOp);
|
|
linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
|
|
linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
|
|
rewriter.finalizeOpModification(linalgOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
/// Whether the loop order is admissible by sparsification.
|
|
static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
|
|
if (!hasAnySparseResult(linalgOp))
|
|
return true;
|
|
|
|
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
|
|
unsigned nest = 0;
|
|
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
|
|
for (const AffineExpr l : order.getResults()) {
|
|
unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
|
|
auto itTp =
|
|
cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
|
|
if (linalg::isReductionIterator(itTp.getValue()))
|
|
break; // terminate at first reduction
|
|
nest++;
|
|
}
|
|
// Determine admissible dynamic insertion situations:
|
|
// (1) fully injective, since there are no reductions,
|
|
// (2) admissible 1-d expansion in innermost dimension.
|
|
return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
|
|
};
|
|
|
|
// Last resort cycle resolution.
|
|
static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
|
|
linalg::LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) {
|
|
// Compute topological sort while leaving out every sparse input tensor in
|
|
// succession until an acylic iteration graph results.
|
|
for (OpOperand *t : linalgOp.getDpsInputOperands()) {
|
|
Value tval = t->get();
|
|
auto srcEnc = getSparseTensorEncoding(tval.getType());
|
|
// The constraints introduced by compound index expression are
|
|
// complicated. Skip them.
|
|
AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
|
|
bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
|
|
return !llvm::isa<AffineDimExpr>(exp);
|
|
});
|
|
if (!srcEnc || hasCompExpr)
|
|
continue;
|
|
|
|
// Try scheduling loop without constraints from `tval`.
|
|
AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
|
|
if (!order) // still cyclic
|
|
continue;
|
|
|
|
// Found an input tensor that resolves the cycle by inserting a
|
|
// conversion into a sparse tensor that adheres to the iteration
|
|
// graph order.
|
|
auto stt = getSparseTensorType(tval);
|
|
assert(stt.isIdentity());
|
|
order = inversePermutation(order);
|
|
// sorted loop -> lvl map.
|
|
idxMap = idxMap.compose(order);
|
|
|
|
// Found a permutation such that the results in `idxMap` is sorted.
|
|
// For example,
|
|
// (d0, d1, d2, d3) -> (d2, d1, d0)
|
|
// loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
|
|
// we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
|
|
// transposed tensor's levels are visited in the same order as the loop
|
|
// scheduling order.
|
|
SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
|
|
for (AffineExpr expr : idxMap.getResults()) {
|
|
unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
|
|
lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
|
|
}
|
|
llvm::sort(lvlSeq, llvm::less_first());
|
|
SmallVector<unsigned> perm =
|
|
llvm::to_vector(llvm::make_second_range(lvlSeq));
|
|
auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
|
|
// The result of the idxMap must be unsorted.
|
|
assert(!dimToLvl.isIdentity());
|
|
|
|
// Inserting the transpose
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
|
|
Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
|
|
rewriter.modifyOpInPlace(linalgOp, [&]() {
|
|
linalgOp->setOperand(t->getOperandNumber(), dst);
|
|
});
|
|
|
|
// Release the transposed form afterwards.
|
|
// TODO: CSE when used in more than one following op?
|
|
rewriter.setInsertionPointAfter(linalgOp);
|
|
rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
|
|
|
|
return success();
|
|
}
|
|
// Cannot be resolved with a single conversion.
|
|
// TODO: convert more than one?
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Reinterpret Map Rewriters for operations other than linalg.generics
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename AllocOp>
|
|
struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
|
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(AllocOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
auto stt = getSparseTensorType(op.getResult());
|
|
|
|
SmallVector<Value> maxDimCrds;
|
|
maxDimCrds.reserve(stt.getDimRank());
|
|
ValueRange dynSz = op.getDynamicSizes();
|
|
for (int64_t dimSz : stt.getDimShape()) {
|
|
if (ShapedType::isDynamic(dimSz)) {
|
|
Value maxCrd = rewriter.create<arith::SubIOp>(
|
|
loc, dynSz.front(), constantIndex(rewriter, loc, 1));
|
|
maxDimCrds.push_back(maxCrd);
|
|
dynSz = dynSz.drop_front();
|
|
} else {
|
|
maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
|
|
}
|
|
}
|
|
|
|
ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
|
|
CrdTransDirectionKind::dim2lvl);
|
|
auto lvlShape = stt.getLvlShape();
|
|
SmallVector<Value> dynLvlSzs;
|
|
for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
|
|
if (ShapedType::isDynamic(lvlShape[i])) {
|
|
Value sz = rewriter.create<arith::AddIOp>(
|
|
loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
|
|
dynLvlSzs.push_back(sz);
|
|
}
|
|
}
|
|
|
|
assert(dynSz.empty()); // should have consumed all.
|
|
rewriter.startOpModification(op);
|
|
op->setOperands(dynLvlSzs);
|
|
op.getResult().setType(stt.getDemappedType());
|
|
rewriter.finalizeOpModification(op);
|
|
rewriter.setInsertionPointAfter(op);
|
|
|
|
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
|
|
rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TensorInsertDemapper
|
|
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
auto stt = getSparseTensorType(op.getResult());
|
|
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
|
|
CrdTransDirectionKind::dim2lvl);
|
|
auto insertOp = rewriter.create<tensor::InsertOp>(
|
|
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
|
|
|
|
Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
|
|
rewriter.replaceOp(op, out);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(AssembleOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
assert(hasAnySparseResult(op));
|
|
auto stt = getSparseTensorType(op.getResult());
|
|
rewriter.modifyOpInPlace(
|
|
op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
|
|
rewriter.setInsertionPointAfter(op);
|
|
Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
|
|
rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SparseDisassembleDemapper
|
|
: public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
assert(hasAnySparseOperandOrResult(op));
|
|
rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
|
|
op.getTensorMutable().assign(adaptor.getTensor());
|
|
});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ForeachOpDemapper
|
|
: public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
// Only handle operations with sparse input/output with non-identity dim2lvl
|
|
// maps.
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
// TODO: demap constant as well.
|
|
if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
|
|
if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
// Cache the type information since we update the foreach op in-place.
|
|
auto srcStt = getSparseTensorType(op.getTensor());
|
|
SmallVector<Type> prevRetTps(op.getResultTypes());
|
|
|
|
rewriter.startOpModification(op);
|
|
op.getTensorMutable().assign(adaptor.getTensor());
|
|
op.getInitArgsMutable().assign(adaptor.getInitArgs());
|
|
// Update results' types.
|
|
for (auto r : op.getResults())
|
|
if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
|
|
r.setType(stt->getDemappedType());
|
|
|
|
Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
|
|
// Update the foreach body.
|
|
SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
|
|
blockArgTps.push_back(srcStt.getElementType());
|
|
blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
|
|
adaptor.getInitArgs().getTypes().end());
|
|
Block *body = op.getBody();
|
|
// Block Args: [dimCrd, val, initArgs]
|
|
unsigned preArgNum = body->getNumArguments();
|
|
for (Type t : blockArgTps)
|
|
body->addArgument(t, loc);
|
|
|
|
// Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
|
|
rewriter.setInsertionPointToStart(body);
|
|
ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
|
|
|
|
ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
|
|
CrdTransDirectionKind::lvl2dim);
|
|
rewriter.replaceAllUsesWith(
|
|
body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
|
|
body->eraseArguments(0, srcStt.getDimRank());
|
|
// Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
|
|
unsigned numInitArgs = op.getInitArgs().size();
|
|
rewriter.replaceAllUsesWith(body->getArgument(0),
|
|
body->getArgument(lvlRank + numInitArgs + 1));
|
|
body->eraseArgument(0);
|
|
// Block Args: [initArgs, lvlCrds, val, DemappedArgs]
|
|
ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
|
|
ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
|
|
// Remap back before replacement.
|
|
SmallVector<Value> reMappedArgs =
|
|
remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
|
|
rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
|
|
body->eraseArguments(0, numInitArgs);
|
|
// Block Args: [lvlCrds, DemappedArgs] and we are done.
|
|
|
|
// Update yield operations.
|
|
if (numInitArgs != 0) {
|
|
rewriter.setInsertionPointToEnd(body);
|
|
auto yield = llvm::cast<YieldOp>(body->getTerminator());
|
|
if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
|
|
stt && !stt->isIdentity()) {
|
|
Value y =
|
|
genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
|
|
rewriter.create<YieldOp>(loc, y);
|
|
rewriter.eraseOp(yield);
|
|
}
|
|
}
|
|
rewriter.finalizeOpModification(op);
|
|
|
|
rewriter.setInsertionPointAfter(op);
|
|
SmallVector<Value> outs =
|
|
remapValueRange(rewriter, prevRetTps, op.getResults());
|
|
|
|
// Replace all the uses of the foreach results, expect the use in
|
|
// reinterpret_map used to remap the output.
|
|
for (auto [from, to] : llvm::zip(op.getResults(), outs))
|
|
rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
|
|
ReinterpretMapScope scope) {
|
|
if (scope == ReinterpretMapScope::kAll ||
|
|
scope == ReinterpretMapScope::kGenericOnly) {
|
|
patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
|
|
patterns.getContext());
|
|
}
|
|
if (scope == ReinterpretMapScope::kAll ||
|
|
scope == ReinterpretMapScope::kExceptGeneric) {
|
|
patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
|
|
TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
|
|
SparseDisassembleDemapper, TensorInsertDemapper,
|
|
ForeachOpDemapper>(patterns.getContext());
|
|
}
|
|
}
|