//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewriting rules that are specific to sparse tensors. // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::linalg; using namespace mlir::sparse_tensor; //===---------------------------------------------------------------------===// // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// // Helper method to match any typed zero. static bool isZeroValue(Value val) { return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); } // Helper to detect a sparse tensor type operand. static bool isSparseTensor(OpOperand *op) { if (auto enc = getSparseTensorEncoding(op->get().getType())) { if (llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed)) return true; } return false; } // Helper method to find zero/uninitialized allocation. static bool isAlloc(OpOperand *op, bool isZero) { Value val = op->get(); if (auto alloc = val.getDefiningOp()) { Value copy = alloc.getCopy(); if (isZero) return copy && isZeroValue(copy); return !copy; } return false; } // Helper to detect sampling operation. static bool isSampling(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { if (isa(def) || isa(def)) { // Both scalar input arguments used exactly once. Value s1 = op.getBlock()->getArgument(0); Value s2 = op.getBlock()->getArgument(1); return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || (def->getOperand(1) == s1 && def->getOperand(0) == s2); } } return false; } // Helper to detect chain of multiplications that do not involve x. static bool isMulChain(Value val, Value x) { if (auto arg = val.dyn_cast()) return arg != x; if (auto *def = val.getDefiningOp()) { if (isa(def) || isa(def)) return isMulChain(def->getOperand(0), x) && isMulChain(def->getOperand(1), x); } return false; } // Helper to detect x = x + . static bool isSumOfMul(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { if (isa(def) || isa(def)) { Value x = op.getBlock()->getArguments().back(); return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); } } return false; } // Helper to detect direct yield of a zero value. static bool isZeroYield(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto arg = yieldOp.getOperand(0).dyn_cast()) { if (arg.getOwner()->getParentOp() == op) { return isZeroValue(op->getOperand(arg.getArgNumber())); } } return isZeroValue(yieldOp.getOperand(0)); } /// Populates given sizes array from type (for static sizes) and from /// the tensor (for dynamic sizes). static void sizesForTensor(OpBuilder &builder, SmallVector &sizes, Location loc, ShapedType stp, Value tensor) { for (const auto &d : enumerate(stp.getShape())) { Value dim; if (d.value() == ShapedType::kDynamicSize) dim = builder.create(loc, tensor, d.index()); else dim = constantIndex(builder, loc, d.value()); sizes.push_back(dim); } } // TODO: The dim level property of the COO type relies on input tensors, the // shape relies on the output tensor // Helpers to setup a COO type. static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) { auto *ctx = src.getContext(); auto rank = src.getRank(); SmallVector dims; // An unordered and non-unique compressed dim at beginning. dims.push_back(DimLevelType::CompressedNuNo); if (rank > 1) { // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. std::fill_n(std::back_inserter(dims), rank - 2, DimLevelType::SingletonNuNo); // TODO: only if all the inputs (for concatentate) are unique at the last // level should the COO has a unique level at the end. Ends by a unordered // unique singleton level unless the tensor rank is 1. dims.push_back(DimLevelType::SingletonNo); } SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. auto enc = SparseTensorEncodingAttr::get( ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); return RankedTensorType::get(src.getShape(), src.getElementType(), enc); } //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// namespace { /// Rewriting rule that converts direct yield of zero with initial allocation. struct FoldInvariantYield : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { if (!op.hasTensorSemantics() || op.getNumResults() != 1 || !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op)) return failure(); auto outputType = op.getResult(0).getType().cast(); // Yielding zero on newly allocated (all-zero) sparse tensors can be // optimized out directly (regardless of dynamic or static size). if (getSparseTensorEncoding(outputType)) { rewriter.replaceOp(op, op.getOutputOperand(0)->get()); return success(); } // Incorporate zero value into allocation copy. if (!outputType.hasStaticShape()) return failure(); Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType()); AllocTensorOp a = op.getOutputOperand(0)->get().getDefiningOp(); rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); }); rewriter.replaceOp(op, op.getOutputOperand(0)->get()); return success(); } }; /// Rewriting rule that converts two kernels: /// /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) /// X(i,j) = S(i,j) * T(i,j) /// /// into a single kernel, using distributive law: /// /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) /// /// This kind of fusion (merging two ops into one but using arithmetic /// equalities that may not hold for floating-point computations) would /// be undesirable in the dense case, since we distribute the multiplication /// into the reduction loop. However, for sparse sampling tensor S, such /// a fusion may actually reduce the asymptotic complexity of the kernel, /// since intermediate results may be nullified. struct FuseSparseMultiplyOverAdd : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { // Check consumer. if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || op.getNumResults() != 1 || op.getNumParallelLoops() != op.getNumLoops() || !op.getMatchingIndexingMap(op.getOutputOperand(0)).isIdentity() || !op.getMatchingIndexingMap(op.getInputOperand(0)).isIdentity() || !op.getMatchingIndexingMap(op.getInputOperand(1)).isIdentity()) return failure(); // Find consuming OP2(sparse, other) or OP2(other, sparse). The other // operand can be sparse or dense, since the point of this rewriting rule // is detecting a situation in which *more* sparsity is introduced into // a computation, be it already sparse or still dense. unsigned other = 0; if (isSparseTensor(op.getInputOperand(0))) other = 1; else if (!isSparseTensor(op.getInputOperand(1))) return failure(); // Check producer. auto prod = dyn_cast_or_null( op.getInputOperand(other)->get().getDefiningOp()); if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 || !prod.getResult(0).hasOneUse()) return failure(); // Sampling consumer and sum of multiplication chain producer. if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) || !isSampling(op) || !isSumOfMul(prod)) return failure(); // Modify operand structure of producer and consumer. Location loc = prod.getLoc(); SmallVector inputOps = prod.getInputs(); SmallVector outputOps = op.getOutputs(); SmallVector fusedIndexMaps = prod.getIndexingMapsArray(); inputOps.push_back(op.getInputOperand(1 - other)->get()); fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other // Fuse producer and consumer into a new generic op. auto fusedOp = rewriter.create( loc, op.getResult(0).getType(), inputOps, outputOps, rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); Block &prodBlock = prod.getRegion().front(); Block &consBlock = op.getRegion().front(); BlockAndValueMapping mapper; Block *fusedBlock = new Block(); fusedOp.getRegion().push_back(fusedBlock); unsigned num = prodBlock.getNumArguments(); for (unsigned i = 0; i < num - 1; i++) addArg(mapper, fusedBlock, prodBlock.getArgument(i)); addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); // Clone bodies of the producer and consumer in new evaluation order. auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); rewriter.setInsertionPointToStart(fusedBlock); Value last; for (auto &op : prodBlock.without_terminator()) if (&op != acc) { last = op.getResult(0); rewriter.clone(op, mapper); } mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); last = rewriter.clone(*acc, mapper)->getResult(0); rewriter.create(loc, last); // Force initial value on merged allocation for dense outputs. if (!getSparseTensorEncoding(op.getResult(0).getType())) { Value init = prod.getOutputOperand(0) ->get() .getDefiningOp() .getCopy(); AllocTensorOp a = op.getOutputOperand(0)->get().getDefiningOp(); rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); }); } // Replace consumer with fused operation. Old producer // and consumer ops will be removed by DCE. rewriter.replaceOp(op, fusedOp->getResults()); return success(); } private: // Helper to add argument and record the mapping. static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) { mapper.map(a, b->addArgument(a.getType(), a.getLoc())); } }; /// Sparse rewriting rule for sparse-to-sparse reshape operator. template struct Sparse2SparseReshapeRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value srcTensor = op.getSrc(); auto srcTp = srcTensor.getType().template cast(); auto dstTp = op.getResult().getType().template cast(); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); if (!encDst || !encSrc) { return failure(); } // Generate code to represent the static dimension constants or compute // the dynamic dimension values. SmallVector srcSizes; sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); SmallVector dstSizes; SmallVector dstDynSizes; if (dstTp.hasStaticShape()) { for (auto d : dstTp.getShape()) dstSizes.push_back(constantIndex(rewriter, loc, d)); } else { ArrayRef dstShape = dstTp.getShape(); genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); for (auto &d : llvm::enumerate(dstShape)) { if (d.value() == ShapedType::kDynamicSize) dstDynSizes.push_back(dstSizes[d.index()]); } } // Implement the sparse2sparse reshape as follows: // %tmp = bufferization.alloc_tensor : unordered COO // foreach srcCoords %srcTensor // insert translateIndicesArray(srcCoords), %tmp // %t = sparse_tensor.cast %tmp RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); auto cooBuffer = rewriter.create(loc, cooTp, dstDynSizes).getResult(); rewriter.create( loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) { SmallVector srcIndices; SmallVector dstIndices; for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { uint64_t dim = toStoredDim(encSrc, i); srcIndices.push_back(args[dim]); } translateIndicesArray(builder, loc, op.getReassociationIndices(), srcIndices, srcSizes, dstSizes, dstIndices); builder.create(loc, args.back(), cooBuffer, dstIndices); builder.create(loc); }); rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); return success(); } }; /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape /// operator. template struct ReshapeRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto encDst = getSparseTensorEncoding(op.getResult().getType()); auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); // Since a pure dense expansion is very cheap (change of view), for // a sparse2dense or dense2sparse, we can simply unfuse a sparse // conversion from the reshape operation itself. // All other cases are handled elsewhere. if (encDst && encSrc) { return failure(); } if (encSrc) { RankedTensorType rtp = op.getSrc().getType().template cast(); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create(loc, denseTp, op.getSrc()); op->setOperand(0, convert); return success(); } if (encDst) { RankedTensorType rtp = op.getResult().getType().template cast(); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto reshape = rewriter.create(loc, denseTp, op.getSrc(), op.getReassociation()); Value convert = rewriter.create(loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); } return failure(); } }; struct ConcatenateRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto rtp = op.getType().cast(); // TODO: Build the output shape if needed. assert(rtp.hasStaticShape()); auto rank = rtp.getRank(); size_t conDim = op.getDimension().getZExtValue(); // %t = concatenate %s1, %s2, %s3 {dim = 1} // ==> // %tmp = bufferization.alloc_tensor : unordered COO // foreach in %s1 : insert d0, d1, %tmp // foreach in %s2 : insert d0, d1 + size(s1), %tmp // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp // %t = sparse_tensor.cast %tmp auto cooTp = getUnorderedCOOFromType(rtp); auto cooBuffer = rewriter.create(loc, cooTp, ValueRange()).getResult(); Value offset = constantIndex(rewriter, loc, 0); for (Value input : op.getInputs()) { // Builds the indexing map. // Build a for op for each input tensor to append new values into the // output tensor. rewriter.create( loc, input, [&](OpBuilder &builder, Location loc, ValueRange args) { SmallVector indices; for (int64_t i = 0; i < rank; i++) { uint64_t dim = toStoredDim(getSparseTensorEncoding(input.getType()), i); Value idx = args[dim]; if (i == static_cast(conDim)) // transform coordinates on matching dim idx = builder.create(loc, idx, offset); indices.push_back(idx); } builder.create(loc, args.back(), cooBuffer, indices); builder.create(loc); }); // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset // dynamically. auto d = input.getType().cast().getShape()[conDim]; assert(!ShapedType::isDynamic(d)); offset = rewriter.create(loc, offset, constantIndex(rewriter, loc, d)); } rewriter.replaceOpWithNewOp(op, rtp, cooBuffer); return success(); } }; /// Sparse rewriting rule for the foreach operator. struct ForeachRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForeachOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); Value input = op.getTensor(); auto rtp = input.getType().cast(); int64_t rank = rtp.getRank(); auto enc = getSparseTensorEncoding(rtp); // 1. Generates loop for the sparse input. SparseTensorLoopEmitter loopEmitter(ValueRange{input}); loopEmitter.initializeLoopEmit(rewriter, loc); for (int64_t i = 0; i < rank; i++) loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i); SmallVector coords; coords.reserve(rank); loopEmitter.getCoordinateArray(coords); Value vals = loopEmitter.getTensorValueBuffer(0); Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0); // Loads the value from sparse tensor using pointer index; // loads the value from dense tensor using coordinate array. Value val = enc ? rewriter.create(loc, vals, pidx) : rewriter.create(loc, vals, coords); for (int64_t i = 0; i < rank; i++) loopEmitter.exitCurrentLoop(); // 2. Inline the block in the foreach operator. Block::iterator inlinePos = rewriter.getInsertionPoint(); Block *srcBlock = op.getBody(); // Remove sparse_tensor.yield. rewriter.eraseOp(srcBlock->getTerminator()); SmallVector args; // Remap coordinates. for (int64_t i = 0; i < rank; i++) { Value actual = coords[toOrigDim(enc, i)]; args.push_back(actual); } // Remap value. args.push_back(val); // Inline body. rewriter.mergeBlockBefore(srcBlock, &*inlinePos, args); // delete the foreach operator. rewriter.eraseOp(op); return success(); } }; /// Sparse rewriting rule for the new operator. struct NewRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(NewOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto dstTp = op.getResult().getType().template cast(); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); if (!encDst) { return failure(); } // Create a sparse tensor reader. Value fileName = op.getSource(); Type opaqueTp = getOpaquePointerType(rewriter); Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader", {opaqueTp}, {fileName}, EmitCInterface::Off) .getResult(0); // Allocate a buffer for storing dimension sizes and indices. Type indexTp = rewriter.getIndexType(); auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); uint64_t rank = dstTp.getRank(); Value dimSizes = rewriter.create( loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)}); // If the result tensor has dynamic dimensions, get the dynamic sizes from // the sparse tensor reader. SmallVector dynSizesArray; if (!dstTp.hasStaticShape()) { createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On) .getResult(0); ArrayRef dstShape = dstTp.getShape(); for (auto &d : llvm::enumerate(dstShape)) { if (d.value() == ShapedType::kDynamicSize) { dynSizesArray.push_back(rewriter.create( loc, dimSizes, constantIndex(rewriter, loc, d.index()))); } } } // Implement the NewOp as follows: // %tmp = bufferization.alloc_tensor : an unordered COO with identity // storage ordering // for i = 0 to nnz // get the next element from the input file // insert the element to %tmp // %t = sparse_tensor.ConvertOp %tmp RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); auto cooBuffer = rewriter.create(loc, cooTp, dynSizesArray).getResult(); Value c0 = constantIndex(rewriter, loc, 0); Value c1 = constantIndex(rewriter, loc, 1); Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1); rewriter.setInsertionPointToStart(forOp.getBody()); Type eltTp = dstTp.getElementType(); SmallString<18> getNextFuncName{"getSparseTensorReaderNext", primaryTypeFunctionSuffix(eltTp)}; Value indices = dimSizes; // Reuse the indices memref to store indices. Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp}, {reader, indices}, EmitCInterface::On) .getResult(0); SmallVector indicesArray; for (uint64_t i = 0; i < rank; i++) { indicesArray.push_back(rewriter.create( loc, indices, constantIndex(rewriter, loc, i))); } rewriter.create(loc, value, cooBuffer, indicesArray); rewriter.setInsertionPointAfter(forOp); // Release the indices buffer and the sparse tensor reader. rewriter.create(loc, indices); createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, EmitCInterface::Off); Value newOp = rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); // Release the unordered COO tensor buffer. rewriter.setInsertionPointAfterValue(newOp); rewriter.create(loc, cooBuffer); return success(); } }; } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, bool enableForeach, bool /*enableConvert*/) { patterns.add, ReshapeRewriter>(patterns.getContext()); if (enableForeach) patterns.add(patterns.getContext()); // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) patterns.add, Sparse2SparseReshapeRewriter>( patterns.getContext()); }