198 lines
7.2 KiB
C++
198 lines
7.2 KiB
C++
|
|
#include "Utils/CodegenUtils.h"
|
|
#include "Utils/SparseTensorIterator.h"
|
|
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Transforms/OneToNTypeConversion.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
|
|
SmallVectorImpl<Type> &fields) {
|
|
// Position and coordinate buffer in the sparse structure.
|
|
if (enc.getLvlType(lvl).isWithPosLT())
|
|
fields.push_back(enc.getPosMemRefType());
|
|
if (enc.getLvlType(lvl).isWithCrdLT())
|
|
fields.push_back(enc.getCrdMemRefType());
|
|
// One index for shape bound (result from lvlOp).
|
|
fields.push_back(IndexType::get(enc.getContext()));
|
|
}
|
|
|
|
static std::optional<LogicalResult>
|
|
convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
|
|
|
|
auto idxTp = IndexType::get(itSp.getContext());
|
|
for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
|
|
convertLevelType(itSp.getEncoding(), l, fields);
|
|
|
|
// Two indices for lower and upper bound (we only need one pair for the last
|
|
// iteration space).
|
|
fields.append({idxTp, idxTp});
|
|
return success();
|
|
}
|
|
|
|
static std::optional<LogicalResult>
|
|
convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
|
|
// The actually Iterator Values (that are updated every iteration).
|
|
auto idxTp = IndexType::get(itTp.getContext());
|
|
// TODO: handle batch dimension.
|
|
assert(itTp.getEncoding().getBatchLvlRank() == 0);
|
|
if (!itTp.isUnique()) {
|
|
// Segment high for non-unique iterator.
|
|
fields.push_back(idxTp);
|
|
}
|
|
fields.push_back(idxTp);
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Sparse codegen rule for number of entries operator.
|
|
class ExtractIterSpaceConverter
|
|
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
|
|
public:
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
|
|
|
|
// Construct the iteration space.
|
|
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
|
|
op.getLvlRange(), adaptor.getParentIter());
|
|
|
|
SmallVector<Value> result = space.toValues();
|
|
rewriter.replaceOp(op, result, resultMapping);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
|
|
public:
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
if (!op.getCrdUsedLvls().empty())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "non-empty coordinates list not implemented.");
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto iterSpace = SparseIterationSpace::fromValues(
|
|
op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
|
|
|
|
std::unique_ptr<SparseIterator> it =
|
|
iterSpace.extractIterator(rewriter, loc);
|
|
|
|
if (it->iteratableByFor()) {
|
|
auto [lo, hi] = it->genForCond(rewriter, loc);
|
|
Value step = constantIndex(rewriter, loc, 1);
|
|
SmallVector<Value> ivs;
|
|
for (ValueRange inits : adaptor.getInitArgs())
|
|
llvm::append_range(ivs, inits);
|
|
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
|
|
|
|
Block *loopBody = op.getBody();
|
|
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
|
|
if (failed(typeConverter->convertSignatureArgs(
|
|
loopBody->getArgumentTypes(), bodyTypeMapping)))
|
|
return failure();
|
|
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
|
|
|
|
rewriter.eraseBlock(forOp.getBody());
|
|
Region &dstRegion = forOp.getRegion();
|
|
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
|
|
|
|
auto yieldOp =
|
|
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
|
|
|
|
rewriter.setInsertionPointToEnd(forOp.getBody());
|
|
// replace sparse_tensor.yield with scf.yield.
|
|
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
|
|
rewriter.eraseOp(yieldOp);
|
|
|
|
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
|
|
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
|
|
} else {
|
|
SmallVector<Value> ivs;
|
|
llvm::append_range(ivs, it->getCursor());
|
|
for (ValueRange inits : adaptor.getInitArgs())
|
|
llvm::append_range(ivs, inits);
|
|
|
|
assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
|
|
|
|
TypeRange types = ValueRange(ivs).getTypes();
|
|
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
|
|
SmallVector<Location> l(types.size(), op.getIterator().getLoc());
|
|
|
|
// Generates loop conditions.
|
|
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
|
|
rewriter.setInsertionPointToStart(before);
|
|
ValueRange bArgs = before->getArguments();
|
|
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
|
|
assert(remArgs.size() == adaptor.getInitArgs().size());
|
|
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
|
|
|
|
// Generates loop body.
|
|
Block *loopBody = op.getBody();
|
|
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
|
|
if (failed(typeConverter->convertSignatureArgs(
|
|
loopBody->getArgumentTypes(), bodyTypeMapping)))
|
|
return failure();
|
|
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
|
|
|
|
Region &dstRegion = whileOp.getAfter();
|
|
// TODO: handle uses of coordinate!
|
|
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
|
|
ValueRange aArgs = whileOp.getAfterArguments();
|
|
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
|
|
whileOp.getAfterBody()->getTerminator());
|
|
|
|
rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
|
|
|
|
aArgs = it->linkNewScope(aArgs);
|
|
ValueRange nx = it->forward(rewriter, loc);
|
|
SmallVector<Value> yields;
|
|
llvm::append_range(yields, nx);
|
|
llvm::append_range(yields, yieldOp.getResults());
|
|
|
|
// replace sparse_tensor.yield with scf.yield.
|
|
rewriter.eraseOp(yieldOp);
|
|
rewriter.create<scf::YieldOp>(loc, yields);
|
|
|
|
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
|
|
rewriter.replaceOp(
|
|
op, whileOp.getResults().drop_front(it->getCursor().size()),
|
|
resultMapping);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
|
|
addConversion([](Type type) { return type; });
|
|
addConversion(convertIteratorType);
|
|
addConversion(convertIterSpaceType);
|
|
|
|
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
|
|
ValueRange inputs,
|
|
Location loc) -> std::optional<Value> {
|
|
return builder
|
|
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
|
|
.getResult(0);
|
|
});
|
|
}
|
|
|
|
void mlir::populateLowerSparseIterationToSCFPatterns(
|
|
TypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
|
|
converter, patterns.getContext());
|
|
}
|