179 lines
6.4 KiB
C++
179 lines
6.4 KiB
C++
//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// These rewriters lower from the Tosa to the SCF dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace tosa;
|
|
|
|
static void inlineIfCase(Region &srcRegion, Region &dstRegion,
|
|
OperandRange operands, PatternRewriter &rewriter) {
|
|
rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
|
|
rewriter.eraseBlock(&dstRegion.back());
|
|
|
|
Block *headBlock = &dstRegion.front();
|
|
for (auto it : llvm::zip(headBlock->getArguments(), operands))
|
|
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
|
|
|
auto yield = cast<YieldOp>(headBlock->getTerminator());
|
|
rewriter.setInsertionPoint(yield);
|
|
scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
|
|
rewriter.eraseOp(yield);
|
|
|
|
headBlock->eraseArguments(0, headBlock->getNumArguments());
|
|
}
|
|
|
|
static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
|
|
PatternRewriter &rewriter, bool isCond) {
|
|
rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
|
|
rewriter.eraseBlock(&dstRegion.back());
|
|
|
|
Block *headBlock = &dstRegion.front();
|
|
|
|
auto yield = cast<YieldOp>(headBlock->getTerminator());
|
|
rewriter.setInsertionPoint(yield);
|
|
if (isCond) {
|
|
auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
|
|
yield.getOperand(0));
|
|
scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
|
|
headBlock->getArguments());
|
|
} else {
|
|
rewriter.setInsertionPoint(yield);
|
|
scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
|
|
}
|
|
rewriter.eraseOp(yield);
|
|
}
|
|
|
|
namespace {
|
|
|
|
class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::IfOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
auto condition =
|
|
tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
|
|
auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
|
|
condition, true);
|
|
|
|
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
|
|
rewriter);
|
|
inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
|
|
rewriter);
|
|
|
|
rewriter.replaceOp(op, newIf.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
|
|
static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
|
|
int64_t dim) {
|
|
return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
|
|
}
|
|
|
|
static Value createIndexConst(OpBuilder &builder, Location loc,
|
|
int64_t value) {
|
|
return arith::ConstantIndexOp::create(builder, loc, value);
|
|
}
|
|
|
|
public:
|
|
using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
|
|
PatternRewriter &rewriter) const final {
|
|
auto valuesIn = scatter.getValuesIn();
|
|
auto indices = scatter.getIndices();
|
|
auto input = scatter.getInput();
|
|
auto loc = scatter.getLoc();
|
|
|
|
// N, W, C are chosen to match the TOSA spec
|
|
auto dimN = createTensorDim(rewriter, loc, input, 0);
|
|
auto dimW = createTensorDim(rewriter, loc, input, 1);
|
|
auto dimC = createTensorDim(rewriter, loc, input, 2);
|
|
|
|
auto zero = createIndexConst(rewriter, loc, 0);
|
|
auto one = createIndexConst(rewriter, loc, 1);
|
|
|
|
// Loop bounds
|
|
auto lbs = llvm::SmallVector<Value>(2, zero);
|
|
auto steps = llvm::SmallVector<Value>(2, one);
|
|
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
|
|
|
|
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
|
|
ValueRange args) -> scf::ValueVector {
|
|
auto n = ivs[0];
|
|
|
|
// Read the index and cast it to index type
|
|
auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
|
|
auto castIndex = arith::IndexCastOp::create(
|
|
builder, loc, builder.getIndexType(), index);
|
|
|
|
// Offset, sizes, and strides for the input tensor
|
|
auto inputOffset = llvm::to_vector(ivs);
|
|
inputOffset.push_back(zero);
|
|
|
|
llvm::SmallVector<Value> sizes = {one, one, dimC};
|
|
llvm::SmallVector<Value> strides = {one, one, one};
|
|
|
|
auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
|
|
inputOffset, sizes, strides);
|
|
|
|
// Insert the slice into the output accumulator tensor.
|
|
llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
|
|
auto updated = tensor::InsertSliceOp::create(
|
|
builder, loc, slice, args[0], outputOffset, sizes, strides);
|
|
|
|
return {updated};
|
|
};
|
|
|
|
auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
|
|
ValueRange{valuesIn}, buildBody);
|
|
rewriter.replaceOp(scatter, loops.results);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::WhileOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
auto newWhile = scf::WhileOp::create(
|
|
rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
|
|
rewriter.createBlock(&newWhile.getBefore());
|
|
rewriter.createBlock(&newWhile.getAfter());
|
|
|
|
inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
|
|
inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
|
|
|
|
rewriter.replaceOp(op, newWhile.getResults());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaToSCFConversionPatterns(
|
|
RewritePatternSet *patterns) {
|
|
patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
|
|
patterns->getContext());
|
|
}
|