
This PR uses `val.getDefiningOp<OpTy>()` to replace `dyn_cast<OpTy>(val.getDefiningOp())` , `dyn_cast_or_null<OpTy>(val.getDefiningOp())` and `dyn_cast_if_present<OpTy>(val.getDefiningOp())`.
1437 lines
58 KiB
C++
1437 lines
58 KiB
C++
//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements rewriting rules that are specific to sparse tensor
|
|
// primitives with memref operands.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Utils/CodegenUtils.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Helper methods for the actual rewriting rules.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
static constexpr uint64_t loIdx = 0;
|
|
static constexpr uint64_t hiIdx = 1;
|
|
static constexpr uint64_t xStartIdx = 2;
|
|
|
|
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
|
|
static constexpr const char kBinarySearchFuncNamePrefix[] =
|
|
"_sparse_binary_search_";
|
|
static constexpr const char kHybridQuickSortFuncNamePrefix[] =
|
|
"_sparse_hybrid_qsort_";
|
|
static constexpr const char kSortStableFuncNamePrefix[] =
|
|
"_sparse_sort_stable_";
|
|
static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
|
|
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
|
|
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
|
|
|
|
using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
|
|
AffineMap, uint64_t, uint32_t)>;
|
|
|
|
/// Constructs a function name with this format to facilitate quick sort:
|
|
/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
|
|
/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
|
|
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
|
|
StringRef namePrefix, AffineMap xPerm,
|
|
uint64_t ny, ValueRange operands) {
|
|
nameOstream << namePrefix;
|
|
for (auto res : xPerm.getResults())
|
|
nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
|
|
|
|
nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
|
|
nameOstream << "_coo_" << ny;
|
|
|
|
constexpr uint64_t yBufferOffset = 1;
|
|
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
|
|
nameOstream << "_" << getMemRefType(v).getElementType();
|
|
}
|
|
|
|
/// Looks up a function that is appropriate for the given operands being
|
|
/// sorted, and creates such a function if it doesn't exist yet. The
|
|
/// parameters `xPerm` and `ny` tell the number of x and y values provided
|
|
/// by the buffer in xStartIdx.
|
|
//
|
|
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
|
|
// parameters for the sorting functions. Other parameters, such as the recursive
|
|
// call depth, are appended to the end of the parameter list as
|
|
// "trailing parameters".
|
|
static FlatSymbolRefAttr getMangledSortHelperFunc(
|
|
OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
|
|
StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
|
|
FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
|
|
SmallString<32> nameBuffer;
|
|
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
|
getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
|
|
operands.drop_back(nTrailingP));
|
|
|
|
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
|
|
MLIRContext *context = module.getContext();
|
|
auto result = SymbolRefAttr::get(context, nameOstream.str());
|
|
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
|
|
|
|
if (!func) {
|
|
// Create the function.
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
builder.setInsertionPoint(insertPoint);
|
|
Location loc = insertPoint.getLoc();
|
|
func = func::FuncOp::create(
|
|
builder, loc, nameOstream.str(),
|
|
FunctionType::get(context, operands.getTypes(), resultTypes));
|
|
func.setPrivate();
|
|
createFunc(builder, module, func, xPerm, ny, nTrailingP);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
|
|
/// The code to process the value pairs is generated by `bodyBuilder`.
|
|
static void forEachIJPairInXs(
|
|
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
|
uint64_t ny,
|
|
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
|
Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
|
|
Value iOffset = arith::MulIOp::create(builder, loc, args[0], cstep);
|
|
Value jOffset = arith::MulIOp::create(builder, loc, args[1], cstep);
|
|
for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
|
|
unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
|
|
Value ak = constantIndex(builder, loc, actualK);
|
|
Value i = arith::AddIOp::create(builder, loc, ak, iOffset);
|
|
Value j = arith::AddIOp::create(builder, loc, ak, jOffset);
|
|
Value buffer = args[xStartIdx];
|
|
|
|
bodyBuilder(k, i, j, buffer);
|
|
}
|
|
}
|
|
|
|
/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
|
|
/// The code to process the value pairs is generated by `bodyBuilder`.
|
|
static void forEachIJPairInAllBuffers(
|
|
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
|
uint64_t ny,
|
|
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
|
|
|
// Create code for the first (xPerm + ny) buffers.
|
|
SmallVector<AffineExpr> exps(xPerm.getResults());
|
|
for (unsigned y = 0; y < ny; y++) {
|
|
exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
|
|
}
|
|
AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
|
|
assert(xyPerm.isPermutation());
|
|
|
|
forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
|
|
|
|
constexpr uint64_t numHandledBuffers = 1;
|
|
// Create code for the remaining buffers.
|
|
Value i = args[0];
|
|
Value j = args[1];
|
|
for (const auto &arg :
|
|
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
|
|
bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
|
|
}
|
|
}
|
|
|
|
/// Creates a code block for swapping the values in index i and j for all the
|
|
/// buffers.
|
|
//
|
|
// The generated IR corresponds to this C like algorithm:
|
|
// swap(x0[i], x0[j]);
|
|
// swap(x1[i], x1[j]);
|
|
// ...
|
|
// swap(xn[i], xn[j]);
|
|
// swap(y0[i], y0[j]);
|
|
// ...
|
|
// swap(yn[i], yn[j]);
|
|
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
|
|
AffineMap xPerm, uint64_t ny) {
|
|
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
|
|
Value vi = memref::LoadOp::create(builder, loc, buffer, i);
|
|
Value vj = memref::LoadOp::create(builder, loc, buffer, j);
|
|
memref::StoreOp::create(builder, loc, vj, buffer, i);
|
|
memref::StoreOp::create(builder, loc, vi, buffer, j);
|
|
};
|
|
|
|
forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
|
|
}
|
|
|
|
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
|
|
/// each pair is create via `compareBuilder`.
|
|
static Value createInlinedCompareImplementation(
|
|
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
|
uint64_t ny,
|
|
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
|
|
compareBuilder) {
|
|
Value result;
|
|
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
|
|
bool isFirstDim = (k == 0);
|
|
bool isLastDim = (k == xPerm.getNumResults() - 1);
|
|
Value val =
|
|
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
|
|
if (isFirstDim) {
|
|
result = val;
|
|
} else if (!isLastDim) {
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
|
|
builder.setInsertionPointAfter(ifOp);
|
|
scf::YieldOp::create(builder, loc, ifOp.getResult(0));
|
|
}
|
|
};
|
|
|
|
forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
|
|
|
|
builder.setInsertionPointAfterValue(result);
|
|
return result;
|
|
}
|
|
|
|
/// Generates code to compare whether x[i] is equal to x[j] and returns the
|
|
/// result of the comparison.
|
|
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
|
|
Value x, bool isFirstDim, bool isLastDim) {
|
|
Value vi = memref::LoadOp::create(builder, loc, x, i);
|
|
Value vj = memref::LoadOp::create(builder, loc, x, j);
|
|
|
|
Value res;
|
|
if (isLastDim) {
|
|
res = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, vi, vj);
|
|
// For 1D, we create a compare without any control flow. Otherwise, we
|
|
// create YieldOp to return the result in the nested if-stmt.
|
|
if (!isFirstDim)
|
|
scf::YieldOp::create(builder, loc, res);
|
|
} else {
|
|
Value ne =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1),
|
|
ne, /*else=*/true);
|
|
// If (x[i] != x[j]).
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
Value f = constantI1(builder, loc, false);
|
|
scf::YieldOp::create(builder, loc, f);
|
|
|
|
// If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
|
|
// checks the remaining dimensions.
|
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
res = ifOp.getResult(0);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
/// Creates code to compare whether xs[i] is equal to xs[j].
|
|
//
|
|
// The generate IR corresponds to this C like algorithm:
|
|
// if (x0[i] != x0[j])
|
|
// return false;
|
|
// else
|
|
// if (x1[i] != x1[j])
|
|
// return false;
|
|
// else if (x2[2] != x2[j]))
|
|
// and so on ...
|
|
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
|
|
ValueRange args, AffineMap xPerm,
|
|
uint64_t ny, uint32_t nTrailingP = 0) {
|
|
// Compare functions don't use trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
|
|
createEqCompare);
|
|
}
|
|
|
|
/// Generates code to compare whether x[i] is less than x[j] and returns the
|
|
/// result of the comparison.
|
|
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
|
|
Value j, Value x, bool isFirstDim,
|
|
bool isLastDim) {
|
|
Value vi = memref::LoadOp::create(builder, loc, x, i);
|
|
Value vj = memref::LoadOp::create(builder, loc, x, j);
|
|
|
|
Value res;
|
|
if (isLastDim) {
|
|
res =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
|
|
// For 1D, we create a compare without any control flow. Otherwise, we
|
|
// create YieldOp to return the result in the nested if-stmt.
|
|
if (!isFirstDim)
|
|
scf::YieldOp::create(builder, loc, res);
|
|
} else {
|
|
Value ne =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1),
|
|
ne, /*else=*/true);
|
|
// If (x[i] != x[j]).
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
Value lt =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
|
|
scf::YieldOp::create(builder, loc, lt);
|
|
|
|
// If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
|
|
// checks the remaining dimensions.
|
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
res = ifOp.getResult(0);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
/// Creates code to compare whether xs[i] is less than xs[j].
|
|
//
|
|
// The generate IR corresponds to this C like algorithm:
|
|
// if (x0[i] != x0[j])
|
|
// return x0[i] < x0[j];
|
|
// else if (x1[j] != x1[i])
|
|
// return x1[i] < x1[j];
|
|
// else
|
|
// and so on ...
|
|
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
|
|
ValueRange args, AffineMap xPerm,
|
|
uint64_t ny, uint32_t nTrailingP = 0) {
|
|
// Compare functions don't use trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
|
|
createLessThanCompare);
|
|
}
|
|
|
|
/// Creates a function to use a binary search to find the insertion point for
|
|
/// inserting xs[hi] to the sorted values xs[lo..hi).
|
|
//
|
|
// The generate IR corresponds to this C like algorithm:
|
|
// p = hi
|
|
// while (lo < hi)
|
|
// mid = (lo + hi) >> 1
|
|
// if (xs[p] < xs[mid])
|
|
// hi = mid
|
|
// else
|
|
// lo = mid - 1
|
|
// return lo;
|
|
//
|
|
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm,
|
|
uint64_t ny, uint32_t nTrailingP = 0) {
|
|
// Binary search doesn't use trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
Location loc = func.getLoc();
|
|
ValueRange args = entryBlock->getArguments();
|
|
Value p = args[hiIdx];
|
|
SmallVector<Type, 2> types(2, p.getType()); // Only two types.
|
|
scf::WhileOp whileOp = scf::WhileOp::create(
|
|
builder, loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
|
|
|
|
// The before-region of the WhileOp.
|
|
Block *before =
|
|
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
|
|
builder.setInsertionPointToEnd(before);
|
|
Value cond1 =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
|
|
before->getArgument(0), before->getArgument(1));
|
|
scf::ConditionOp::create(builder, loc, cond1, before->getArguments());
|
|
|
|
// The after-region of the WhileOp.
|
|
Block *after =
|
|
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
|
|
builder.setInsertionPointToEnd(after);
|
|
Value lo = after->getArgument(0);
|
|
Value hi = after->getArgument(1);
|
|
// Compute mid = (lo + hi) >> 1.
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
Value mid = arith::ShRUIOp::create(
|
|
builder, loc, arith::AddIOp::create(builder, loc, lo, hi), c1);
|
|
Value midp1 = arith::AddIOp::create(builder, loc, mid, c1);
|
|
|
|
// Compare xs[p] < xs[mid].
|
|
SmallVector<Value> compareOperands{p, mid};
|
|
constexpr uint64_t numXBuffers = 1;
|
|
compareOperands.append(args.begin() + xStartIdx,
|
|
args.begin() + xStartIdx + numXBuffers);
|
|
Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
|
// Update lo and hi for the WhileOp as follows:
|
|
// if (xs[p] < xs[mid]))
|
|
// hi = mid;
|
|
// else
|
|
// lo = mid + 1;
|
|
Value newLo = arith::SelectOp::create(builder, loc, cond2, lo, midp1);
|
|
Value newHi = arith::SelectOp::create(builder, loc, cond2, mid, hi);
|
|
scf::YieldOp::create(builder, loc, ValueRange{newLo, newHi});
|
|
|
|
builder.setInsertionPointAfter(whileOp);
|
|
func::ReturnOp::create(builder, loc, whileOp.getResult(0));
|
|
}
|
|
|
|
/// Creates code to advance i in a loop based on xs[p] as follows:
|
|
/// while (xs[i] < xs[p]) i += step (step > 0)
|
|
/// or
|
|
/// while (xs[i] > xs[p]) i += step (step < 0)
|
|
/// The routine returns i as well as a boolean value to indicate whether
|
|
/// xs[i] == xs[p].
|
|
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
|
|
ModuleOp module,
|
|
func::FuncOp func, ValueRange xs,
|
|
Value i, Value p, AffineMap xPerm,
|
|
uint64_t ny, int step) {
|
|
Location loc = func.getLoc();
|
|
scf::WhileOp whileOp =
|
|
scf::WhileOp::create(builder, loc, TypeRange{i.getType()}, ValueRange{i});
|
|
|
|
Block *before =
|
|
builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
|
|
builder.setInsertionPointToEnd(before);
|
|
SmallVector<Value> compareOperands;
|
|
if (step > 0) {
|
|
compareOperands.push_back(before->getArgument(0));
|
|
compareOperands.push_back(p);
|
|
} else {
|
|
assert(step < 0);
|
|
compareOperands.push_back(p);
|
|
compareOperands.push_back(before->getArgument(0));
|
|
}
|
|
compareOperands.append(xs.begin(), xs.end());
|
|
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
|
scf::ConditionOp::create(builder, loc, cond, before->getArguments());
|
|
|
|
Block *after =
|
|
builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
|
|
builder.setInsertionPointToEnd(after);
|
|
Value cs = constantIndex(builder, loc, step);
|
|
i = arith::AddIOp::create(builder, loc, after->getArgument(0), cs);
|
|
scf::YieldOp::create(builder, loc, ValueRange{i});
|
|
i = whileOp.getResult(0);
|
|
|
|
builder.setInsertionPointAfter(whileOp);
|
|
compareOperands[0] = i;
|
|
compareOperands[1] = p;
|
|
Value compareEq =
|
|
createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
|
|
|
|
return std::make_pair(whileOp.getResult(0), compareEq);
|
|
}
|
|
|
|
/// Creates and returns an IfOp to compare two elements and swap the elements
|
|
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
|
|
/// right after the swap instructions.
|
|
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
|
|
AffineMap xPerm, uint64_t ny,
|
|
SmallVectorImpl<Value> &swapOperands,
|
|
SmallVectorImpl<Value> &compareOperands,
|
|
Value a, Value b) {
|
|
// Compare(data[b], data[a]).
|
|
compareOperands[0] = b;
|
|
compareOperands[1] = a;
|
|
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false);
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
swapOperands[0] = b;
|
|
swapOperands[1] = a;
|
|
createSwap(builder, loc, swapOperands, xPerm, ny);
|
|
return ifOp;
|
|
}
|
|
|
|
/// Creates code to insert the 3rd element to a list of two sorted elements.
|
|
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
|
|
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
|
SmallVectorImpl<Value> &compareOperands, Value v0,
|
|
Value v1, Value v2) {
|
|
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
|
compareOperands, v1, v2);
|
|
createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
|
|
v0, v1);
|
|
builder.setInsertionPointAfter(ifOp);
|
|
}
|
|
|
|
/// Creates code to sort 3 elements.
|
|
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
|
|
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
|
SmallVectorImpl<Value> &compareOperands, Value v0,
|
|
Value v1, Value v2) {
|
|
// Sort the first 2 elements.
|
|
scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
|
compareOperands, v0, v1);
|
|
builder.setInsertionPointAfter(ifOp1);
|
|
|
|
// Insert the 3th element.
|
|
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
|
|
v1, v2);
|
|
}
|
|
|
|
/// Creates code to sort 5 elements.
|
|
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
|
|
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
|
SmallVectorImpl<Value> &compareOperands, Value v0,
|
|
Value v1, Value v2, Value v3, Value v4) {
|
|
// Sort the first 3 elements.
|
|
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
|
|
v2);
|
|
|
|
auto insert4th = [&]() {
|
|
scf::IfOp ifOp = createCompareThenSwap(
|
|
builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
|
|
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
|
|
v1, v2);
|
|
builder.setInsertionPointAfter(ifOp);
|
|
};
|
|
|
|
// Insert the 4th element.
|
|
insert4th();
|
|
|
|
// Insert the 5th element.
|
|
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
|
compareOperands, v3, v4);
|
|
insert4th();
|
|
builder.setInsertionPointAfter(ifOp);
|
|
}
|
|
|
|
/// Creates a code block to swap the values in indices lo, mi, and hi so that
|
|
/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
|
|
/// the number of values in range [lo, hi) is more than a threshold, we also
|
|
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
|
|
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
|
Value lo, Value hi, Value mi, ValueRange args) {
|
|
SmallVector<Value> compareOperands{mi, lo};
|
|
constexpr uint64_t numXBuffers = 1;
|
|
compareOperands.append(args.begin() + xStartIdx,
|
|
args.begin() + xStartIdx + numXBuffers);
|
|
SmallVector<Value> swapOperands{mi, lo};
|
|
swapOperands.append(args.begin() + xStartIdx, args.end());
|
|
Location loc = func.getLoc();
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
Value hiP1 = arith::AddIOp::create(builder, loc, hi, c1);
|
|
Value len = arith::SubIOp::create(builder, loc, hiP1, lo);
|
|
Value lenThreshold = constantIndex(builder, loc, 1000);
|
|
Value lenCond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
|
|
len, lenThreshold);
|
|
scf::IfOp lenIf = scf::IfOp::create(builder, loc, lenCond, /*else=*/true);
|
|
|
|
// When len < 1000, choose pivot from median of 3 values.
|
|
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
|
|
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
|
|
hi);
|
|
|
|
// When len >= 1000, choose pivot from median of 5 values.
|
|
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
|
|
Value miP1 = arith::AddIOp::create(builder, loc, hi, c1);
|
|
Value a = arith::AddIOp::create(builder, loc, lo, miP1);
|
|
// Value a is the middle between [loc, mi].
|
|
a = arith::ShRUIOp::create(builder, loc, a, c1);
|
|
Value b = arith::AddIOp::create(builder, loc, mi, hiP1);
|
|
// Value b is the middle between [mi, hi].
|
|
b = arith::ShRUIOp::create(builder, loc, b, c1);
|
|
createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
|
|
b, hi);
|
|
|
|
builder.setInsertionPointAfter(lenIf);
|
|
}
|
|
|
|
/// Creates a function to perform quick sort partition on the values in the
|
|
/// range of index [lo, hi), assuming lo < hi.
|
|
//
|
|
// The generated IR corresponds to this C like algorithm:
|
|
// int partition(lo, hi, xs) {
|
|
// p = (lo+hi)/2 // pivot index
|
|
// i = lo
|
|
// j = hi-1
|
|
// while (true) do {
|
|
// while (xs[i] < xs[p]) i ++;
|
|
// i_eq = (xs[i] == xs[p]);
|
|
// while (xs[j] > xs[p]) j --;
|
|
// j_eq = (xs[j] == xs[p]);
|
|
//
|
|
// if (i >= j) return j + 1;
|
|
//
|
|
// if (i < j) {
|
|
// swap(xs[i], xs[j])
|
|
// if (i == p) {
|
|
// p = j;
|
|
// } else if (j == p) {
|
|
// p = i;
|
|
// }
|
|
// if (i_eq && j_eq) {
|
|
// ++i;
|
|
// --j;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
|
uint32_t nTrailingP = 0) {
|
|
// Quick sort partition doesn't use trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
Location loc = func.getLoc();
|
|
ValueRange args = entryBlock->getArguments();
|
|
Value lo = args[loIdx];
|
|
Value hi = args[hiIdx];
|
|
Value sum = arith::AddIOp::create(builder, loc, lo, hi);
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
Value p = arith::ShRUIOp::create(builder, loc, sum, c1);
|
|
|
|
Value i = lo;
|
|
Value j = arith::SubIOp::create(builder, loc, hi, c1);
|
|
createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
|
|
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
|
|
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
|
|
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
|
|
trueVal.getType()};
|
|
scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, operands);
|
|
|
|
// The before-region of the WhileOp.
|
|
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
|
|
{loc, loc, loc, loc});
|
|
builder.setInsertionPointToEnd(before);
|
|
scf::ConditionOp::create(builder, loc, before->getArgument(3),
|
|
before->getArguments());
|
|
|
|
// The after-region of the WhileOp.
|
|
Block *after =
|
|
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
|
|
builder.setInsertionPointToEnd(after);
|
|
i = after->getArgument(0);
|
|
j = after->getArgument(1);
|
|
p = after->getArgument(2);
|
|
|
|
constexpr uint64_t numXBuffers = 1;
|
|
auto [iresult, iCompareEq] =
|
|
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
|
|
i, p, xPerm, ny, 1);
|
|
i = iresult;
|
|
auto [jresult, jCompareEq] =
|
|
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
|
|
j, p, xPerm, ny, -1);
|
|
j = jresult;
|
|
|
|
// If i < j:
|
|
Value cond =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, i, j);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
SmallVector<Value> swapOperands{i, j};
|
|
swapOperands.append(args.begin() + xStartIdx, args.end());
|
|
createSwap(builder, loc, swapOperands, xPerm, ny);
|
|
// If the pivot is moved, update p with the new pivot.
|
|
Value icond =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, i, p);
|
|
scf::IfOp ifOpI = scf::IfOp::create(builder, loc, TypeRange{p.getType()},
|
|
icond, /*else=*/true);
|
|
builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{j});
|
|
builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
|
|
Value jcond =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, j, p);
|
|
scf::IfOp ifOpJ = scf::IfOp::create(builder, loc, TypeRange{p.getType()},
|
|
jcond, /*else=*/true);
|
|
builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{i});
|
|
builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{p});
|
|
builder.setInsertionPointAfter(ifOpJ);
|
|
scf::YieldOp::create(builder, loc, ifOpJ.getResults());
|
|
builder.setInsertionPointAfter(ifOpI);
|
|
Value compareEqIJ =
|
|
arith::AndIOp::create(builder, loc, iCompareEq, jCompareEq);
|
|
scf::IfOp ifOp2 =
|
|
scf::IfOp::create(builder, loc, TypeRange{i.getType(), j.getType()},
|
|
compareEqIJ, /*else=*/true);
|
|
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
|
|
Value i2 = arith::AddIOp::create(builder, loc, i, c1);
|
|
Value j2 = arith::SubIOp::create(builder, loc, j, c1);
|
|
scf::YieldOp::create(builder, loc, ValueRange{i2, j2});
|
|
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{i, j});
|
|
builder.setInsertionPointAfter(ifOp2);
|
|
scf::YieldOp::create(builder, loc,
|
|
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1),
|
|
ifOpI.getResult(0),
|
|
/*cont=*/constantI1(builder, loc, true)});
|
|
|
|
// False branch for if i < j (i.e., i >= j):
|
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
p = arith::AddIOp::create(builder, loc, j,
|
|
constantOne(builder, loc, j.getType()));
|
|
scf::YieldOp::create(
|
|
builder, loc,
|
|
ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
|
|
|
|
// Return for the whileOp.
|
|
builder.setInsertionPointAfter(ifOp);
|
|
scf::YieldOp::create(builder, loc, ifOp.getResults());
|
|
|
|
// Return for the function.
|
|
builder.setInsertionPointAfter(whileOp);
|
|
func::ReturnOp::create(builder, loc, whileOp.getResult(2));
|
|
}
|
|
|
|
/// Computes (n-2)/n, assuming n has index type.
|
|
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
|
|
Value n) {
|
|
Value i2 = constantIndex(builder, loc, 2);
|
|
Value res = arith::SubIOp::create(builder, loc, n, i2);
|
|
Value i1 = constantIndex(builder, loc, 1);
|
|
return arith::ShRUIOp::create(builder, loc, res, i1);
|
|
}
|
|
|
|
/// Creates a function to heapify the subtree with root `start` within the full
|
|
/// binary tree in the range of index [first, first + n).
|
|
//
|
|
// The generated IR corresponds to this C like algorithm:
|
|
// void shiftDown(first, start, n, data) {
|
|
// if (n >= 2) {
|
|
// child = start - first
|
|
// if ((n-2)/2 >= child) {
|
|
// // Left child exists.
|
|
// child = child * 2 + 1 // Initialize the bigger child to left child.
|
|
// childIndex = child + first
|
|
// if (child+1 < n && data[childIndex] < data[childIndex+1])
|
|
// // Right child exits and is bigger.
|
|
// childIndex++; child++;
|
|
// // Shift data[start] down to where it belongs in the subtree.
|
|
// while (data[start] < data[childIndex) {
|
|
// swap(data[start], data[childIndex])
|
|
// start = childIndex
|
|
// if ((n - 2)/2 >= child) {
|
|
// // Left child exists.
|
|
// child = 2*child + 1
|
|
// childIndex = child + 1
|
|
// if (child + 1) < n && data[childIndex] < data[childIndex+1]
|
|
// childIndex++; child++;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
|
uint32_t nTrailingP) {
|
|
// The value n is passed in as a trailing parameter.
|
|
assert(nTrailingP == 1);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
Location loc = func.getLoc();
|
|
Value n = entryBlock->getArguments().back();
|
|
ValueRange args = entryBlock->getArguments().drop_back();
|
|
Value first = args[loIdx];
|
|
Value start = args[hiIdx];
|
|
|
|
// If (n >= 2).
|
|
Value c2 = constantIndex(builder, loc, 2);
|
|
Value condN =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, n, c2);
|
|
scf::IfOp ifN = scf::IfOp::create(builder, loc, condN, /*else=*/false);
|
|
builder.setInsertionPointToStart(&ifN.getThenRegion().front());
|
|
Value child = arith::SubIOp::create(builder, loc, start, first);
|
|
|
|
// If ((n-2)/2 >= child).
|
|
Value t = createSubTwoDividedByTwo(builder, loc, n);
|
|
Value condNc =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
|
|
scf::IfOp ifNc = scf::IfOp::create(builder, loc, condNc, /*else=*/false);
|
|
|
|
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
SmallVector<Value> compareOperands{start, start};
|
|
constexpr uint64_t numXBuffers = 1;
|
|
compareOperands.append(args.begin() + xStartIdx,
|
|
args.begin() + xStartIdx + numXBuffers);
|
|
|
|
// Generate code to inspect the children of 'r' and return the larger child
|
|
// as follows:
|
|
// child = r * 2 + 1 // Left child.
|
|
// childIndex = child + first
|
|
// if (child+1 < n && data[childIndex] < data[childIndex+1])
|
|
// childIndex ++; child ++ // Right child is bigger.
|
|
auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
|
|
Value lChild = arith::ShLIOp::create(builder, loc, r, c1);
|
|
lChild = arith::AddIOp::create(builder, loc, lChild, c1);
|
|
Value lChildIdx = arith::AddIOp::create(builder, loc, lChild, first);
|
|
Value rChild = arith::AddIOp::create(builder, loc, lChild, c1);
|
|
Value cond1 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
|
|
rChild, n);
|
|
SmallVector<Type, 2> ifTypes(2, r.getType());
|
|
scf::IfOp if1 =
|
|
scf::IfOp::create(builder, loc, ifTypes, cond1, /*else=*/true);
|
|
builder.setInsertionPointToStart(&if1.getThenRegion().front());
|
|
Value rChildIdx = arith::AddIOp::create(builder, loc, rChild, first);
|
|
// Compare data[left] < data[right].
|
|
compareOperands[0] = lChildIdx;
|
|
compareOperands[1] = rChildIdx;
|
|
Value cond2 =
|
|
createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
|
scf::IfOp if2 =
|
|
scf::IfOp::create(builder, loc, ifTypes, cond2, /*else=*/true);
|
|
builder.setInsertionPointToStart(&if2.getThenRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{rChild, rChildIdx});
|
|
builder.setInsertionPointToStart(&if2.getElseRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx});
|
|
builder.setInsertionPointAfter(if2);
|
|
scf::YieldOp::create(builder, loc, if2.getResults());
|
|
builder.setInsertionPointToStart(&if1.getElseRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx});
|
|
builder.setInsertionPointAfter(if1);
|
|
return std::make_pair(if1.getResult(0), if1.getResult(1));
|
|
};
|
|
|
|
Value childIdx;
|
|
std::tie(child, childIdx) = getLargerChild(child);
|
|
|
|
// While (data[start] < data[childIndex]).
|
|
SmallVector<Type, 3> types(3, child.getType());
|
|
scf::WhileOp whileOp = scf::WhileOp::create(
|
|
builder, loc, types, SmallVector<Value, 2>{start, child, childIdx});
|
|
|
|
// The before-region of the WhileOp.
|
|
SmallVector<Location, 3> locs(3, loc);
|
|
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
|
|
builder.setInsertionPointToEnd(before);
|
|
start = before->getArgument(0);
|
|
childIdx = before->getArgument(2);
|
|
compareOperands[0] = start;
|
|
compareOperands[1] = childIdx;
|
|
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
|
scf::ConditionOp::create(builder, loc, cond, before->getArguments());
|
|
|
|
// The after-region of the WhileOp.
|
|
Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
|
|
start = after->getArgument(0);
|
|
child = after->getArgument(1);
|
|
childIdx = after->getArgument(2);
|
|
SmallVector<Value> swapOperands{start, childIdx};
|
|
swapOperands.append(args.begin() + xStartIdx, args.end());
|
|
createSwap(builder, loc, swapOperands, xPerm, ny);
|
|
start = childIdx;
|
|
Value cond2 =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
|
|
scf::IfOp if2 = scf::IfOp::create(builder, loc,
|
|
TypeRange{child.getType(), child.getType()},
|
|
cond2, /*else=*/true);
|
|
builder.setInsertionPointToStart(&if2.getThenRegion().front());
|
|
auto [newChild, newChildIdx] = getLargerChild(child);
|
|
scf::YieldOp::create(builder, loc, ValueRange{newChild, newChildIdx});
|
|
builder.setInsertionPointToStart(&if2.getElseRegion().front());
|
|
scf::YieldOp::create(builder, loc, ValueRange{child, childIdx});
|
|
builder.setInsertionPointAfter(if2);
|
|
scf::YieldOp::create(builder, loc,
|
|
ValueRange{start, if2.getResult(0), if2.getResult(1)});
|
|
|
|
builder.setInsertionPointAfter(ifN);
|
|
func::ReturnOp::create(builder, loc);
|
|
}
|
|
|
|
/// Creates a function to perform heap sort on the values in the range of index
|
|
/// [lo, hi) with the assumption hi - lo >= 2.
|
|
//
|
|
// The generate IR corresponds to this C like algorithm:
|
|
// void heapSort(lo, hi, data) {
|
|
// n = hi - lo
|
|
// for i = (n-2)/2 downto 0
|
|
// shiftDown(lo, lo+i, n)
|
|
//
|
|
// for l = n downto 2
|
|
// swap(lo, lo+l-1)
|
|
// shiftdown(lo, lo, l-1)
|
|
// }
|
|
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
|
uint32_t nTrailingP) {
|
|
// Heap sort function doesn't have trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
Location loc = func.getLoc();
|
|
ValueRange args = entryBlock->getArguments();
|
|
Value lo = args[loIdx];
|
|
Value hi = args[hiIdx];
|
|
Value n = arith::SubIOp::create(builder, loc, hi, lo);
|
|
|
|
// For i = (n-2)/2 downto 0.
|
|
Value c0 = constantIndex(builder, loc, 0);
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
Value s = createSubTwoDividedByTwo(builder, loc, n);
|
|
Value up = arith::AddIOp::create(builder, loc, s, c1);
|
|
scf::ForOp forI = scf::ForOp::create(builder, loc, c0, up, c1);
|
|
builder.setInsertionPointToStart(forI.getBody());
|
|
Value i = arith::SubIOp::create(builder, loc, s, forI.getInductionVar());
|
|
Value lopi = arith::AddIOp::create(builder, loc, lo, i);
|
|
SmallVector<Value> shiftDownOperands = {lo, lopi};
|
|
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
|
|
shiftDownOperands.push_back(n);
|
|
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
|
|
builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
|
|
shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
|
|
func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(),
|
|
shiftDownOperands);
|
|
|
|
builder.setInsertionPointAfter(forI);
|
|
// For l = n downto 2.
|
|
up = arith::SubIOp::create(builder, loc, n, c1);
|
|
scf::ForOp forL = scf::ForOp::create(builder, loc, c0, up, c1);
|
|
builder.setInsertionPointToStart(forL.getBody());
|
|
Value l = arith::SubIOp::create(builder, loc, n, forL.getInductionVar());
|
|
Value loplm1 = arith::AddIOp::create(builder, loc, lo, l);
|
|
loplm1 = arith::SubIOp::create(builder, loc, loplm1, c1);
|
|
SmallVector<Value> swapOperands{lo, loplm1};
|
|
swapOperands.append(args.begin() + xStartIdx, args.end());
|
|
createSwap(builder, loc, swapOperands, xPerm, ny);
|
|
shiftDownOperands[1] = lo;
|
|
shiftDownOperands[shiftDownOperands.size() - 1] =
|
|
arith::SubIOp::create(builder, loc, l, c1);
|
|
func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(),
|
|
shiftDownOperands);
|
|
|
|
builder.setInsertionPointAfter(forL);
|
|
func::ReturnOp::create(builder, loc);
|
|
}
|
|
|
|
/// A helper for generating code to perform quick sort. It partitions [lo, hi),
|
|
/// recursively calls quick sort to process the smaller partition and returns
|
|
/// the bigger partition to be processed by the enclosed while-loop.
|
|
static std::pair<Value, Value>
|
|
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
|
ValueRange args, AffineMap xPerm, uint64_t ny,
|
|
uint32_t nTrailingP) {
|
|
MLIRContext *context = module.getContext();
|
|
Location loc = func.getLoc();
|
|
Value lo = args[loIdx];
|
|
Value hi = args[hiIdx];
|
|
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
|
|
|
|
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
|
|
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
|
|
ny, args.drop_back(nTrailingP), createPartitionFunc);
|
|
Value p = builder
|
|
.create<func::CallOp>(loc, partitionFunc,
|
|
TypeRange{IndexType::get(context)},
|
|
args.drop_back(nTrailingP))
|
|
.getResult(0);
|
|
|
|
Value lenLow = arith::SubIOp::create(builder, loc, p, lo);
|
|
Value lenHigh = arith::SubIOp::create(builder, loc, hi, p);
|
|
// Partition already sorts array with len <= 2
|
|
Value c2 = constantIndex(builder, loc, 2);
|
|
Value len = arith::SubIOp::create(builder, loc, hi, lo);
|
|
Value lenGtTwo =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ugt, len, c2);
|
|
scf::IfOp ifLenGtTwo =
|
|
scf::IfOp::create(builder, loc, types, lenGtTwo, /*else=*/true);
|
|
builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
|
|
// Returns an empty range to mark the entire region is fully sorted.
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
|
|
|
|
// Else len > 2, need recursion.
|
|
builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
|
|
Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
|
|
lenLow, lenHigh);
|
|
|
|
Value c0 = constantIndex(builder, loc, 0);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
|
|
|
|
auto mayRecursion = [&](Value low, Value high, Value len) {
|
|
Value cond =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, len, c0);
|
|
scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false);
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
SmallVector<Value> operands{low, high};
|
|
operands.append(args.begin() + xStartIdx, args.end());
|
|
func::CallOp::create(builder, loc, func, operands);
|
|
builder.setInsertionPointAfter(ifOp);
|
|
};
|
|
|
|
// Recursively call quickSort to process the smaller partition and return
|
|
// the bigger partition to be processed by the enclosed while-loop.
|
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
mayRecursion(lo, p, lenLow);
|
|
scf::YieldOp::create(builder, loc, ValueRange{p, hi});
|
|
|
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
mayRecursion(p, hi, lenHigh);
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, p});
|
|
|
|
builder.setInsertionPointAfter(ifOp);
|
|
scf::YieldOp::create(builder, loc, ifOp.getResults());
|
|
|
|
builder.setInsertionPointAfter(ifLenGtTwo);
|
|
return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
|
|
}
|
|
|
|
/// Creates a function to perform insertion sort on the values in the range of
|
|
/// index [lo, hi).
|
|
//
|
|
// The generate IR corresponds to this C like algorithm:
|
|
// void insertionSort(lo, hi, data) {
|
|
// for (i = lo+1; i < hi; i++) {
|
|
// d = data[i];
|
|
// p = binarySearch(lo, i-1, data)
|
|
// for (j = 0; j > i - p; j++)
|
|
// data[i-j] = data[i-j-1]
|
|
// data[p] = d
|
|
// }
|
|
// }
|
|
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm,
|
|
uint64_t ny, uint32_t nTrailingP) {
|
|
// Stable sort function doesn't use trailing parameters.
|
|
(void)nTrailingP;
|
|
assert(nTrailingP == 0);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
MLIRContext *context = module.getContext();
|
|
Location loc = func.getLoc();
|
|
ValueRange args = entryBlock->getArguments();
|
|
Value c1 = constantIndex(builder, loc, 1);
|
|
Value lo = args[loIdx];
|
|
Value hi = args[hiIdx];
|
|
Value lop1 = arith::AddIOp::create(builder, loc, lo, c1);
|
|
|
|
// Start the outer for-stmt with induction variable i.
|
|
scf::ForOp forOpI = scf::ForOp::create(builder, loc, lop1, hi, c1);
|
|
builder.setInsertionPointToStart(forOpI.getBody());
|
|
Value i = forOpI.getInductionVar();
|
|
|
|
// Binary search to find the insertion point p.
|
|
SmallVector<Value> operands{lo, i};
|
|
operands.append(args.begin() + xStartIdx, args.end());
|
|
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
|
|
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
|
|
xPerm, ny, operands, createBinarySearchFunc);
|
|
Value p = builder
|
|
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
|
|
operands)
|
|
.getResult(0);
|
|
|
|
// Move the value at data[i] to a temporary location.
|
|
operands[0] = operands[1] = i;
|
|
SmallVector<Value> d;
|
|
forEachIJPairInAllBuffers(
|
|
builder, loc, operands, xPerm, ny,
|
|
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
|
|
d.push_back(memref::LoadOp::create(builder, loc, buffer, i));
|
|
});
|
|
|
|
// Start the inner for-stmt with induction variable j, for moving data[p..i)
|
|
// to data[p+1..i+1).
|
|
Value imp = arith::SubIOp::create(builder, loc, i, p);
|
|
Value c0 = constantIndex(builder, loc, 0);
|
|
scf::ForOp forOpJ = scf::ForOp::create(builder, loc, c0, imp, c1);
|
|
builder.setInsertionPointToStart(forOpJ.getBody());
|
|
Value j = forOpJ.getInductionVar();
|
|
Value imj = arith::SubIOp::create(builder, loc, i, j);
|
|
operands[1] = imj;
|
|
operands[0] = arith::SubIOp::create(builder, loc, imj, c1);
|
|
forEachIJPairInAllBuffers(
|
|
builder, loc, operands, xPerm, ny,
|
|
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
|
|
Value t = memref::LoadOp::create(builder, loc, buffer, imjm1);
|
|
memref::StoreOp::create(builder, loc, t, buffer, imj);
|
|
});
|
|
|
|
// Store the value at data[i] to data[p].
|
|
builder.setInsertionPointAfter(forOpJ);
|
|
operands[0] = operands[1] = p;
|
|
forEachIJPairInAllBuffers(
|
|
builder, loc, operands, xPerm, ny,
|
|
[&](uint64_t k, Value p, Value usused, Value buffer) {
|
|
memref::StoreOp::create(builder, loc, d[k], buffer, p);
|
|
});
|
|
|
|
builder.setInsertionPointAfter(forOpI);
|
|
func::ReturnOp::create(builder, loc);
|
|
}
|
|
|
|
/// Creates a function to perform quick sort or a hybrid quick sort on the
|
|
/// values in the range of index [lo, hi).
|
|
//
|
|
//
|
|
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
|
|
// void quickSort(lo, hi, data) {
|
|
// while (lo + 1 < hi) {
|
|
// p = partition(low, high, data);
|
|
// if (len(lo, p) < len(p+1, hi)) {
|
|
// quickSort(lo, p, data);
|
|
// lo = p+1;
|
|
// } else {
|
|
// quickSort(p + 1, hi, data);
|
|
// hi = p;
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
|
|
// void hybridQuickSort(lo, hi, data, depthLimit) {
|
|
// while (lo + 1 < hi) {
|
|
// len = hi - lo;
|
|
// if (len <= limit) {
|
|
// insertionSort(lo, hi, data);
|
|
// } else {
|
|
// depthLimit --;
|
|
// if (depthLimit <= 0) {
|
|
// heapSort(lo, hi, data);
|
|
// } else {
|
|
// p = partition(low, high, data);
|
|
// if (len(lo, p) < len(p+1, hi)) {
|
|
// quickSort(lo, p, data, depthLimit);
|
|
// lo = p+1;
|
|
// } else {
|
|
// quickSort(p + 1, hi, data, depthLimit);
|
|
// hi = p;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
|
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
|
uint32_t nTrailingP) {
|
|
assert(nTrailingP == 1 || nTrailingP == 0);
|
|
bool isHybrid = (nTrailingP == 1);
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
Block *entryBlock = func.addEntryBlock();
|
|
builder.setInsertionPointToStart(entryBlock);
|
|
|
|
Location loc = func.getLoc();
|
|
SmallVector<Value> args;
|
|
args.append(entryBlock->getArguments().begin(),
|
|
entryBlock->getArguments().end());
|
|
Value lo = args[loIdx];
|
|
Value hi = args[hiIdx];
|
|
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
|
|
scf::WhileOp whileOp =
|
|
scf::WhileOp::create(builder, loc, types, SmallVector<Value, 2>{lo, hi});
|
|
|
|
// The before-region of the WhileOp.
|
|
Block *before =
|
|
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
|
|
builder.setInsertionPointToEnd(before);
|
|
lo = before->getArgument(0);
|
|
hi = before->getArgument(1);
|
|
Value loP1 =
|
|
arith::AddIOp::create(builder, loc, lo, constantIndex(builder, loc, 1));
|
|
Value needSort =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, loP1, hi);
|
|
scf::ConditionOp::create(builder, loc, needSort, before->getArguments());
|
|
|
|
// The after-region of the WhileOp.
|
|
Block *after =
|
|
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
|
|
builder.setInsertionPointToEnd(after);
|
|
lo = after->getArgument(0);
|
|
hi = after->getArgument(1);
|
|
args[0] = lo;
|
|
args[1] = hi;
|
|
|
|
if (isHybrid) {
|
|
Value len = arith::SubIOp::create(builder, loc, hi, lo);
|
|
Value lenLimit = constantIndex(builder, loc, 30);
|
|
Value lenCond = arith::CmpIOp::create(
|
|
builder, loc, arith::CmpIPredicate::ule, len, lenLimit);
|
|
scf::IfOp lenIf =
|
|
scf::IfOp::create(builder, loc, types, lenCond, /*else=*/true);
|
|
|
|
// When len <= limit.
|
|
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
|
|
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
|
|
builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
|
|
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
|
|
func::CallOp::create(builder, loc, insertionSortFunc, TypeRange(),
|
|
ValueRange(args).drop_back(nTrailingP));
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
|
|
|
|
// When len > limit.
|
|
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
|
|
Value depthLimit = args.back();
|
|
depthLimit = arith::SubIOp::create(builder, loc, depthLimit,
|
|
constantI64(builder, loc, 1));
|
|
Value depthCond =
|
|
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
|
|
depthLimit, constantI64(builder, loc, 0));
|
|
scf::IfOp depthIf =
|
|
scf::IfOp::create(builder, loc, types, depthCond, /*else=*/true);
|
|
|
|
// When depth exceeds limit.
|
|
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
|
|
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
|
|
builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
|
|
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
|
|
func::CallOp::create(builder, loc, heapSortFunc, TypeRange(),
|
|
ValueRange(args).drop_back(nTrailingP));
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
|
|
|
|
// When depth doesn't exceed limit.
|
|
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
|
|
args.back() = depthLimit;
|
|
std::tie(lo, hi) =
|
|
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
|
|
|
|
builder.setInsertionPointAfter(depthIf);
|
|
lo = depthIf.getResult(0);
|
|
hi = depthIf.getResult(1);
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
|
|
|
|
builder.setInsertionPointAfter(lenIf);
|
|
lo = lenIf.getResult(0);
|
|
hi = lenIf.getResult(1);
|
|
} else {
|
|
std::tie(lo, hi) =
|
|
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
|
|
}
|
|
|
|
// New [lo, hi) for the next while-loop iteration.
|
|
scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
|
|
|
|
// After the while-loop.
|
|
builder.setInsertionPointAfter(whileOp);
|
|
func::ReturnOp::create(builder, loc);
|
|
}
|
|
|
|
/// Implements the rewriting for operator sort and sort_coo.
|
|
template <typename OpTy>
|
|
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
|
|
uint64_t ny, PatternRewriter &rewriter) {
|
|
Location loc = op.getLoc();
|
|
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
|
|
|
|
// Convert `values` to have dynamic shape and append them to `operands`.
|
|
for (Value v : xys) {
|
|
auto mtp = getMemRefType(v);
|
|
if (!mtp.isDynamicDim(0)) {
|
|
auto newMtp =
|
|
MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
|
|
v = memref::CastOp::create(rewriter, loc, newMtp, v);
|
|
}
|
|
operands.push_back(v);
|
|
}
|
|
|
|
auto insertPoint = op->template getParentOfType<func::FuncOp>();
|
|
if (!insertPoint)
|
|
return failure();
|
|
|
|
SmallString<32> funcName;
|
|
FuncGeneratorType funcGenerator;
|
|
uint32_t nTrailingP = 0;
|
|
switch (op.getAlgorithm()) {
|
|
case SparseTensorSortKind::HybridQuickSort: {
|
|
funcName = kHybridQuickSortFuncNamePrefix;
|
|
funcGenerator = createQuickSortFunc;
|
|
nTrailingP = 1;
|
|
// As a heuristics, set depthLimit = 2 * log2(n).
|
|
Value lo = operands[loIdx];
|
|
Value hi = operands[hiIdx];
|
|
Value len = arith::IndexCastOp::create(
|
|
rewriter, loc, rewriter.getI64Type(),
|
|
arith::SubIOp::create(rewriter, loc, hi, lo));
|
|
Value depthLimit = arith::SubIOp::create(
|
|
rewriter, loc, constantI64(rewriter, loc, 64),
|
|
math::CountLeadingZerosOp::create(rewriter, loc, len));
|
|
operands.push_back(depthLimit);
|
|
break;
|
|
}
|
|
case SparseTensorSortKind::QuickSort:
|
|
funcName = kQuickSortFuncNamePrefix;
|
|
funcGenerator = createQuickSortFunc;
|
|
break;
|
|
case SparseTensorSortKind::InsertionSortStable:
|
|
funcName = kSortStableFuncNamePrefix;
|
|
funcGenerator = createSortStableFunc;
|
|
break;
|
|
case SparseTensorSortKind::HeapSort:
|
|
funcName = kHeapSortFuncNamePrefix;
|
|
funcGenerator = createHeapSortFunc;
|
|
break;
|
|
}
|
|
|
|
FlatSymbolRefAttr func =
|
|
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
|
|
xPerm, ny, operands, funcGenerator, nTrailingP);
|
|
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
|
|
return success();
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// The actual sparse buffer rewriting rules.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Sparse rewriting rule for the push_back operator.
|
|
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
|
|
public:
|
|
using OpRewritePattern<PushBackOp>::OpRewritePattern;
|
|
PushBackRewriter(MLIRContext *context, bool enableInit)
|
|
: OpRewritePattern(context), enableBufferInitialization(enableInit) {}
|
|
LogicalResult matchAndRewrite(PushBackOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Rewrite push_back(buffer, value, n) to:
|
|
// new_size = size(buffer) + n
|
|
// if (new_size > capacity(buffer))
|
|
// while new_size > new_capacity
|
|
// new_capacity = new_capacity*2
|
|
// new_buffer = realloc(buffer, new_capacity)
|
|
// buffer = new_buffer
|
|
// subBuffer = subviewof(buffer)
|
|
// linalg.fill subBuffer value
|
|
//
|
|
// size(buffer) += n
|
|
//
|
|
// The capacity check is skipped when the attribute inbounds is presented.
|
|
Location loc = op->getLoc();
|
|
Value c0 = constantIndex(rewriter, loc, 0);
|
|
Value buffer = op.getInBuffer();
|
|
Value capacity = memref::DimOp::create(rewriter, loc, buffer, c0);
|
|
Value size = op.getCurSize();
|
|
Value value = op.getValue();
|
|
|
|
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
|
|
Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
|
|
auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
|
|
bool nIsOne = (nValue && nValue.value() == 1);
|
|
|
|
if (!op.getInbounds()) {
|
|
Value cond = arith::CmpIOp::create(
|
|
rewriter, loc, arith::CmpIPredicate::ugt, newSize, capacity);
|
|
|
|
Value c2 = constantIndex(rewriter, loc, 2);
|
|
auto bufferType =
|
|
MemRefType::get({ShapedType::kDynamic}, value.getType());
|
|
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, bufferType, cond,
|
|
/*else=*/true);
|
|
// True branch.
|
|
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
|
if (nIsOne) {
|
|
capacity = arith::MulIOp::create(rewriter, loc, capacity, c2);
|
|
} else {
|
|
// Use a do-while loop to calculate the new capacity as follows:
|
|
// do { new_capacity *= 2 } while (size > new_capacity)
|
|
scf::WhileOp whileOp =
|
|
scf::WhileOp::create(rewriter, loc, capacity.getType(), capacity);
|
|
|
|
// The before-region of the WhileOp.
|
|
Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
|
|
{capacity.getType()}, {loc});
|
|
rewriter.setInsertionPointToEnd(before);
|
|
|
|
capacity =
|
|
arith::MulIOp::create(rewriter, loc, before->getArgument(0), c2);
|
|
cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt,
|
|
newSize, capacity);
|
|
scf::ConditionOp::create(rewriter, loc, cond, ValueRange{capacity});
|
|
// The after-region of the WhileOp.
|
|
Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
|
|
{capacity.getType()}, {loc});
|
|
rewriter.setInsertionPointToEnd(after);
|
|
scf::YieldOp::create(rewriter, loc, after->getArguments());
|
|
|
|
rewriter.setInsertionPointAfter(whileOp);
|
|
capacity = whileOp.getResult(0);
|
|
}
|
|
|
|
Value newBuffer = memref::ReallocOp::create(rewriter, loc, bufferType,
|
|
buffer, capacity);
|
|
if (enableBufferInitialization) {
|
|
Value fillSize =
|
|
arith::SubIOp::create(rewriter, loc, capacity, newSize);
|
|
Value fillValue = constantZero(rewriter, loc, value.getType());
|
|
Value subBuffer = memref::SubViewOp::create(
|
|
rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize},
|
|
/*size=*/ValueRange{fillSize},
|
|
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
|
|
linalg::FillOp::create(rewriter, loc, fillValue, subBuffer);
|
|
}
|
|
scf::YieldOp::create(rewriter, loc, newBuffer);
|
|
|
|
// False branch.
|
|
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
scf::YieldOp::create(rewriter, loc, buffer);
|
|
|
|
// Prepare for adding the value to the end of the buffer.
|
|
rewriter.setInsertionPointAfter(ifOp);
|
|
buffer = ifOp.getResult(0);
|
|
}
|
|
|
|
// Add the value to the end of the buffer.
|
|
if (nIsOne) {
|
|
memref::StoreOp::create(rewriter, loc, value, buffer, size);
|
|
} else {
|
|
Value subBuffer = memref::SubViewOp::create(
|
|
rewriter, loc, buffer, /*offset=*/ValueRange{size},
|
|
/*size=*/ValueRange{n},
|
|
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
|
|
linalg::FillOp::create(rewriter, loc, value, subBuffer);
|
|
}
|
|
|
|
// Update the buffer size.
|
|
rewriter.replaceOp(op, {buffer, newSize});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
bool enableBufferInitialization;
|
|
};
|
|
|
|
/// Sparse rewriting rule for the sort_coo operator.
|
|
struct SortRewriter : public OpRewritePattern<SortOp> {
|
|
public:
|
|
using OpRewritePattern<SortOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SortOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<Value> xys;
|
|
xys.push_back(op.getXy());
|
|
xys.append(op.getYs().begin(), op.getYs().end());
|
|
|
|
auto xPerm = op.getPermMap();
|
|
uint64_t ny = 0;
|
|
if (auto nyAttr = op.getNyAttr())
|
|
ny = nyAttr.getInt();
|
|
|
|
return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods that add patterns described in this file to a pattern list.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
|
|
bool enableBufferInitialization) {
|
|
patterns.add<PushBackRewriter>(patterns.getContext(),
|
|
enableBufferInitialization);
|
|
patterns.add<SortRewriter>(patterns.getContext());
|
|
}
|