llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
2022-10-31 08:19:12 -07:00

645 lines
25 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting rules that are specific to sparse tensor
// primitives with memref operands.
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
//===---------------------------------------------------------------------===//
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//
static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;
static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_";
static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
static constexpr const char kSortNonstableFuncNamePrefix[] =
"_sparse_sort_nonstable_";
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
using FuncGeneratorType =
function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>;
/// Constructs a function name with this format to facilitate quick sort:
/// <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
StringRef namePrefix, size_t dim,
ValueRange operands) {
nameOstream
<< namePrefix << dim << "_"
<< operands[xStartIdx].getType().cast<MemRefType>().getElementType();
for (Value v : operands.drop_front(xStartIdx + dim))
nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
}
/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet.
static FlatSymbolRefAttr
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
TypeRange resultTypes, StringRef namePrefix,
size_t dim, ValueRange operands,
FuncGeneratorType createFunc) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, nameOstream.str());
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
// Create the function.
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(insertPoint);
Location loc = insertPoint.getLoc();
func = builder.create<func::FuncOp>(
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
createFunc(builder, module, func, dim);
}
return result;
}
/// Creates a function for swapping the values in index i and j for all the
/// buffers.
//
// The generate IR corresponds to this C like algorithm:
// if (i != j) {
// swap(x0[i], x0[j]);
// swap(x1[i], x1[j]);
// ...
// swap(xn[i], xn[j]);
// swap(y0[i], y0[j]);
// ...
// swap(yn[i], yn[j]);
// }
static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value i = args[0];
Value j = args[1];
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, i, j);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
// If i!=j swap values in the buffers.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
for (auto arg : args.drop_front(xStartIdx)) {
Value vi = builder.create<memref::LoadOp>(loc, arg, i);
Value vj = builder.create<memref::LoadOp>(loc, arg, j);
builder.create<memref::StoreOp>(loc, vj, arg, i);
builder.create<memref::StoreOp>(loc, vi, arg, j);
}
builder.setInsertionPointAfter(ifOp);
builder.create<func::ReturnOp>(loc);
}
/// Generates an if-statement to compare x[i] and x[j].
static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
Value i, Value j, Value x,
bool isLastDim) {
Value f = constantI1(builder, loc, false);
Value t = constantI1(builder, loc, true);
Value vi = builder.create<memref::LoadOp>(loc, x, i);
Value vj = builder.create<memref::LoadOp>(loc, x, j);
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
// If (x[i] < x[j]).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<scf::YieldOp>(loc, t);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
if (isLastDim == 1) {
// Finish checking all dimensions.
builder.create<scf::YieldOp>(loc, f);
} else {
cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
scf::IfOp ifOp2 =
builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
// Otherwise if (x[j] < x[i]).
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
builder.create<scf::YieldOp>(loc, f);
// Otherwise check the remaining dimensions.
builder.setInsertionPointAfter(ifOp2);
builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
// Set up the insertion point for the nested if-stmt that checks the
// remaining dimensions.
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
}
return ifOp;
}
/// Creates a function to compare the xs values in index i and j for all the
/// dimensions. The function returns true iff xs[i] < xs[j].
//
// The generate IR corresponds to this C like algorithm:
// if (x0[i] < x0[j])
// return true;
// else if (x0[j] < x0[i])
// return false;
// else
// if (x1[i] < x1[j])
// return true;
// else if (x1[j] < x1[i]))
// and so on ...
static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
scf::IfOp topIfOp;
for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
scf::IfOp ifOp =
createLessThanCompare(builder, loc, args[0], args[1], item.value(),
(item.index() == dim - 1));
if (item.index() == 0) {
topIfOp = ifOp;
} else {
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
}
}
builder.setInsertionPointAfter(topIfOp);
builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
}
/// Creates a function to use a binary search to find the insertion point for
/// inserting xs[hi] to the sorted values xs[lo..hi).
//
// The generate IR corresponds to this C like algorithm:
// p = hi
// while (lo < hi)
// mid = (lo + hi) >> 1
// if (xs[p] < xs[mid])
// hi = mid
// else
// lo = mid - 1
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value p = args[hiIdx];
SmallVector<Type, 2> types(2, p.getType());
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
// The before-region of the WhileOp.
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(before);
Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
before->getArgument(0),
before->getArgument(1));
builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
// The after-region of the WhileOp.
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(after);
Value lo = after->getArgument(0);
Value hi = after->getArgument(1);
// Compute mid = (lo + hi) >> 1.
Value c1 = constantIndex(builder, loc, 1);
Value mid = builder.create<arith::ShRUIOp>(
loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
// Compare xs[p] < xs[mid].
SmallVector<Value, 6> compareOperands{p, mid};
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + dim);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc =
getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
dim, compareOperands, createLessThanFunc);
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);
// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
// else
// lo = mid + 1;
Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
}
/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
// The generated IR corresponds to this C like algorithm:
// int partition(lo, hi, data) {
// pivot = data[hi - 1];
// i = (lo 1) // RHS of the pivot found so far.
// for (j = lo; j < hi - 1; j++){
// if (data[j] < pivot){
// i++;
// swap data[i] and data[j]
// }
// }
// i++
// swap data[i] and data[hi-1])
// return i
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
Value c1 = constantIndex(builder, loc, 1);
Value i = builder.create<arith::SubIOp>(loc, lo, c1);
Value him1 = builder.create<arith::SubIOp>(loc, args[hiIdx], c1);
scf::ForOp forOp =
builder.create<scf::ForOp>(loc, lo, him1, c1, ValueRange{i});
// Start the for-stmt body.
builder.setInsertionPointToStart(forOp.getBody());
Value j = forOp.getInductionVar();
SmallVector<Value, 6> compareOperands{j, him1};
ValueRange xs = args.slice(xStartIdx, dim);
compareOperands.append(xs.begin(), xs.end());
Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc =
getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
dim, compareOperands, createLessThanFunc);
Value cond = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, i.getType(), cond, /*else=*/true);
// The if-stmt true branch: i++; swap(data[i], data[j]); yield i.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value i1 =
builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
SmallVector<Value, 6> swapOperands{i1, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands,
createMaySwapFunc);
builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
builder.create<scf::YieldOp>(loc, i1);
// The if-stmt false branch: yield i.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, forOp.getRegionIterArgs().front());
// After the if-stmt, yield the updated i value to end the for-stmt body.
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
// After the for-stmt: i++; swap(data[i], data[him1]); return i.
builder.setInsertionPointAfter(forOp);
i1 = builder.create<arith::AddIOp>(loc, forOp.getResult(0), c1);
swapOperands[0] = i1;
swapOperands[1] = him1;
builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
builder.create<func::ReturnOp>(loc, i1);
}
/// Creates a function to perform quick sort on the value in the range of
/// index [lo, hi).
//
// The generate IR corresponds to this C like algorithm:
// void quickSort(lo, hi, data) {
// if (lo < hi) {
// p = partition(low, high, data);
// quickSort(lo, p, data);
// quickSort(p + 1, hi, data);
// }
// }
static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
// The if-stmt true branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
args, createPartitionFunc);
auto p = builder.create<func::CallOp>(
loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
SmallVector<Value, 6> lowOperands{lo, p.getResult(0)};
lowOperands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, lowOperands);
SmallVector<Value, 6> highOperands{
builder.create<arith::AddIOp>(loc, p.getResult(0),
constantIndex(builder, loc, 1)),
hi};
highOperands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, highOperands);
// After the if-stmt.
builder.setInsertionPointAfter(ifOp);
builder.create<func::ReturnOp>(loc);
}
/// Creates a function to perform insertion sort on the values in the range of
/// index [lo, hi).
//
// The generate IR corresponds to this C like algorithm:
// void insertionSort(lo, hi, data) {
// for (i = lo+1; i < hi; i++) {
// d = data[i];
// p = binarySearch(lo, i-1, data)
// for (j = 0; j > i - p; j++)
// data[i-j] = data[i-j-1]
// data[p] = d
// }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value c1 = constantIndex(builder, loc, 1);
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
// Start the outer for-stmt with induction variable i.
scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
builder.setInsertionPointToStart(forOpI.getBody());
Value i = forOpI.getInductionVar();
// Binary search to find the insertion point p.
SmallVector<Value, 6> operands{lo, i};
operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
dim, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
.getResult(0);
// Move the value at data[i] to a temporary location.
ValueRange data = args.drop_front(xStartIdx);
SmallVector<Value, 6> d;
for (Value v : data)
d.push_back(builder.create<memref::LoadOp>(loc, v, i));
// Start the inner for-stmt with induction variable j, for moving data[p..i)
// to data[p+1..i+1).
Value imp = builder.create<arith::SubIOp>(loc, i, p);
Value c0 = constantIndex(builder, loc, 0);
scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
builder.setInsertionPointToStart(forOpJ.getBody());
Value j = forOpJ.getInductionVar();
Value imj = builder.create<arith::SubIOp>(loc, i, j);
Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
for (Value v : data) {
Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
builder.create<memref::StoreOp>(loc, t, v, imj);
}
// Store the value at data[i] to data[p].
builder.setInsertionPointAfter(forOpJ);
for (auto it : llvm::zip(d, data))
builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);
builder.setInsertionPointAfter(forOpI);
builder.create<func::ReturnOp>(loc);
}
//===---------------------------------------------------------------------===//
// The actual sparse buffer rewriting rules.
//===---------------------------------------------------------------------===//
namespace {
/// Sparse rewriting rule for the push_back operator.
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
using OpRewritePattern<PushBackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PushBackOp op,
PatternRewriter &rewriter) const override {
// Rewrite push_back(buffer, value, n) to:
// new_size = size(buffer) + n
// if (new_size > capacity(buffer))
// while new_size > new_capacity
// new_capacity = new_capacity*2
// new_buffer = realloc(buffer, new_capacity)
// buffer = new_buffer
// subBuffer = subviewof(buffer)
// linalg.fill subBuffer value
//
// size(buffer) += n
//
// The capacity check is skipped when the attribute inbounds is presented.
Location loc = op->getLoc();
Value c0 = constantIndex(rewriter, loc, 0);
Value buffer = op.getInBuffer();
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
Value bufferSizes = op.getBufferSizes();
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
Value value = op.getValue();
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
bool nIsOne = (nValue && nValue.value() == 1);
if (!op.getInbounds()) {
Value cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, newSize, capacity);
Value c2 = constantIndex(rewriter, loc, 2);
auto bufferType =
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
/*else=*/true);
// True branch.
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
if (nIsOne) {
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
} else {
// Use a do-while loop to calculate the new capacity as follows:
// do { new_capacity *= 2 } while (size > new_capacity)
scf::WhileOp whileOp =
rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
// The before-region of the WhileOp.
Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
{capacity.getType()}, {loc});
rewriter.setInsertionPointToEnd(before);
capacity =
rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
newSize, capacity);
rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
// The after-region of the WhileOp.
Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
{capacity.getType()}, {loc});
rewriter.setInsertionPointToEnd(after);
rewriter.create<scf::YieldOp>(loc, after->getArguments());
rewriter.setInsertionPointAfter(whileOp);
capacity = whileOp.getResult(0);
}
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
rewriter.create<scf::YieldOp>(loc, newBuffer);
// False branch.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, buffer);
// Prepare for adding the value to the end of the buffer.
rewriter.setInsertionPointAfter(ifOp);
buffer = ifOp.getResult(0);
}
// Add the value to the end of the buffer.
if (nIsOne) {
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
} else {
Value subBuffer = rewriter.create<memref::SubViewOp>(
loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
rewriter.create<linalg::FillOp>(loc, value, subBuffer);
}
// Update the buffer size.
rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
rewriter.replaceOp(op, buffer);
return success();
}
};
/// Sparse rewriting rule for the sort operator.
struct SortRewriter : public OpRewritePattern<SortOp> {
public:
using OpRewritePattern<SortOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SortOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
// Convert `values` to have dynamic shape and append them to `operands`.
auto addValues = [&](ValueRange values) {
for (Value v : values) {
auto mtp = v.getType().cast<MemRefType>();
if (!mtp.isDynamicDim(0)) {
auto newMtp =
MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
v = rewriter.create<memref::CastOp>(loc, newMtp, v);
}
operands.push_back(v);
}
};
ValueRange xs = op.getXs();
addValues(xs);
addValues(op.getYs());
auto insertPoint = op->getParentOfType<func::FuncOp>();
SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
: kSortNonstableFuncNamePrefix);
FuncGeneratorType funcGenerator =
op.getStable() ? createSortStableFunc : createSortNonstableFunc;
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
xs.size(), operands, funcGenerator);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
}