Rename interface functions as follows: * `hasTensorSemantics` -> `hasPureTensorSemantics` * `hasBufferSemantics` -> `hasPureBufferSemantics` These two functions return "true" if the op has tensor/buffer operands but not buffer/tensor operands. Also drop the "ranked" part from the interface, i.e., do not distinguish between ranked/unranked types. The new function names describe the functions more accurately. They also align their semantics with the notion of "tensor semantics" with the bufferization framework. (An op is supposed to be bufferized if it has tensor operands, and we don't care if it also has memref operands.) This change is in preparation of #75273, which adds `BufferizableOpInterface::hasTensorSemantics`. By renaming the functions in the `DestinationStyleOpInterface`, we can avoid name clashes between the two interfaces.
765 lines
30 KiB
C++
765 lines
30 KiB
C++
//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Utils/CodegenUtils.h"
|
|
#include "Utils/IterationGraphSorter.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// File Local Helper classes.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CRTP to help implementing a rewriter that demaps all its inputs.
|
|
template <typename SubClass, typename SourceOp>
|
|
struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
|
|
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
|
|
LogicalResult matchAndRewrite(SourceOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// Demaps non-trivial inputs.
|
|
bool changed = false;
|
|
SmallVector<Value> deMappedIns(op->getOperands());
|
|
for (Value &in : deMappedIns) {
|
|
if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
|
|
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
|
|
changed = true;
|
|
}
|
|
}
|
|
|
|
// CRTP call.
|
|
OpAdaptor adaptor(deMappedIns, op);
|
|
LogicalResult status =
|
|
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
|
|
return changed ? success() : status;
|
|
}
|
|
};
|
|
|
|
// Flattens an affine expression into a list of AffineDimExprs.
|
|
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
|
|
explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
|
|
void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
|
|
BitVector dims;
|
|
};
|
|
|
|
// Flattens an affine expression into a list of AffineDimExprs.
|
|
struct AffineExprAdmissibleVisitor
|
|
: public AffineExprVisitor<AffineExprAdmissibleVisitor> {
|
|
explicit AffineExprAdmissibleVisitor(bool isOutput)
|
|
: admissible(true), isOutput(isOutput){};
|
|
|
|
// We only allow AffineDimExpr on output.
|
|
void visitAddExpr(AffineBinaryOpExpr expr) {
|
|
if (isOutput)
|
|
admissible = false;
|
|
}
|
|
void visitMulExpr(AffineBinaryOpExpr expr) {
|
|
if (isOutput)
|
|
admissible = false;
|
|
}
|
|
|
|
// We disallow mod, floor div and ceil div on inputs.
|
|
void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
|
|
operator bool() { return admissible; }
|
|
|
|
private:
|
|
bool admissible;
|
|
bool isOutput;
|
|
};
|
|
|
|
// The first BitVector stores levels where inadmissible exprs are used.
|
|
// The second BitVector stores the AffineDimExp that are used by the
|
|
// inadmissible expressions.
|
|
using InadmissInfo = std::pair<BitVector, BitVector>;
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// File Local Helper methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Collects the inadmissible affine expression imposed on levels.
|
|
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
|
|
auto ret = std::make_pair(BitVector(map.getNumResults()),
|
|
BitVector(map.getNumDims()));
|
|
AffineDimCollector collector(map.getNumDims());
|
|
for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
|
|
AffineExprAdmissibleVisitor admissible(isOutput);
|
|
admissible.walkPostOrder(map.getResult(lvl));
|
|
if (!admissible) {
|
|
// Record the inadmissible level.
|
|
ret.first.set(lvl);
|
|
// Record the AffineDimExpr that is used in the inadmissible expr.
|
|
collector.walkPostOrder(map.getResult(lvl));
|
|
}
|
|
}
|
|
ret.second = collector.dims;
|
|
return ret;
|
|
}
|
|
|
|
// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
|
|
// inadmissible affine expressions can be eliminated.
|
|
// For example, we can rewrite
|
|
// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
|
|
// to
|
|
// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
|
|
// by composing inverse(idxMap), that is
|
|
// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
|
|
// -> ((l0 * 2 + l2) floordiv 2,
|
|
// (l1 * 3 + l3) floordiv 3,
|
|
// (l0 * 2 + l2) mod 2,
|
|
// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
|
|
//
|
|
// This function builds the inverse(idxMap) that replace every dimensions used
|
|
// in `info` to levels, and updates the iterator type array `itTps` for the new
|
|
// index variable introduced.
|
|
//
|
|
// Note that the returned affine map does not retain the order of the input
|
|
// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
|
|
// replaced levels, and remaining ones for unused dimensions.
|
|
// For example, to handle
|
|
// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
|
|
// which is a typical map for block_2to4. The function returns:
|
|
// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
|
|
// in which, (l0, l1) together replaces `d1`, yet they appear
|
|
// before `d0` in the resulting affine map.
|
|
// The index (loop) order can later be canonicalized by a topo sort.
|
|
static AffineMap
|
|
genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
|
|
SmallVector<utils::IteratorType> &itTps) {
|
|
MLIRContext *ctx = idxMap.getContext();
|
|
auto [inAdLvls, usedDims] = info;
|
|
// Note that idxMap does not equal to dim2Lvl map, it is computed by
|
|
// composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
|
|
// ID map.
|
|
// TODO: we might fail here, in those case we should really return
|
|
// failure instead of assertion error.
|
|
auto lvl2Idx = inferLvlToDim(idxMap, ctx);
|
|
|
|
assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
|
|
if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
|
|
// This could happen when some dimensions are projected.
|
|
// E.g., idx2Lvl = (*i*, j, k) -> (j, k)
|
|
// ==> lvl2Idx = (j, k) -> (j, k)
|
|
// In this case, we append the unused dimesion at the end.
|
|
// ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
|
|
SmallVector<AffineExpr> results;
|
|
AffineDimCollector usedInLvl(idxMap.getNumDims());
|
|
for (auto e : idxMap.getResults())
|
|
usedInLvl.walkPostOrder(e);
|
|
|
|
unsigned curUsedDimID = 0;
|
|
unsigned curUnusedDimID = lvl2Idx.getNumDims();
|
|
|
|
BitVector unused = usedInLvl.dims.flip();
|
|
for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
|
|
if (unused.test(i))
|
|
results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
|
|
else
|
|
results.push_back(lvl2Idx.getResult(curUsedDimID++));
|
|
}
|
|
lvl2Idx =
|
|
AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
|
|
}
|
|
assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
|
|
|
|
// We do not need to replace the DimExpr that is not used in inadmissible
|
|
// level expressions. We use the first inAdLvl.count() dim to represent the
|
|
// replaced level, the remainings are reserved for unchanged ones.
|
|
// Note that results from the inverse map computed previously does not follow
|
|
// the convention we used, and we need to fix the mismatch below.
|
|
unsigned curRepID = 0;
|
|
unsigned curOriID = inAdLvls.count();
|
|
SmallVector<AffineExpr> results;
|
|
SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
|
|
SmallVector<utils::IteratorType> transItTps;
|
|
|
|
for (unsigned l : inAdLvls.set_bits()) {
|
|
// By our convention, the inadmissible level `l` always appears in the
|
|
// leading part (accumulated by curRepID) of the affine map's parameter
|
|
// list. Record the mapping so that we can replace all the uses of `l` to
|
|
// the correct position after the translation.
|
|
dimRep[l] = getAffineDimExpr(curRepID++, ctx);
|
|
// A new index variable is introduced for the inadmissible level, inherit
|
|
// the iterator type. E.g., if l0 = d0 floordiv 2, the
|
|
// iterator type of l0 equals to the iterator type of d0.
|
|
AffineExpr lvlExp = idxMap.getResult(l);
|
|
AffineDimCollector collector(idxMap.getNumDims());
|
|
collector.walkPostOrder(lvlExp);
|
|
// We assumes a level can only be derived from one dimension.
|
|
assert(collector.dims.count() == 1);
|
|
transItTps.push_back(itTps[collector.dims.find_first()]);
|
|
}
|
|
|
|
for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
|
|
if (usedDims.test(d)) {
|
|
// The dimension is used in some of the inadmissible levels, and it need
|
|
// to be inversed. Get the inversion from the inverse map, and fix the
|
|
// mismatch captured by the above loop.
|
|
results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
|
|
} else {
|
|
// The dimension is not used in any of the inadmissible levels, and it
|
|
// does not need to be inversed. Fix the mismatch by mapping it to the
|
|
// trailing part of the affine map (accumulated by curOriID).
|
|
results.push_back(getAffineDimExpr(curOriID++, ctx));
|
|
transItTps.push_back(itTps[d]);
|
|
}
|
|
}
|
|
unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
|
|
// Update iterator type.
|
|
itTps.assign(transItTps.begin(), transItTps.end());
|
|
return AffineMap::get(numDim, 0, results, ctx);
|
|
}
|
|
|
|
// Translates the index map in the linalg::GenericOp from idx->dim map to
|
|
// idx->lvl map. Returns failure if the index map can not be translated to an
|
|
// admissible form.
|
|
// Returns the translated index map array and the iterator type array.
|
|
static std::optional<std::pair<ArrayAttr, ArrayAttr>>
|
|
translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
|
|
// idxMap is a idx2dim map before reinterpretation.
|
|
MLIRContext *ctx = op.getContext();
|
|
SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
|
|
SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
|
|
for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
|
|
Value tensor = op->getOpOperand(i).get();
|
|
auto stt = tryGetSparseTensorType(tensor);
|
|
if (stt && !stt->isIdentity()) {
|
|
AffineMap dim2Lvl = stt->getDimToLvl();
|
|
// By composing the idx2dim(dim2lvl), we got a idx2lvl Map
|
|
idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
|
|
}
|
|
}
|
|
|
|
// A naive way to handle common constant expressions that arise during dim2lvl
|
|
// translation.
|
|
auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
|
|
unsigned pos, int64_t lvlSz) {
|
|
if (!ShapedType::isDynamic(lvlSz)) {
|
|
auto c0 = getAffineConstantExpr(0, ctx);
|
|
auto lvlExp = getAffineDimExpr(pos, ctx);
|
|
auto szExp = getAffineConstantExpr(lvlSz, ctx);
|
|
|
|
// lvl floordiv lvlSz = 0
|
|
auto divExp =
|
|
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
|
|
cstMapping.try_emplace(divExp, c0);
|
|
|
|
// lvl mod lvlSz = lvl
|
|
auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
|
|
cstMapping.try_emplace(modExp, lvlExp);
|
|
}
|
|
};
|
|
|
|
unsigned boundedNum = 0;
|
|
// A fixed-point algorithm.
|
|
bool changed = true;
|
|
while (changed) {
|
|
changed = false;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
auto stt = tryGetSparseTensorType(operand.get());
|
|
// Skip on dense operands.
|
|
if (!stt || !stt->getEncoding())
|
|
continue;
|
|
|
|
unsigned tid = operand.getOperandNumber();
|
|
bool isOutput = &operand == op.getDpsInitOperand(0);
|
|
AffineMap idxMap = idxMapArray[tid];
|
|
InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
|
|
auto [inAdLvls, dimExprs] = inAdInfo;
|
|
for (unsigned d : dimExprs.set_bits()) {
|
|
// The first `boundedNum` used in the AffineMap is introduced to
|
|
// resolve previous inadmissible expressions. We can not replace them
|
|
// as it might bring back the inadmissible expressions.
|
|
if (d < boundedNum)
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (inAdLvls.count() != 0) {
|
|
// Naive constant progagation, should be sufficient to handle block
|
|
// sparsity in our cases.
|
|
SmallVector<int64_t> lvlShape = stt->getLvlShape();
|
|
DenseMap<AffineExpr, AffineExpr> cstMapping;
|
|
unsigned position = 0;
|
|
for (unsigned lvl : inAdLvls.set_bits()) {
|
|
int64_t lvlSz = lvlShape[lvl];
|
|
populateCstMapping(cstMapping, position, lvlSz);
|
|
position++;
|
|
}
|
|
|
|
AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
|
|
// Compose the lvl2Idx Map to all AffineIdxMap to eliminate
|
|
// inadmissible expressions.
|
|
for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
|
|
AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
|
|
idxMapArray[tid] = transMap.replace(
|
|
cstMapping, /*numResultDims=*/transMap.getNumDims(),
|
|
/*numResultSyms=*/0);
|
|
}
|
|
changed = true;
|
|
boundedNum += inAdLvls.count();
|
|
}
|
|
}
|
|
};
|
|
|
|
SmallVector<Attribute> iterAttr =
|
|
llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
|
|
return linalg::IteratorTypeAttr::get(ctx, itTp);
|
|
});
|
|
|
|
return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
|
|
rewriter.getArrayAttr(iterAttr));
|
|
}
|
|
|
|
// Generates a "de"mapping reinterpretation of the map.
|
|
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
|
|
Value val) {
|
|
return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
|
|
val);
|
|
}
|
|
|
|
// Generates a "re"mapping reinterpretation of the map.
|
|
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
|
|
Value val) {
|
|
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
|
|
}
|
|
|
|
static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
|
|
ValueRange outs) {
|
|
SmallVector<Value> ret(outs);
|
|
assert(outs.size() == types.size());
|
|
for (auto [r, t] : llvm::zip(ret, types))
|
|
if (r.getType() != t)
|
|
r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
|
|
return ret;
|
|
}
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Rewriting rules for linalg generic ops.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Sparse rewriting rule for the generic `linalg` operation.
|
|
struct GenericOpReinterpretMap
|
|
: public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
|
|
public:
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
// Only rewrite single output operations with pure (sparse) tensor
|
|
// semantics.
|
|
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
|
|
!hasAnySparseOperandOrResult(linalgOp) ||
|
|
!hasAnyNonIdentityOperandsOrResults(linalgOp))
|
|
return failure();
|
|
|
|
// Try translating the index map.
|
|
auto transMap = translateMap(linalgOp, rewriter);
|
|
if (!transMap)
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be sparsified.");
|
|
|
|
// On success, replace update the linalg operands and maps in place.
|
|
Value res = linalgOp.getResult(0);
|
|
auto stt = tryGetSparseTensorType(res);
|
|
auto [idxMap, itTp] = *transMap;
|
|
|
|
rewriter.startRootUpdate(linalgOp);
|
|
linalgOp.setIndexingMapsAttr(idxMap);
|
|
linalgOp.setIteratorTypesAttr(itTp);
|
|
// Use demapped arguments.
|
|
linalgOp.getInputsMutable().assign(adaptor.getInputs());
|
|
linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
|
|
res.setType(adaptor.getOutputs()[0].getType());
|
|
rewriter.finalizeRootUpdate(linalgOp);
|
|
|
|
rewriter.setInsertionPointAfter(linalgOp);
|
|
if (stt && stt->hasEncoding()) {
|
|
Value t = genRemap(rewriter, stt->getEncoding(), res);
|
|
rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
|
|
hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
|
|
!hasAnySparseOperandOrResult(linalgOp)) {
|
|
return failure();
|
|
}
|
|
|
|
const StringRef sorted = "sorted";
|
|
if (linalgOp->hasAttr(sorted))
|
|
return failure();
|
|
|
|
auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
|
|
bool isAdmissible = false;
|
|
AffineMap order;
|
|
// A const list of all masks that we used for iteration graph
|
|
// computation. Must be ordered from more strict to less strict.
|
|
// Ideally (though might not be guaranteed), the earlier a constraint mask
|
|
// can be satisfied, the faster the generated kernel will be.
|
|
const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
|
|
SortMask::kIncludeDenseInput,
|
|
SortMask::kIncludeDenseOutput,
|
|
SortMask::kSparseOnly};
|
|
for (const SortMask mask : allMasks) {
|
|
order = scheduler.sort(mask);
|
|
if (order) {
|
|
if (isAdmissibleOrder(linalgOp, order)) {
|
|
isAdmissible = true;
|
|
break;
|
|
}
|
|
// else try a set of less strict constraints.
|
|
}
|
|
}
|
|
|
|
if (!order) {
|
|
// Cycles detected.
|
|
if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be scheduled: loop detected.");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
if (!isAdmissible) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "the sparse kernel can not be scheduled.");
|
|
}
|
|
|
|
// Marks the GenericOp to avoid recursive matching.
|
|
rewriter.updateRootInPlace(linalgOp, [&]() {
|
|
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
|
|
});
|
|
|
|
// Already sorted.
|
|
if (order.isIdentity())
|
|
return success();
|
|
|
|
assert(order.isPermutation());
|
|
// `order` is orignial loop -> sorted loop map
|
|
ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
|
|
SmallVector<Attribute> curItTypes;
|
|
curItTypes.reserve(preItTypes.size());
|
|
for (AffineExpr expr : order.getResults()) {
|
|
unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
|
|
curItTypes.push_back(preItTypes[loopID]);
|
|
}
|
|
|
|
// Inverse `order` to get sorted loop -> original loop map
|
|
order = inversePermutation(order);
|
|
SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
|
|
for (AffineMap &idxMap : idxMaps)
|
|
idxMap = idxMap.compose(order); // sorted loop -> lvl map
|
|
|
|
rewriter.startRootUpdate(linalgOp);
|
|
linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
|
|
linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
|
|
rewriter.finalizeRootUpdate(linalgOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
/// Whether the loop order is admissible by sparsification.
|
|
static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
|
|
if (!hasAnySparseResult(linalgOp))
|
|
return true;
|
|
|
|
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
|
|
unsigned nest = 0;
|
|
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
|
|
for (const AffineExpr l : order.getResults()) {
|
|
unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
|
|
auto itTp =
|
|
linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
|
|
if (linalg::isReductionIterator(itTp.getValue()))
|
|
break; // terminate at first reduction
|
|
nest++;
|
|
}
|
|
// Determine admissible dynamic insertion situations:
|
|
// (1) fully injective, since there are no reductions,
|
|
// (2) admissible 1-d expansion in innermost dimension.
|
|
return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
|
|
};
|
|
|
|
// Last resort cycle resolution.
|
|
static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
|
|
linalg::LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) {
|
|
// Compute topological sort while leaving out every sparse input tensor in
|
|
// succession until an acylic iteration graph results.
|
|
for (OpOperand *t : linalgOp.getDpsInputOperands()) {
|
|
Value tval = t->get();
|
|
auto srcEnc = getSparseTensorEncoding(tval.getType());
|
|
// The constraints introduced by compound index expression are
|
|
// complicated. Skip them.
|
|
AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
|
|
bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
|
|
return !llvm::isa<AffineDimExpr>(exp);
|
|
});
|
|
if (!srcEnc || hasCompExpr)
|
|
continue;
|
|
|
|
// Try scheduling loop without constraints from `tval`.
|
|
AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
|
|
if (!order) // still cyclic
|
|
continue;
|
|
|
|
// Found an input tensor that resolves the cycle by inserting a
|
|
// conversion into a sparse tensor that adheres to the iteration
|
|
// graph order.
|
|
auto stt = getSparseTensorType(tval);
|
|
assert(stt.isIdentity());
|
|
order = inversePermutation(order);
|
|
// sorted loop -> lvl map.
|
|
idxMap = idxMap.compose(order);
|
|
|
|
// Found a permutation such that the results in `idxMap` is sorted.
|
|
// For example,
|
|
// (d0, d1, d2, d3) -> (d2, d1, d0)
|
|
// loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
|
|
// we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
|
|
// transposed tensor's levels are visited in the same order as the loop
|
|
// scheduling order.
|
|
SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
|
|
for (AffineExpr expr : idxMap.getResults()) {
|
|
unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
|
|
lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
|
|
}
|
|
std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
|
|
return lhs.first < rhs.first;
|
|
});
|
|
SmallVector<unsigned> perm =
|
|
llvm::to_vector(llvm::make_second_range(lvlSeq));
|
|
auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
|
|
// The result of the idxMap must be unsorted.
|
|
assert(!dimToLvl.isIdentity());
|
|
|
|
// Inserting the transpose
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
|
|
Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
|
|
rewriter.updateRootInPlace(linalgOp, [&]() {
|
|
linalgOp->setOperand(t->getOperandNumber(), dst);
|
|
});
|
|
return success();
|
|
}
|
|
// Cannot be resolved with a single conversion.
|
|
// TODO: convert more than one?
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Reinterpret Map Rewriters for operations other than linalg.generics
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename AllocOp>
|
|
struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
|
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(AllocOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
auto stt = getSparseTensorType(op.getResult());
|
|
|
|
SmallVector<Value> maxDimCrds;
|
|
maxDimCrds.reserve(stt.getDimRank());
|
|
ValueRange dynSz = op.getDynamicSizes();
|
|
for (int64_t dimSz : stt.getDimShape()) {
|
|
if (ShapedType::isDynamic(dimSz)) {
|
|
Value maxCrd = rewriter.create<arith::SubIOp>(
|
|
loc, dynSz.front(), constantIndex(rewriter, loc, 1));
|
|
maxDimCrds.push_back(maxCrd);
|
|
dynSz = dynSz.drop_front();
|
|
} else {
|
|
maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
|
|
}
|
|
}
|
|
|
|
ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
|
|
CrdTransDirectionKind::dim2lvl);
|
|
auto lvlShape = stt.getLvlShape();
|
|
SmallVector<Value> dynLvlSzs;
|
|
for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
|
|
if (ShapedType::isDynamic(lvlShape[i])) {
|
|
Value sz = rewriter.create<arith::AddIOp>(
|
|
loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
|
|
dynLvlSzs.push_back(sz);
|
|
}
|
|
}
|
|
|
|
assert(dynSz.empty()); // should have consumed all.
|
|
rewriter.startRootUpdate(op);
|
|
op->setOperands(dynLvlSzs);
|
|
op.getResult().setType(stt.getDemappedType());
|
|
rewriter.finalizeRootUpdate(op);
|
|
rewriter.setInsertionPointAfter(op);
|
|
|
|
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
|
|
rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TensorInsertDemapper
|
|
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
if (!hasAnySparseResult(op))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
auto stt = getSparseTensorType(op.getResult());
|
|
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
|
|
CrdTransDirectionKind::dim2lvl);
|
|
auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
|
|
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
|
|
|
|
Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
|
|
rewriter.replaceOp(op, out);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ForeachOpDemapper
|
|
: public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
|
|
using DemapInsRewriter::DemapInsRewriter;
|
|
LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
|
|
PatternRewriter &rewriter) const {
|
|
// Only handle operations with sparse input/output with non-identity dim2lvl
|
|
// maps.
|
|
if (!hasAnyNonIdentityOperandsOrResults(op))
|
|
return failure();
|
|
|
|
// TODO: demap constant as well.
|
|
if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
|
|
if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
// Cache the type information since we update the foreach op in-place.
|
|
auto srcStt = getSparseTensorType(op.getTensor());
|
|
SmallVector<Type> prevRetTps(op.getResultTypes());
|
|
|
|
rewriter.startRootUpdate(op);
|
|
op.getTensorMutable().assign(adaptor.getTensor());
|
|
op.getInitArgsMutable().assign(adaptor.getInitArgs());
|
|
// Update results' types.
|
|
for (auto r : op.getResults())
|
|
if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
|
|
r.setType(stt->getDemappedType());
|
|
|
|
Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
|
|
// Update the foreach body.
|
|
SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
|
|
blockArgTps.push_back(srcStt.getElementType());
|
|
blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
|
|
adaptor.getInitArgs().getTypes().end());
|
|
Block *body = op.getBody();
|
|
// Block Args: [dimCrd, val, initArgs]
|
|
unsigned preArgNum = body->getNumArguments();
|
|
for (Type t : blockArgTps)
|
|
body->addArgument(t, loc);
|
|
|
|
// Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
|
|
rewriter.setInsertionPointToStart(body);
|
|
ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
|
|
|
|
ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
|
|
CrdTransDirectionKind::lvl2dim);
|
|
rewriter.replaceAllUsesWith(
|
|
body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
|
|
body->eraseArguments(0, srcStt.getDimRank());
|
|
// Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
|
|
unsigned numInitArgs = op.getInitArgs().size();
|
|
rewriter.replaceAllUsesWith(body->getArgument(0),
|
|
body->getArgument(lvlRank + numInitArgs + 1));
|
|
body->eraseArgument(0);
|
|
// Block Args: [initArgs, lvlCrds, val, DemappedArgs]
|
|
ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
|
|
ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
|
|
// Remap back before replacement.
|
|
SmallVector<Value> reMappedArgs =
|
|
remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
|
|
rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
|
|
body->eraseArguments(0, numInitArgs);
|
|
// Block Args: [lvlCrds, DemappedArgs] and we are done.
|
|
|
|
// Update yield operations.
|
|
if (numInitArgs != 0) {
|
|
rewriter.setInsertionPointToEnd(body);
|
|
auto yield = llvm::cast<YieldOp>(body->getTerminator());
|
|
if (auto stt = tryGetSparseTensorType(yield.getResult());
|
|
stt && !stt->isIdentity()) {
|
|
Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
|
|
rewriter.create<YieldOp>(loc, y);
|
|
rewriter.eraseOp(yield);
|
|
}
|
|
}
|
|
rewriter.finalizeRootUpdate(op);
|
|
|
|
rewriter.setInsertionPointAfter(op);
|
|
SmallVector<Value> outs =
|
|
remapValueRange(rewriter, prevRetTps, op.getResults());
|
|
|
|
// Replace all the uses of the foreach results, expect the use in
|
|
// reinterpret_map used to remap the output.
|
|
for (auto [from, to] : llvm::zip(op.getResults(), outs))
|
|
rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
|
|
ReinterpretMapScope scope) {
|
|
if (scope == ReinterpretMapScope::kAll ||
|
|
scope == ReinterpretMapScope::kGenericOnly) {
|
|
patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
|
|
patterns.getContext());
|
|
}
|
|
if (scope == ReinterpretMapScope::kAll ||
|
|
scope == ReinterpretMapScope::kExceptGeneric) {
|
|
patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
|
|
TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
|
|
ForeachOpDemapper>(patterns.getContext());
|
|
}
|
|
}
|