llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
wren romano 2af2e4dbb7 [mlir][sparse] Breaking up openSparseTensor to better support non-permutations
This commit updates how the `SparseTensorConversion` pass handles `NewOp`.  It breaks up the underlying `openSparseTensor` function into two parts (`SparseTensorReader::create` and `SparseTensorReader::readSparseTensor`) so that the pass can inject code for constructing `lvlSizes` between those two parts.  Migrating the construction of `lvlSizes` out of the runtime and into the pass is a necessary first step toward fully supporting non-permutations.  (The alternative would be for the pass to generate a `FuncOp` for performing the construction and then passing that to the runtime; which doesn't seem to have any benefits over the design of this commit.)  And since the pass now generates the code to call these two functions, this change also removes the `Action::kFromFile` value from the enum used by `_mlir_ciface_newSparseTensor`.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138363
2022-12-02 11:10:57 -08:00

1121 lines
46 KiB
C++

//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting rules that are specific to sparse tensors.
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::linalg;
using namespace mlir::sparse_tensor;
//===---------------------------------------------------------------------===//
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//
// Helper method to match any typed zero.
static bool isZeroValue(Value val) {
return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
}
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(OpOperand *op) {
if (auto enc = getSparseTensorEncoding(op->get().getType())) {
if (llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed))
return true;
}
return false;
}
// Helper method to find zero/uninitialized allocation.
static bool isAlloc(OpOperand *op, bool isZero) {
Value val = op->get();
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Value copy = alloc.getCopy();
if (isZero)
return copy && isZeroValue(copy);
return !copy;
}
return false;
}
// Helper to detect sampling operation.
static bool isSampling(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
// Both scalar input arguments used exactly once.
Value s1 = op.getBlock()->getArgument(0);
Value s2 = op.getBlock()->getArgument(1);
return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
(def->getOperand(1) == s1 && def->getOperand(0) == s2);
}
}
return false;
}
// Helper to detect chain of multiplications that do not involve x.
static bool isMulChain(Value val, Value x) {
if (auto arg = val.dyn_cast<BlockArgument>())
return arg != x;
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
return isMulChain(def->getOperand(0), x) &&
isMulChain(def->getOperand(1), x);
}
return false;
}
// Helper to detect x = x + <multiplications>.
static bool isSumOfMul(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
Value x = op.getBlock()->getArguments().back();
return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
(def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
}
}
return false;
}
// Helper to detect direct yield of a zero value.
static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
if (arg.getOwner()->getParentOp() == op) {
return isZeroValue(op->getOperand(arg.getArgNumber()));
}
}
return isZeroValue(yieldOp.getOperand(0));
}
/// Populates given sizes array from type (for static sizes) and from
/// the tensor (for dynamic sizes).
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, ShapedType stp, Value tensor) {
for (const auto &d : enumerate(stp.getShape())) {
Value dim;
if (d.value() == ShapedType::kDynamic)
dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
else
dim = constantIndex(builder, loc, d.value());
sizes.push_back(dim);
}
}
// TODO: The dim level property of the COO type relies on input tensors, the
// shape relies on the output tensor
// Helpers to setup a COO type.
static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
auto *ctx = src.getContext();
auto rank = src.getRank();
SmallVector<DimLevelType> dims;
// An unordered and non-unique compressed dim at beginning.
dims.push_back(DimLevelType::CompressedNuNo);
if (rank > 1) {
// TODO: it is actually ordered at the level for ordered input.
// Followed by unordered non-unique n-2 singleton levels.
std::fill_n(std::back_inserter(dims), rank - 2,
DimLevelType::SingletonNuNo);
// TODO: only if all the inputs (for concatentate) are unique at the last
// level should the COO has a unique level at the end. Ends by a unordered
// unique singleton level unless the tensor rank is 1.
dims.push_back(DimLevelType::SingletonNo);
}
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src);
// TODO: Maybe pick the bitwidth based on input/output tensors (probably the
// largest one among them) in the original operation instead of using the
// default value.
unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0;
unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0;
auto enc = SparseTensorEncodingAttr::get(
ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(),
pointerBitWidth, indexBitWidth);
return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
}
/// Collects the dynamic dimension sizes for `tp` with the assumption that
/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
/// sizes to dynSizes.
static void getDynamicSizes(RankedTensorType tp,
const SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &dynSizes) {
for (const auto &d : enumerate(tp.getShape())) {
if (d.value() == ShapedType::kDynamic)
dynSizes.push_back(sizes[d.index()]);
}
}
static LogicalResult genForeachOnSparseConstant(ForeachOp op,
RewriterBase &rewriter,
SparseElementsAttr attr) {
auto loc = op.getLoc();
SmallVector<Value> reduc = op.getInitArgs();
// Foreach on constant.
foreachInSparseConstant(
loc, rewriter, attr,
[&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
SmallVector<Value> args;
args.append(coords.begin(), coords.end());
args.push_back(v);
args.append(reduc);
// Clones the foreach op to get a copy of the loop body.
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
assert(args.size() == cloned.getBody()->getNumArguments());
Operation *yield = cloned.getBody()->getTerminator();
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
// clean up
rewriter.eraseOp(cloned);
reduc = yield->getOperands();
rewriter.eraseOp(yield);
});
rewriter.replaceOp(op, reduc);
return success();
}
/// Populates the given sizes array for concatenation from types (for static
/// sizes) and from the source tensors (for dynamic sizes).
static void concatSizesFromInputs(OpBuilder &builder,
SmallVectorImpl<Value> &sizes, Location loc,
ShapedType dstTp, ValueRange srcs,
unsigned dim) {
auto dstShape = dstTp.getShape();
sizesFromSrc(builder, sizes, loc, srcs[0]);
// Sum up on the `dim` if the dimension is dynamic.
if (dstShape[dim] != ShapedType::kDynamic) {
// Faithfully take the static size.
sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
} else {
// Else, compute the shape dynamically.
for (const auto &src : srcs.drop_front()) {
Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
// Sum up all the sizes.
sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
}
}
}
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
namespace {
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
// Yielding zero on newly allocated (all-zero) sparse tensors can be
// optimized out directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
// Incorporate zero value into allocation copy.
if (!outputType.hasStaticShape())
return failure();
Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
};
/// Rewriting rule that converts two kernels:
///
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
/// X(i,j) = S(i,j) * T(i,j)
///
/// into a single kernel, using distributive law:
///
/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
///
/// This kind of fusion (merging two ops into one but using arithmetic
/// equalities that may not hold for floating-point computations) would
/// be undesirable in the dense case, since we distribute the multiplication
/// into the reduction loop. However, for sparse sampling tensor S, such
/// a fusion may actually reduce the asymptotic complexity of the kernel,
/// since intermediate results may be nullified.
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
// Check consumer.
if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 ||
op.getNumResults() != 1 ||
op.getNumParallelLoops() != op.getNumLoops() ||
!op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
!op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
!op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
return failure();
// Find consuming OP2(sparse, other) or OP2(other, sparse). The other
// operand can be sparse or dense, since the point of this rewriting rule
// is detecting a situation in which *more* sparsity is introduced into
// a computation, be it already sparse or still dense.
unsigned other = 0;
if (isSparseTensor(op.getDpsInputOperand(0)))
other = 1;
else if (!isSparseTensor(op.getDpsInputOperand(1)))
return failure();
// Check producer.
auto prod = dyn_cast_or_null<GenericOp>(
op.getDpsInputOperand(other)->get().getDefiningOp());
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
!isSampling(op) || !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
Location loc = prod.getLoc();
SmallVector<Value> inputOps = prod.getInputs();
SmallVector<Value> outputOps = op.getOutputs();
SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
// Fuse producer and consumer into a new generic op.
auto fusedOp = rewriter.create<GenericOp>(
loc, op.getResult(0).getType(), inputOps, outputOps,
rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
/*doc=*/nullptr, /*library_call=*/nullptr);
Block &prodBlock = prod.getRegion().front();
Block &consBlock = op.getRegion().front();
BlockAndValueMapping mapper;
Block *fusedBlock = new Block();
fusedOp.getRegion().push_back(fusedBlock);
unsigned num = prodBlock.getNumArguments();
for (unsigned i = 0; i < num - 1; i++)
addArg(mapper, fusedBlock, prodBlock.getArgument(i));
addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
// Clone bodies of the producer and consumer in new evaluation order.
auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
rewriter.setInsertionPointToStart(fusedBlock);
Value last;
for (auto &op : prodBlock.without_terminator())
if (&op != acc) {
last = op.getResult(0);
rewriter.clone(op, mapper);
}
mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
last = rewriter.clone(*acc, mapper)->getResult(0);
rewriter.create<linalg::YieldOp>(loc, last);
// Force initial value on merged allocation for dense outputs.
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
Value init = prod.getDpsInitOperand(0)
->get()
.getDefiningOp<AllocTensorOp>()
.getCopy();
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
// Replace consumer with fused operation. Old producer
// and consumer ops will be removed by DCE.
rewriter.replaceOp(op, fusedOp->getResults());
return success();
}
private:
// Helper to add argument and record the mapping.
static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
}
};
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc) {
return failure();
}
// Generate code to represent the static dimension constants or compute
// the dynamic dimension values.
SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
SmallVector<Value> dstSizes;
SmallVector<Value> dstDynSizes;
if (dstTp.hasStaticShape()) {
for (auto d : dstTp.getShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
} else {
ArrayRef<int64_t> dstShape = dstTp.getShape();
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
for (auto &d : llvm::enumerate(dstShape)) {
if (d.value() == ShapedType::kDynamic)
dstDynSizes.push_back(dstSizes[d.index()]);
}
}
// Implement the sparse2sparse reshape as follows:
// %tmp = bufferization.alloc_tensor : unordered COO
// foreach srcCoords %srcTensor
// insert translateIndicesArray(srcCoords), %tmp
// %t = sparse_tensor.cast %tmp
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, cooBuffer,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
SmallVector<Value> srcIndices;
SmallVector<Value> dstIndices;
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
uint64_t dim = toStoredDim(encSrc, i);
srcIndices.push_back(args[dim]);
}
translateIndicesArray(builder, loc, op.getReassociationIndices(),
srcIndices, srcSizes, dstSizes, dstIndices);
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstIndices);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
auto converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
rewriter.replaceOp(op, converted);
return success();
}
};
/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
/// operator.
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(op.getResult().getType());
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
// Since a pure dense expansion is very cheap (change of view), for
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
// conversion from the reshape operation itself.
// All other cases are handled elsewhere.
if (encDst && encSrc) {
return failure();
}
if (encSrc) {
RankedTensorType rtp =
op.getSrc().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
op->setOperand(0, convert);
return success();
}
if (encDst) {
RankedTensorType rtp =
op.getResult().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
op.getReassociation());
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
}
return failure();
}
};
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto dstTp = op.getType().cast<RankedTensorType>();
uint64_t conDim = op.getDimension().getZExtValue();
SmallVector<Value> sizes;
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
// %t = concatenate %s1, %s2, %s3 {dim = 1}
// ==>
// if (isSparseDst)
// if (allDense)
// %tmp = bufferization.alloc_tensor dstTp
// else
// %tmp = bufferization.alloc_tensor : unordered COO
// else
// %tmp = memref.alloc : dense tensor
// foreach in %s1 : insert d0, d1, %tmp
// foreach in %s2 : insert d0, d1 + size(s1), %tmp
// foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
// %t = convert_to_dest_tensor(%tmp)
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
Value dst; // Destination tensor for inserting source tensor values.
bool allDense = false;
if (encDst) {
allDense = llvm::all_of(encDst.getDimLevelType(),
[](DimLevelType dlt) { return isDenseDLT(dlt); });
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
RankedTensorType tp = dstTp;
if (!allDense) {
tp = getUnorderedCOOFromType(dstTp);
encDst = getSparseTensorEncoding(tp);
}
dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
} else {
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
}
int64_t rank = dstTp.getRank();
Value offset = constantIndex(rewriter, loc, 0);
SmallVector<Value> initArgs;
if (encDst)
initArgs.push_back(dst);
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Build a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
loc, input, initArgs,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
SmallVector<Value> indices(rank, Value());
for (int64_t i = 0; i < rank; i++) {
Value idx = args[i];
if (i == static_cast<int64_t>(conDim))
// Transform coordinates for the concatenating dim.
idx = builder.create<arith::AddIOp>(loc, idx, offset);
indices[toStoredDim(encDst, i)] = idx;
}
if (encDst) {
Value cond = genIsNonzero(rewriter, loc, v);
scf::IfOp ifOp = builder.create<scf::IfOp>(
loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value t =
builder.create<InsertOp>(loc, v, reduc.front(), indices);
rewriter.create<scf::YieldOp>(loc, t);
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, reduc.front());
rewriter.setInsertionPointAfter(ifOp);
rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
} else {
builder.create<memref::StoreOp>(loc, v, dst, indices);
builder.create<sparse_tensor::YieldOp>(loc);
}
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
// dynamically.
int64_t d = input.getType().cast<RankedTensorType>().getShape()[conDim];
assert(!ShapedType::isDynamic(d));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, d));
if (encDst) {
dst = foreachOp.getResult(0);
initArgs[0] = dst;
}
}
if (encDst) {
dst = rewriter.create<LoadOp>(loc, dst, true);
if (!allDense) {
Value tmpCoo = dst;
dst = rewriter.create<ConvertOp>(loc, dstTp, tmpCoo).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
rewriter.replaceOp(op, dst);
} else {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
}
return success();
}
};
/// Sparse rewriting rule for the convert operator.
struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (encDst && encSrc) {
// Trivial tensor conversion is handled in codegen.
if (encSrc == encDst)
return failure();
return sparse2SparseRewrite(op, rewriter);
}
if (encSrc && !encDst)
return sparse2DenseRewrite(op, rewriter);
if (!encSrc && encDst)
return dense2SparseRewrite(op, rewriter);
// Dense-to-dense convert is a nop and handled by canonicalization.
return failure();
}
private:
// Handles sparse constant to sparse tensor or dense tensor to sparse tensor
// conversion as follows:
// t = new sparse COO tensor
// fill t using src
// dst = convert t
//
// To fill the COO tensor from a dense tensor:
// for i1 in dim1
// ..
// for ik in dimk
// val = a[i1,..,ik]
// if val != 0
// t->add(val, [i1,..,ik], [p1,..,pk])
//
// To fill the COO tensor from a sparse constant in COO format:
// for i in range(NNZ)
// val = values[i]
// [i1,..,ik] = indices[i]
// t->add(val, [i1,..,ik], [p1,..,pk])
LogicalResult dense2SparseRewrite(ConvertOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value src = op.getSource();
RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
SmallVector<Value> sizes;
sizesFromSrc(rewriter, sizes, loc, src);
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
bool fromSparseConst = false;
if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
fromSparseConst = true;
}
}
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, cooBuffer,
[&](OpBuilder &builder, Location loc, ValueRange indices, Value v,
ValueRange reduc) {
Value input = reduc.front();
if (fromSparseConst) {
input = builder.create<InsertOp>(loc, v, input, indices);
} else {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(
loc, TypeRange(input.getType()), cond, /*else*/ true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value insert = builder.create<InsertOp>(loc, v, input, indices);
builder.create<scf::YieldOp>(loc, insert);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, input);
builder.setInsertionPointAfter(ifOp);
input = ifOp.getResult(0);
}
builder.create<sparse_tensor::YieldOp>(loc, input);
});
rewriter.setInsertionPointAfter(op);
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, src);
rewriter.create<DeallocTensorOp>(loc, src);
return success();
}
// Handles sparse tensor to dense tensor conversion as follows:
// dst = new dense tensor;
// foreach elemment in src
// dst[elemment.indices] = element.value
LogicalResult sparse2DenseRewrite(ConvertOp op,
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
Value src = op.getSource();
RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
SmallVector<Value> sizes;
sizesForTensor(rewriter, sizes, loc, srcTp, src);
Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
Block *insertionBlock = rewriter.getInsertionBlock();
bool noEscape = bufferization::allocationDoesNotEscape(op->getOpResult(0));
rewriter.create<ForeachOp>(loc, src, llvm::None,
[&](OpBuilder &builder, Location loc,
ValueRange args, Value v, ValueRange reduc) {
builder.create<memref::StoreOp>(loc, v, dst,
args);
builder.create<sparse_tensor::YieldOp>(loc);
});
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
// Deallocate the buffer.
if (noEscape) {
rewriter.setInsertionPoint(insertionBlock->getTerminator());
deallocDenseTensor(rewriter, loc, dst);
}
return success();
}
// Handles sparse tensor to sparse tensor conversion as follows:
// if src is not COO
// construct a COO to represent the src
// sort the src COO
// foreach elemment in the sorted src COO
// insert element to dst
LogicalResult sparse2SparseRewrite(ConvertOp op,
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value src = op.getSource();
RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, src);
Value tmpCoo = Value();
if (!isUniqueCOOType(srcTp)) {
// Construct a COO tensor from the src tensor.
// TODO: there may be cases for which more efficiently without
// going through an intermediate COO, such as cases that only change
// the overhead types.
SmallVector<Value> dynSrcSizes;
getDynamicSizes(srcTp, srcSizes, dynSrcSizes);
srcTp = getUnorderedCOOFromType(srcTp);
tmpCoo =
rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, tmpCoo,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
// The resulting COO tensor has identity ordering.
auto t = builder.create<InsertOp>(loc, v, reduc.front(),
args.slice(0, srcTp.getRank()));
builder.create<sparse_tensor::YieldOp>(loc, t);
});
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
}
// Sort the COO tensor so that its elements are ordered via increasing
// indices for the storage ordering of the dst tensor.
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
auto dynShape = {ShapedType::kDynamic};
auto indTp =
MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc));
uint64_t rank = dstTp.getRank();
// Gather the indices-arrays in the dst tensor storage order.
SmallVector<Value> xs(rank, Value());
for (uint64_t i = 0; i < rank; i++) {
uint64_t orgDim = toOrigDim(encSrc, i);
xs[toStoredDim(encDst, orgDim)] = rewriter.create<ToIndicesOp>(
loc, indTp, src, rewriter.getIndexAttr(i));
}
// Retrieve NNZ.
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
nnz =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nnz);
// Retrieve the values-array.
auto valTp = MemRefType::get(dynShape, srcTp.getElementType());
Value y = rewriter.create<ToValuesOp>(loc, valTp, src);
// Sort the COO tensor.
rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
// For each element in the COO tensor, insert the element to the dst tensor.
SmallVector<Value> dynDstSizes;
getDynamicSizes(dstTp, srcSizes, dynDstSizes);
Value dst =
rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
SmallVector<Value> indices(srcTp.getRank(), Value());
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, dst,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
uint64_t dim = toStoredDim(encDst, i);
indices[dim] = args[i];
}
auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
// Release the temporary COO if it is created. Note that tmpCoo is
// invalidated due to foreach and updated to src.
if (tmpCoo)
rewriter.create<DeallocTensorOp>(loc, src);
// Directly replace op with dst results in bufferization error message
// "sparse tensor allocation should not escape function".
// As such, we insert a trivial tensor convert which will be removed by
// codegen.
rewriter.setInsertionPointAfter(op);
auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, t);
return success();
}
};
/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ForeachOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
auto rtp = input.getType().cast<RankedTensorType>();
int64_t rank = rtp.getRank();
// Special-case: for each over a sparse constant uses its own rewriting
// rule.
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
return genForeachOnSparseConstant(op, rewriter, attr);
}
}
// Otherwise, use loop emitter to generate loops.
auto enc = getSparseTensorEncoding(rtp);
// 1. Generates loop for the sparse input.
SparseTensorLoopEmitter loopEmitter(
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
loopEmitter.initializeLoopEmit(rewriter, loc);
for (int64_t i = 0; i < rank; i++) {
// TODO: provide utility function for loop sequences that only contains
// one for loop?
loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast<size_t>(i));
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i, reduc);
}
SmallVector<Value> coords;
coords.reserve(rank);
loopEmitter.getCoordinateArray(coords);
Value vals = loopEmitter.getValBuffer()[0];
Value pidx = loopEmitter.getPidxs()[0].back();
// Loads the value from sparse tensor using pointer index;
// loads the value from dense tensor using coordinate array.
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
: rewriter.create<memref::LoadOp>(loc, vals, coords);
// 2. Inline the block in the foreach operator.
Block *srcBlock = op.getBody();
// Remap coordinates.
SmallVector<Value> args;
for (int64_t i = 0; i < rank; i++) {
Value actual = coords[toStoredDim(enc, i)];
args.push_back(actual);
}
// Remap value.
args.push_back(val);
// Remap reduction variables.
args.append(reduc);
// Remove sparse_tensor.yield.
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
rewriter.eraseOp(srcBlock->getTerminator());
// Inline body.
if (!reducValue.empty()) {
rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
} else {
// This is annoying, since scf.for inserts a implicit yield op when
// there is no reduction variable upon creation, in this case we need to
// merge the block *before* the yield op.
rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args);
}
for (int64_t i = 0; i < rank; i++) {
// Link the reduction chain. Note that loop emitter update the reducValue
// in place.
loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
loopEmitter.exitCurrentLoopSeq();
}
// Replace the foreach operator with the value returned by the outtermost
// for loop.
rewriter.replaceOp(op, reducValue);
return success();
}
};
/// Sparse rewriting rule for the new operator.
struct NewRewriter : public OpRewritePattern<NewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst)
return failure();
// Create a sparse tensor reader.
Value fileName = op.getSource();
Type opaqueTp = getOpaquePointerType(rewriter);
Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader",
{opaqueTp}, {fileName}, EmitCInterface::Off)
.getResult(0);
// Allocate a temporary buffer for storing dimension sizes and indices.
Type indexTp = rewriter.getIndexType();
uint64_t rank = dstTp.getRank();
Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
// If the result tensor has dynamic dimensions, get the dynamic sizes from
// the sparse tensor reader.
SmallVector<Value> dynSizesArray;
if (!dstTp.hasStaticShape()) {
createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {},
{reader, dimSizes}, EmitCInterface::On)
.getResult(0);
ArrayRef<int64_t> dstShape = dstTp.getShape();
for (auto &d : llvm::enumerate(dstShape)) {
if (d.value() == ShapedType::kDynamic) {
dynSizesArray.push_back(rewriter.create<memref::LoadOp>(
loc, dimSizes, constantIndex(rewriter, loc, d.index())));
}
}
}
// Implement the NewOp as follows:
// %tmp = bufferization.alloc_tensor : an unordered COO with identity
// storage ordering
// for i = 0 to nnz
// get the next element from the input file
// insert the element to %tmp
// %t = sparse_tensor.ConvertOp %tmp
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dynSizesArray).getResult();
Value c0 = constantIndex(rewriter, loc, 0);
Value c1 = constantIndex(rewriter, loc, 1);
Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
Value symmetric;
// The verifier ensures only 2D tensors can have the expandSymmetry flag.
if (rank == 2 && op.getExpandSymmetry()) {
symmetric =
createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric",
{rewriter.getI1Type()}, {reader}, EmitCInterface::Off)
.getResult(0);
} else {
symmetric = Value();
}
Type eltTp = dstTp.getElementType();
Value value = genAllocaScalar(rewriter, loc, eltTp);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1,
ArrayRef<Value>(cooBuffer));
rewriter.setInsertionPointToStart(forOp.getBody());
SmallString<29> getNextFuncName{"getSparseTensorReaderNext",
primaryTypeFunctionSuffix(eltTp)};
Value indices = dimSizes; // Reuse the indices memref to store indices.
createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value},
EmitCInterface::On);
SmallVector<Value> indicesArray;
for (uint64_t i = 0; i < rank; i++) {
indicesArray.push_back(rewriter.create<memref::LoadOp>(
loc, indices, constantIndex(rewriter, loc, i)));
}
Value v = rewriter.create<memref::LoadOp>(loc, value);
Value t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
indicesArray);
if (symmetric) {
Value eq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]);
Value cond = rewriter.create<arith::AndIOp>(loc, symmetric, eq);
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, t.getType(), cond, /*else*/ true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<scf::YieldOp>(
loc, Value(rewriter.create<InsertOp>(
loc, v, t, ValueRange{indicesArray[1], indicesArray[0]})));
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, t);
t = ifOp.getResult(0);
rewriter.setInsertionPointAfter(ifOp);
}
rewriter.create<scf::YieldOp>(loc, ArrayRef<Value>(t));
rewriter.setInsertionPointAfter(forOp);
// Link SSA chain.
cooBuffer = forOp.getResult(0);
// Release the sparse tensor reader.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
Value newOp = rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
// Release the unordered COO tensor buffer.
rewriter.setInsertionPointAfterValue(newOp);
rewriter.create<DeallocTensorOp>(loc, cooBuffer);
return success();
}
};
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Calculate NNZ.
Value src = op.getTensor();
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
// Allocate a temporary buffer for storing dimension sizes and indices.
auto srcTp = src.getType().template cast<RankedTensorType>();
uint64_t rank = srcTp.getRank();
Type indexTp = rewriter.getIndexType();
Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
// Generate code to calculate dimension size values and store the values to
// the buffer.
SmallVector<Value> dims;
sizesForTensor(rewriter, dims, loc, srcTp, src);
for (uint64_t i = 0; i < rank; i++) {
rewriter.create<memref::StoreOp>(loc, dims[i], dimSizes,
constantIndex(rewriter, loc, i));
}
// Create a sparse tensor writer and output meta data.
Type opaqueTp = getOpaquePointerType(rewriter);
Value writer =
createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
{op.getDest()}, EmitCInterface::Off)
.getResult(0);
Value rankValue = constantIndex(rewriter, loc, rank);
createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
{writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
Value indices = dimSizes; // Reuse the dimSizes buffer for indices.
Type eltTp = srcTp.getElementType();
SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();
// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, llvm::None,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
for (uint64_t i = 0; i < rank; i++) {
rewriter.create<memref::StoreOp>(loc, args[i], indices,
constantIndex(builder, loc, i));
}
rewriter.create<memref::StoreOp>(loc, v, value);
SmallVector<Value> operands{writer, rankValue, indices, value};
FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
EmitCInterface::On);
builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
builder.create<sparse_tensor::YieldOp>(loc);
});
// Release the writer.
createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
EmitCInterface::Off);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
patterns.getContext());
}
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
patterns.getContext());
if (enableConvert)
patterns.add<ConvertRewriter>(patterns.getContext());
}
}