Artem Gindinson d03f30fb52
[mlir][TOSA] restore unrealized casts when lowering rescale ops (#141096)
Along with the changes to rescale op attributes, commit 7208649 dropped
the builtin casts between signed and signless types. However, explicitly
unsigned types are still legal input and output values from the TOSA IR
perspective.

The change adds back the casts when the unsigned<->signless semantics
are explicit in the underlying tensor types. This prevents the
conversion routine from trying to generate illegal `arith` casts that
are constrained to signless types. Whether the `arith` casts themselves
are signed or unsigned should still depend on the rescale's `*_unsigned`
attribute values.

---------

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
2025-05-26 12:52:19 +01:00

2912 lines
120 KiB
C++

//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include <numeric>
#include <type_traits>
using namespace mlir;
using namespace mlir::tosa;
// Helper function to materialize the semantically correct compare and select
// operations given a binary operation with a specific NaN propagation mode.
//
// In the case of "PROPAGATE" semantics no compare and selection is required and
// this function does nothing.
//
// In the case of "IGNORE" semantics this function materializes a comparison of
// the current operands to the op which will return true for any NaN
// argument and then selects between the non-NaN operation argument and the
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
// code:
//
// In the case that the op is operating on non floating point types we ignore
// the attribute completely, this is consistent with the TOSA spec which has
// the following wording: "This attribute is ignored by non floating-point
// types."
//
// binary<op>(lhs, rhs):
// result = op(lhs, rhs)
// if lhs == NaN return rhs
// if rhs == NaN return lhs
// return result
template <typename OpTy>
static Value
materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
Value lhs, Value rhs, Value result) {
// NaN propagation has no meaning for non floating point types.
if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
return result;
auto nanMode = op.getNanMode();
if (nanMode == "PROPAGATE")
return result;
// Unordered comparison of NaN against itself will always return true.
Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
Value rhsOrResult =
rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
rhsOrResult);
}
static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}
// tosa::AddOp
if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
// tosa::ReciprocalOp
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
return nullptr;
}
int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
if (isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getStringAttr("SINGLE_ROUND"));
if (elementTy.isInteger(32))
return result;
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
}
int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
if (aWidth < cWidth)
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
}
}
// tosa::NegateOp
if (isa<tosa::NegateOp>(op)) {
auto negate = cast<tosa::NegateOp>(op);
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
if (failed(maybeInZp)) {
(void)rewriter.notifyMatchFailure(
op, "input1 zero point cannot be statically determined");
return nullptr;
}
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
if (failed(maybeOutZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return nullptr;
}
int64_t inZp = *maybeInZp;
int64_t outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
args[0]);
}
// Compute the maximum value that can occur in the intermediate buffer.
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
const int64_t zpAdd = inZp + outZp;
const int64_t maxValue =
APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
std::abs(zpAdd) + 1;
// Convert that maximum value into the maximum bitwidth needed to
// represent it. We assume 48-bit numbers may be supported further in
// the pipeline.
int intermediateBitWidth = 64;
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
intermediateBitWidth = 16;
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
intermediateBitWidth = 32;
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
intermediateBitWidth = 48;
}
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
Value zpAddValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
auto ext =
rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
// Clamp to the negation range.
Value min = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
intermediateType);
Value max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
}
}
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
// tosa::BitwiseNotOp
if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
}
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
// tosa::LogicalRightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
if (!round) {
return result;
}
Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
auto one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto i1one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
// Checking that input2 != 0
auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[1], zero);
// Checking for the last bit of input1 to be 1
auto subtract =
rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
auto shifted =
rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
->getResults();
auto truncated =
rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
auto isInputOdd =
rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
auto shouldRound = rewriter.create<arith::AndIOp>(
loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
auto extended =
rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
}
// tosa::ClzOp
if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
}
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
// tosa::LogicalNot
if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(elementTy, 1));
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
}
// tosa::LogicalOr
if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
// tosa::LogicalXor
if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
// tosa::SinOp
if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
// tosa::CosOp
if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::ErfOp
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
args[0], args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
args[0], args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
args[0], args[1]);
// tosa::EqualOp
if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
args[0], args[1]);
if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
args[0], args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
rewriter, args[0], args[1], max);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
rewriter, args[0], args[1], min);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::CeilOp>(loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::FloorOp>(loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
bool losesInfo = false;
APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
auto min = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
auto clampOp = llvm::cast<tosa::ClampOp>(op);
const auto nanMode = clampOp.getNanMode();
// NaN propagation has no meaning for non floating point types.
if (!isa<FloatType>(elementTy))
return result;
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
if (nanMode == "PROPAGATE")
return result;
// In the case of "IGNORE" semantics materialize a comparison
// of the current operand to the reduction which will return true for a NaN
// argument and then selects between the initial reduction value and the
// calculated result based on whether the argument is NaN or not. In pseudo
// code:
//
// reduce<op>(x, init):
// result = op(init, x)
// return init if x == NaN else result
// Unordered comparison of NaN against itself will always return true.
Value isNaN = rewriter.create<arith::CmpFOp>(
op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
// is NaN.
return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
}
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
int64_t min =
cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
int64_t max =
cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
int64_t minRepresentable = std::numeric_limits<int64_t>::min();
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
if (intTy.isUnsignedInteger()) {
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
maxRepresentable =
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
}
} else if (intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
}
// Ensure that the bounds are representable as n-bit signed/unsigned
// integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
max = std::min(max, maxRepresentable);
auto minVal = rewriter.create<arith::ConstantIntOp>(
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
intTy.isUnsignedInteger());
}
// tosa::SigmoidOp
if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
}
// tosa::CastOp
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
(void)rewriter.notifyMatchFailure(op, "unsupported type");
return nullptr;
}
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
if (srcTy == dstTy)
return args.front();
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
std::nullopt);
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
std::nullopt);
// 1-bit integers need to be treated as signless.
if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
std::nullopt);
if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
std::nullopt);
// Unsigned integers need an unrealized cast so that they can be passed
// to UIToFP.
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast =
rewriter
.create<UnrealizedConversionCastOp>(
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
args[0])
.getResult(0);
return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
unrealizedCast);
}
// All other si-to-fp conversions should be handled by SIToFP.
if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
std::nullopt);
// Casting to boolean, floats need to only be checked as not-equal to zero.
if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(srcTy, 0.0));
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
args.front(), zero);
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
// Check whether neither int min nor int max can be represented in the
// input floating-point type due to too short exponent range.
if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
APFloat::semanticsMaxExponent(fltSemantics)) {
// Use cmp + select to replace infinites by int min / int max. Other
// integral values can be represented in the integer space.
auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
auto posInf = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
APFloat::getInf(fltSemantics)));
auto negInf = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APFloat::getInf(fltSemantics, /*Negative=*/true)));
auto overflow = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::UEQ, rounded, posInf);
auto underflow = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::UEQ, rounded, negInf);
auto intMin = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
getElementTypeOrSelf(dstTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
auto intMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
getElementTypeOrSelf(dstTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto maxClamped =
rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
maxClamped);
}
auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
// Check whether the mantissa has enough bits to represent int max.
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
dstTy.getIntOrFloatBitWidth() - 1) {
// Int min can also be represented since it is a power of two and thus
// consists of a single leading bit. Therefore we can clamp the input
// in the floating-point domain.
auto intMaxFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
Value clamped =
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
// Due to earlier check we know exponant range is big enough to represent
// int min. We can therefore rely on int max + 1 being representable as
// well because it's just int min with a positive sign. So clamp the min
// value and compare against that to select the max int value if needed.
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
static_cast<double>(
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()) +
1.0f));
auto intMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
getElementTypeOrSelf(dstTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto minClampedFP =
rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
auto minClamped =
rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
auto overflow = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantIntOp>(
loc, 0, srcTy.getIntOrFloatBitWidth());
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
args.front(), zero);
}
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
std::nullopt);
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
}
}
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
using IndexPool = DenseMap<int64_t, Value>;
// Emit an 'arith.constant' op for the given index if it has not been created
// yet, or return an existing constant. This will prevent an excessive creation
// of redundant constants, easing readability of emitted code for unit tests.
static Value createIndex(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, int64_t index) {
auto [it, inserted] = indexPool.try_emplace(index);
if (inserted)
it->second =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
return it->second;
}
static Value getTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor, int64_t index) {
auto indexValue = createIndex(rewriter, loc, indexPool, index);
return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
}
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor,
int64_t index) {
auto shapedType = dyn_cast<ShapedType>(tensor.getType());
assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
if (shapedType.isDynamicDim(index))
return getTensorDim(rewriter, loc, indexPool, tensor, index);
return rewriter.getIndexAttr(shapedType.getDimSize(index));
}
static bool operandsAndResultsRanked(Operation *operation) {
auto isRanked = [](Value value) {
return isa<RankedTensorType>(value.getType());
};
return llvm::all_of(operation->getOperands(), isRanked) &&
llvm::all_of(operation->getResults(), isRanked);
}
// Compute the runtime dimension size for dimension 'dim' of the output by
// inspecting input 'operands', all of which are expected to have the same rank.
// This function returns a pair {targetSize, masterOperand}.
//
// The runtime size of the output dimension is returned either as a statically
// computed attribute or as a runtime SSA value.
//
// If the target size was inferred directly from one dominating operand, that
// operand is returned in 'masterOperand'. If the target size is inferred from
// multiple operands, 'masterOperand' is set to nullptr.
static std::pair<OpFoldResult, Value>
computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
ValueRange operands, int64_t dim) {
// If any input operand contains a static size greater than 1 for this
// dimension, that is the target size. An occurrence of an additional static
// dimension greater than 1 with a different value is undefined behavior.
for (auto operand : operands) {
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
if (!ShapedType::isDynamic(size) && size > 1)
return {rewriter.getIndexAttr(size), operand};
}
// Filter operands with dynamic dimension
auto operandsWithDynamicDim =
llvm::filter_to_vector(operands, [&](Value operand) {
return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
});
// If no operand has a dynamic dimension, it means all sizes were 1
if (operandsWithDynamicDim.empty())
return {rewriter.getIndexAttr(1), operands.front()};
// Emit code that computes the runtime size for this dimension. If there is
// only one operand with a dynamic dimension, it is considered the master
// operand that determines the runtime size of the output dimension.
auto targetSize =
getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
if (operandsWithDynamicDim.size() == 1)
return {targetSize, operandsWithDynamicDim[0]};
// Calculate maximum size among all dynamic dimensions
for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
auto nextSize =
getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
}
return {targetSize, nullptr};
}
// Compute the runtime output size for all dimensions. This function returns
// a pair {targetShape, masterOperands}.
static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
computeTargetShape(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands) {
assert(!operands.empty());
auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
SmallVector<OpFoldResult> targetShape;
SmallVector<Value> masterOperands;
for (auto dim : llvm::seq<int64_t>(0, rank)) {
auto [targetSize, masterOperand] =
computeTargetSize(rewriter, loc, indexPool, operands, dim);
targetShape.push_back(targetSize);
masterOperands.push_back(masterOperand);
}
return {targetShape, masterOperands};
}
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
int64_t dim, OpFoldResult targetSize,
Value masterOperand) {
// Nothing to do if this is a static dimension
auto rankedTensorType = cast<RankedTensorType>(operand.getType());
if (!rankedTensorType.isDynamicDim(dim))
return operand;
// If the target size for this dimension was directly inferred by only taking
// this operand into account, there is no need to broadcast. This is an
// optimization that will prevent redundant control flow, and constitutes the
// main motivation for tracking "master operands".
if (operand == masterOperand)
return operand;
// Affine maps for 'linalg.generic' op
auto rank = rankedTensorType.getRank();
SmallVector<AffineExpr> affineExprs;
for (auto index : llvm::seq<int64_t>(0, rank)) {
auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
: rewriter.getAffineDimExpr(index);
affineExprs.push_back(affineExpr);
}
auto broadcastAffineMap =
AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
// Check if broadcast is necessary
auto one = createIndex(rewriter, loc, indexPool, 1);
auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, runtimeSize, one);
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
// It is not safe to cache constants across regions.
// New constants could potentially violate dominance requirements.
IndexPool localPool;
// Emit 'tensor.empty' op
SmallVector<OpFoldResult> outputTensorShape;
for (auto index : llvm::seq<int64_t>(0, rank)) {
auto size = index == dim ? targetSize
: getOrFoldTensorDim(rewriter, loc, localPool,
operand, index);
outputTensorShape.push_back(size);
}
Value outputTensor = opBuilder.create<tensor::EmptyOp>(
loc, outputTensorShape, rankedTensorType.getElementType());
// Emit 'linalg.generic' op
auto resultTensor =
opBuilder
.create<linalg::GenericOp>(
loc, outputTensor.getType(), operand, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
// Emit 'linalg.yield' op
opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
})
.getResult(0);
// Cast to original operand type if necessary
auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
loc, operand.getType(), resultTensor);
// Emit 'scf.yield' op
opBuilder.create<scf::YieldOp>(loc, castResultTensor);
};
// Emit 'else' region of 'scf.if'
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
opBuilder.create<scf::YieldOp>(loc, operand);
};
// Emit 'scf.if' op
auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
emitThenRegion, emitElseRegion);
return ifOp.getResult(0);
}
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
assert((int64_t)targetShape.size() == rank);
assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
operand =
broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
targetShape[index], masterOperands[index]);
return operand;
}
static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
// No need to broadcast for unary operations
if (operands.size() == 1)
return operands;
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
targetShape, masterOperands);
});
}
static LogicalResult
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
ArrayRef<OpFoldResult> targetShape,
const TypeConverter &converter) {
// Generate output tensor
auto resultType = cast_or_null<RankedTensorType>(
converter.convertType(operation->getResultTypes().front()));
if (!resultType) {
return rewriter.notifyMatchFailure(operation, "failed to convert type");
}
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
// Create affine maps. Input affine maps broadcast static dimensions of size
// 1. The output affine map is an identity map.
//
auto rank = resultType.getRank();
auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
auto shape = cast<ShapedType>(operand.getType()).getShape();
SmallVector<AffineExpr> affineExprs;
for (auto it : llvm::enumerate(shape)) {
// Prefer producting identity maps whenever possible (i.e. no broadcasting
// needed) because some transforms (like reshape folding)
// do not support affine constant exprs.
bool requiresBroadcast =
(it.value() == 1 && resultType.getDimSize(it.index()) != 1);
auto affineExpr = requiresBroadcast
? rewriter.getAffineConstantExpr(0)
: rewriter.getAffineDimExpr(it.index());
affineExprs.push_back(affineExpr);
}
return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
});
affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Emit 'linalg.generic' op
bool encounteredError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
{resultType.getElementType()}, rewriter);
if (!opResult) {
encounteredError = true;
return;
}
opBuilder.create<linalg::YieldOp>(loc, opResult);
});
if (encounteredError)
return rewriter.notifyMatchFailure(
operation, "unable to create linalg.generic body for elementwise op");
// Cast 'linalg.generic' result into original result type if needed
auto castResult = rewriter.createOrFold<tensor::CastOp>(
loc, resultType, linalgOp->getResult(0));
rewriter.replaceOp(operation, castResult);
return success();
}
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
if (isa<tosa::MulOp>(operation))
return operands.take_front(2);
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
return operands;
}
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter &converter) {
// Collect op properties
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
assert(operation->getNumOperands() >= 1 &&
"elementwise op expects at least 1 operand");
if (!operandsAndResultsRanked(operation))
return rewriter.notifyMatchFailure(operation,
"Unranked tensors not supported");
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
auto broadcastOperands =
broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
targetShape, converter);
}
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(elementTy, 0.0);
if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(elementTy, 0);
if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(elementTy, 1.0);
if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(elementTy, 1);
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
cast<FloatType>(elementTy).getFloatSemantics(), false));
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
cast<FloatType>(elementTy).getFloatSemantics(), true));
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
cast<FloatType>(elementTy).getFloatSemantics(), true));
if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
return {};
}
// Creates the body calculation for a reduction. The operations vary depending
// on the input type.
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
ValueRange args,
Type elementTy,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::AddFOp>(loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::AddIOp>(loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MulFOp>(loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::MulIOp>(loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::AndIOp>(loc, args);
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, args);
return {};
}
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
template <typename OpTy>
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!inputTy || !resultTy)
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
SmallVector<int64_t> reduceShape;
SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
if (axis != i) {
reduceShape.push_back(inputTy.getDimSize(i));
if (inputTy.isDynamicDim(i))
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
SmallVector<Value> inputs, outputs;
inputs.push_back(input);
// First fill the output buffer with the init value.
auto emptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
dynDims)
.getResult();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
.result();
outputs.push_back(filledTensor);
bool isNanIgnoreMode = false;
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
// Additionally we have to keep track of whether we've seen any non-NaN
// values and then do a final select based on this predicate.
auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
auto emptyBoolTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
dynDims)
.getResult();
auto allResultsNaNTensor =
rewriter
.create<linalg::FillOp>(loc, ValueRange{trueValue},
ValueRange{emptyBoolTensor})
.result();
// Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
// need to have the same number of inputs and outputs.
//
// The second input isn't actually used anywhere since the value used to
// update the NaN flag is calculated inside the body of the reduction and
// then used to update an out value.
// In order to satisfy type constraints we just pass another copy of the
// input here.
inputs.push_back(input);
outputs.push_back(allResultsNaNTensor);
}
}
bool didEncounterError = false;
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, binaryArgs, elementTy, rewriter);
if (result)
didEncounterError = true;
SmallVector<Value> resultsToYield;
if (isNanIgnoreMode) {
auto inputValue = blockArgs[0];
auto initialValue = blockArgs[2];
auto oldAllResultsNanFlagValue = blockArgs[3];
// Unordered comparison of NaN against itself will always return true.
Value isNaN = nestedBuilder.create<arith::CmpFOp>(
op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
// If we've encountered a NaN, take the non-NaN value.
auto selectOp = nestedBuilder.create<arith::SelectOp>(
op->getLoc(), isNaN, initialValue, result);
// Update the flag which keeps track of whether we have seen a non-NaN
// value.
auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
op->getLoc(), oldAllResultsNanFlagValue, isNaN);
resultsToYield.push_back(selectOp);
resultsToYield.push_back(newAllResultsNanFlagValue);
} else {
resultsToYield.push_back(result);
}
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
});
if (!didEncounterError)
return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");
if (isNanIgnoreMode) {
// Materialize a check to see whether we encountered any non-NaN values, if
// we didn't we need to select a tensor of NaNs since the result will just
// be the initial identity value propagated through all the compares and
// selects inside the reduction.
// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
auto emptyNanTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();
auto nanFilledTensor =
rewriter
.create<linalg::FillOp>(loc, ValueRange{nanValue},
ValueRange{emptyNanTensor})
.result();
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();
// Do a selection between the tensors akin to:
// result = NaN if "all results NaN" else result.
SmallVector<Value> ins, outs;
ins.push_back(linalgOp->getOpResult(1));
ins.push_back(nanFilledTensor);
ins.push_back(linalgOp->getResult(0));
outs.push_back(finalEmptyTensor);
auto linalgSelect =
rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
linalgOp = linalgSelect;
}
SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
reassociationMap.resize(expandInputRank);
for (uint64_t i = 0; i < expandInputRank; i++) {
int32_t dimToPush = i > axis ? i + 1 : i;
reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
}
if (expandInputRank != 0) {
int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
reassociationMap[expandedDim].push_back(
rewriter.getAffineDimExpr(expandedDim + 1));
}
// Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
// not have access to such information. This matters when handling dynamically
// sized tensors.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultTy, linalgOp->getResults()[0], reassociationMap);
return success();
}
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;
using typename OpConversionPattern<SrcOp>::OpAdaptor;
LogicalResult
matchAndRewrite(SrcOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(
op, operands.getOperands(), rewriter, *this->getTypeConverter());
}
};
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::RescaleOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto input = op.getInput();
auto inputTy = cast<ShapedType>(op.getInput().getType());
auto outputTy = cast<ShapedType>(op.getOutput().getType());
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
if (op.getRoundingMode() == "INEXACT_ROUND")
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
if (!isa<IntegerType>(inputTy.getElementType()))
return rewriter.notifyMatchFailure(op, "only support integer type");
SmallVector<Value> dynDims;
for (int i = 0; i < outputTy.getRank(); i++) {
if (outputTy.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
// The shift and multiplier values.
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant shift input values");
DenseElementsAttr multiplierElems;
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");
llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));
// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
if (shiftValues[i] > 63) {
shiftValues[i] = 0;
multiplierValues[i] = 0;
}
}
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
bool doubleRound =
op.getRoundingMode() == "DOUBLE_ROUND" &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
StringAttr roundingMode = doubleRound
? rewriter.getStringAttr("DOUBLE_ROUND")
: rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<Value, 4> genericInputs = {input};
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
multiplierConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));
multiplierArg = indexingMaps.size() - 1;
}
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
shiftConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
shiftArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Construct the indexing maps needed for linalg.generic ops.
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputTy.getElementType(),
ArrayRef<Value>({dynDims}));
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value value = blockArgs[0];
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
return;
}
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
auto inputZp = nestedBuilder.create<arith::ConstantOp>(
loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
*maybeIZp));
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return;
};
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
auto outputZp = nestedBuilder.create<arith::ConstantOp>(
loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
*maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(
nestedLoc,
nestedBuilder.getIntegerType(
valueTy.getIntOrFloatBitWidth()),
value)
.getResult(0);
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
if (op.getInputUnsigned()) {
value = nestedBuilder.create<arith::ExtUIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
} else {
value = nestedBuilder.create<arith::ExtSIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
}
}
value =
nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
value = nestedBuilder.create<tosa::ApplyScaleOp>(
loc, nestedBuilder.getI32Type(), value, multiplier, shift,
roundingMode);
// Move to the new zero-point.
value =
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
// Unsigned integers have a difference output value.
if (op.getOutputUnsigned()) {
intMin = 0;
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
}
auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMin));
auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
nestedBuilder, /*isUnsigned=*/false);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
value);
}
if (outIntType.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(nestedLoc,
outIntType, value)
.getResult(0);
}
nestedBuilder.create<linalg::YieldOp>(loc, value);
});
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
// Handle the resize case where the input is a 1x1 image. This case
// can entirely avoiding having extract operations which target much
// more difficult to optimize away.
class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ImplicitLocOpBuilder builder(loc, rewriter);
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
const bool isBilinear = op.getMode() == "BILINEAR";
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
auto outputH = resultTy.getDimSize(1);
auto outputW = resultTy.getDimSize(2);
if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
return rewriter.notifyMatchFailure(
op, "tosa.resize is not a pure 1x1->1x1 image operation");
// TODO(suderman): These string values should be declared the TOSA dialect.
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
if (inputTy == resultTy) {
rewriter.replaceOp(op, input);
return success();
}
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
return failure();
}
// Collapse the unit width and height away.
SmallVector<ReassociationExprs, 4> reassociationMap(2);
reassociationMap[0].push_back(builder.getAffineDimExpr(0));
reassociationMap[1].push_back(builder.getAffineDimExpr(1));
reassociationMap[1].push_back(builder.getAffineDimExpr(2));
reassociationMap[1].push_back(builder.getAffineDimExpr(3));
auto collapseTy =
RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
inputTy.getElementType());
Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
reassociationMap);
// Get any dynamic shapes that appear in the input format.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
if (inputTy.isDynamicDim(3))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
// Generate the elementwise operation for casting scaling the input value.
auto genericTy = collapseTy.clone(resultTy.getElementType());
Value empty = builder.create<tensor::EmptyOp>(
genericTy.getShape(), resultTy.getElementType(), outputDynSize);
auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
utils::IteratorType::parallel);
auto generic = builder.create<linalg::GenericOp>(
genericTy, ValueRange{collapse}, ValueRange{empty},
ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
// This is the quantized case.
if (inputTy.getElementType() != resultTy.getElementType()) {
value =
b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
if (isBilinear && scale[0] != 0) {
Value scaleY = b.create<arith::ConstantOp>(
loc, b.getI32IntegerAttr(scale[0]));
value = b.create<arith::MulIOp>(loc, value, scaleY);
}
if (isBilinear && scale[2] != 0) {
Value scaleX = b.create<arith::ConstantOp>(
loc, b.getI32IntegerAttr(scale[2]));
value = b.create<arith::MulIOp>(loc, value, scaleX);
}
}
b.create<linalg::YieldOp>(loc, value);
});
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultTy, generic.getResults()[0], reassociationMap);
return success();
}
};
// TOSA resize with width or height of 1 may be broadcasted to a wider
// dimension. This is done by materializing a new tosa.resize without
// the broadcasting behavior, and an explicit broadcast afterwards.
class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ImplicitLocOpBuilder builder(loc, rewriter);
auto input = op.getInput();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
if (!inputTy || !resultTy)
return rewriter.notifyMatchFailure(op,
"requires ranked input/output types");
auto batch = inputTy.getDimSize(0);
auto channels = inputTy.getDimSize(3);
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
auto outputH = resultTy.getDimSize(1);
auto outputW = resultTy.getDimSize(2);
if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
return rewriter.notifyMatchFailure(
op, "tosa.resize has no broadcasting behavior");
// For any dimension that is broadcastable we generate a width of 1
// on the output.
llvm::SmallVector<int64_t> resizeShape;
resizeShape.push_back(batch);
resizeShape.push_back(inputH == 1 ? 1 : outputH);
resizeShape.push_back(inputW == 1 ? 1 : outputW);
resizeShape.push_back(channels);
auto resizeTy = resultTy.clone(resizeShape);
auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
op.getOffset(), op.getBorder(),
op.getMode());
// Collapse an unit result dims.
SmallVector<ReassociationExprs, 4> reassociationMap(2);
reassociationMap[0].push_back(builder.getAffineDimExpr(0));
reassociationMap.back().push_back(builder.getAffineDimExpr(1));
if (inputH != 1)
reassociationMap.push_back({});
reassociationMap.back().push_back(builder.getAffineDimExpr(2));
if (inputW != 1)
reassociationMap.push_back({});
reassociationMap.back().push_back(builder.getAffineDimExpr(3));
llvm::SmallVector<int64_t> collapseShape = {batch};
if (inputH != 1)
collapseShape.push_back(outputH);
if (inputW != 1)
collapseShape.push_back(outputW);
collapseShape.push_back(channels);
auto collapseTy = resultTy.clone(collapseShape);
Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
reassociationMap);
// Broadcast the collapsed shape to the output result.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
if (inputTy.isDynamicDim(3))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
utils::IteratorType::parallel);
Value empty = builder.create<tensor::EmptyOp>(
resultTy.getShape(), resultTy.getElementType(), outputDynSize);
SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
if (inputH != 1)
inputExprs.push_back(rewriter.getAffineDimExpr(1));
if (inputW != 1)
inputExprs.push_back(rewriter.getAffineDimExpr(2));
inputExprs.push_back(rewriter.getAffineDimExpr(3));
auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
inputExprs, rewriter.getContext());
auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, ValueRange{collapse}, ValueRange{empty},
ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
b.create<linalg::YieldOp>(loc, value);
});
return success();
}
};
class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
auto input = op.getInput();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
if (!dynamicDimsOr.has_value())
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
*dynamicDimsOr);
auto genericOp = b.create<linalg::GenericOp>(
resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
Value resize = genericOp.getResult(0);
{
OpBuilder::InsertionGuard regionGuard(b);
b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
TypeRange({resultETy}), loc);
Value batch = b.create<linalg::IndexOp>(0);
Value y = b.create<linalg::IndexOp>(1);
Value x = b.create<linalg::IndexOp>(2);
Value channel = b.create<linalg::IndexOp>(3);
Value zeroI32 =
b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
SmallVector<int64_t> scale, offset, border;
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
!tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
return rewriter.notifyMatchFailure(
op, "tosa.resize scale/offset/border should have compile time "
"constant values.");
}
Value yScaleN, yScaleD, xScaleN, xScaleD;
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
Value yOffset, xOffset, yBorder, xBorder;
yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
// Compute the ix and dx values for both the X and Y dimensions.
auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
Value scaleN, Value scaleD, Value offset,
int size, ImplicitLocOpBuilder &b) {
if (size == 1) {
index = zeroI32;
delta = zeroFp;
return;
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
Value val = b.create<arith::MulIOp>(in, scaleD);
val = b.create<arith::AddIOp>(val, offset);
index = b.create<arith::FloorDivSIOp>(val, scaleN);
// rx = x % scale_n
// dx = rx / scale_n
Value r = b.create<arith::RemSIOp>(val, scaleN);
Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
delta = b.create<arith::DivFOp>(rFp, scaleNfp);
};
// Compute the ix and dx values for the X and Y dimensions - int case.
auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
Value scaleN, Value scaleD, Value offset,
int size, ImplicitLocOpBuilder &b) {
if (size == 1) {
index = zeroI32;
delta = zeroI32;
return;
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
// dx = x - ix * scale_n;
Value val = b.create<arith::MulIOp>(in, scaleD);
val = b.create<arith::AddIOp>(val, offset);
index = b.create<arith::DivSIOp>(val, scaleN);
delta = b.create<arith::MulIOp>(index, scaleN);
delta = b.create<arith::SubIOp>(val, delta);
};
Value ix, iy, dx, dy;
if (floatingPointMode) {
getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
} else {
getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
if (op.getMode() == "NEAREST_NEIGHBOR") {
auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
Value max, int size,
ImplicitLocOpBuilder &b) -> Value {
if (size == 1) {
return b.create<arith::ConstantIndexOp>(0);
}
Value pred;
if (floatingPointMode) {
auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
} else {
Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
dvalDouble, scale);
}
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
val = b.create<arith::AddIOp>(val, offset);
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
};
iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
Value result = b.create<tensor::ExtractOp>(
input, ValueRange{batch, iy, ix, channel});
b.create<linalg::YieldOp>(result);
} else {
// The mode here must be BILINEAR.
assert(op.getMode() == "BILINEAR");
auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
val1 = b.create<arith::AddIOp>(val0, oneVal);
val0 =
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
val1 =
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
};
// Linalg equivalent to the section below:
// int16_t iy0 = apply_max(iy, 0);
// int16_t iy1 = apply_min(iy + 1, IH - 1);
// int16_t ix0 = apply_max(ix, 0);
// int16_t ix1 = apply_min(ix + 1, IW - 1);
Value x0, x1, y0, y1;
getClampedIdxs(y0, y1, imageH, iy, hMax, b);
getClampedIdxs(x0, x1, imageW, ix, wMax, b);
Value y0x0 = b.create<tensor::ExtractOp>(
input, ValueRange{batch, y0, x0, channel});
Value y0x1 = b.create<tensor::ExtractOp>(
input, ValueRange{batch, y0, x1, channel});
Value y1x0 = b.create<tensor::ExtractOp>(
input, ValueRange{batch, y1, x0, channel});
Value y1x1 = b.create<tensor::ExtractOp>(
input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
auto oneVal =
b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
auto interpolate = [&](Value val0, Value val1, Value delta,
int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
return val0;
Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
Value mul1 = b.create<arith::MulFOp>(val1, delta);
return b.create<arith::AddFOp>(mul0, mul1);
};
// Linalg equivalent to the section below:
// topAcc = v00 * (unit_x - dx);
// topAcc += v01 * dx;
Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
// Linalg equivalent to the section below:
// bottomAcc = v10 * (unit_x - dx);
// bottomAcc += v11 * dx;
Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
// Linalg equivalent to the section below:
// result = topAcc * (unit_y - dy) + bottomAcc * dy
Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
b.create<linalg::YieldOp>(result);
} else {
// Perform in quantized space.
y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
dx = b.create<arith::ExtSIOp>(resultETy, dx);
dy = b.create<arith::ExtSIOp>(resultETy, dy);
}
Value yScaleNExt = yScaleN;
Value xScaleNExt = xScaleN;
const int64_t scaleBitwidth =
xScaleN.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
}
auto interpolate = [](Value val0, Value val1, Value weight1,
Value scale, int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
return b.create<arith::MulIOp>(val0, scale);
Value weight0 = b.create<arith::SubIOp>(scale, weight1);
Value mul0 = b.create<arith::MulIOp>(val0, weight0);
Value mul1 = b.create<arith::MulIOp>(val1, weight1);
return b.create<arith::AddIOp>(mul0, mul1);
};
Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
Value result =
interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
b.create<linalg::YieldOp>(result);
}
}
}
rewriter.replaceOp(op, resize);
return success();
}
};
// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOp(op, op.getOperation()->getOperands());
return success();
}
};
template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp reduceOp,
PatternRewriter &rewriter) const final {
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
}
};
class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
public:
using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReverseOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = op.getInput1();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(op.getType());
auto axis = op.getAxis();
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
// First fill the output buffer with the init value.
auto emptyTensor = rewriter
.create<tensor::EmptyOp>(loc, inputTy.getShape(),
inputTy.getElementType(),
ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
Value index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (i == axis) {
auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
auto sizeMinusOne =
rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
index);
}
indices.push_back(index);
}
auto extract = nestedBuilder.create<tensor::ExtractOp>(
nestedLoc, input, indices);
nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
extract.getResult());
});
return success();
}
};
// This converter translate a tile operation to a reshape, broadcast, reshape.
// The first reshape minimally expands each tiled dimension to include a
// proceding size-1 dim. This dim is then broadcasted to the appropriate
// multiple.
struct TileConverter : public OpConversionPattern<tosa::TileOp> {
using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput1();
auto inputTy = cast<ShapedType>(input.getType());
auto inputShape = inputTy.getShape();
auto resultTy = cast<ShapedType>(op.getType());
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
SmallVector<int64_t> multiples;
if (failed(op.getConstantMultiples(multiples)))
return failure();
// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
for (int i = 0; i < rank; i++) {
int64_t dim = multiples[i];
genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
genericShape.push_back(inputShape[i]);
}
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
op.getLoc(), genericShape, elementTy, dynDims);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
auto readAffineMap =
AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
rewriter.getContext());
SmallVector<AffineMap, 2> affineMaps = {
readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, RankedTensorType::get(genericShape, elementTy), input,
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(genericShape.size()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
auto shapeValue = getTosaConstShape(
rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
public:
using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
PatternRewriter &rewriter) const final {
auto loc = argmaxOp.getLoc();
Value input = argmaxOp.getInput();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
auto inElementTy = inputTy.getElementType();
auto outElementTy = resultTy.getElementType();
int axis = argmaxOp.getAxis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
if (!isa<IntegerType>(outElementTy))
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
// First fill the output buffer for the index.
auto emptyTensorIdx = rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
outElementTy, dynDims)
.getResult();
auto fillValueIdx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
ValueRange{emptyTensorIdx})
.result();
// Second fill the output buffer for the running max.
auto emptyTensorMax = rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
inElementTy, dynDims)
.getResult();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
if (!fillValueMaxAttr)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
auto filledTensorMax =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueMax},
ValueRange{emptyTensorMax})
.result();
// We need to reduce along the arg-max axis, with parallel operations along
// the rest.
SmallVector<utils::IteratorType, 4> iteratorTypes;
iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
iteratorTypes[axis] = utils::IteratorType::reduction;
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
auto newValue = blockArgs[0];
auto oldIndex = blockArgs[1];
auto oldValue = blockArgs[2];
Value newIndex = rewriter.create<arith::IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, axis));
Value predicate;
if (isa<FloatType>(inElementTy)) {
if (argmaxOp.getNanMode() == "IGNORE") {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
} else {
// Update max value if either of the following is true:
// - new value is bigger
// - cur max is not NaN and new value is NaN
Value gt = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
Value oldNonNaN = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
predicate = rewriter.create<arith::AndIOp>(
nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
}
} else if (isa<IntegerType>(inElementTy)) {
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
} else {
didEncounterError = true;
return;
}
auto resultMax = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
if (didEncounterError)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported tosa.argmax element type");
rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
return success();
}
};
class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
public:
using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto input = adaptor.getOperands()[0];
auto indices = adaptor.getOperands()[1];
auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
if (!valuesTy || !resultTy)
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
auto dynamicDims = inferDynamicDimsForGather(
rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
auto resultElementTy = resultTy.getElementType();
auto loc = op.getLoc();
auto emptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
dynamicDims)
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto indexValue = args[0];
auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
Value index1 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), indexValue);
auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
Value extract = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{index0, index1, index2});
rewriter.create<linalg::YieldOp>(loc, extract);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
}
static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
Location loc,
Value values,
Value indices) {
llvm::SmallVector<Value> results;
auto addDynamicDimension = [&](Value source, int64_t dim) {
auto sz = tensor::getMixedSize(builder, loc, source, dim);
if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
results.push_back(dimValue);
};
addDynamicDimension(values, 0);
addDynamicDimension(indices, 1);
addDynamicDimension(values, 2);
return results;
}
};
// Lowerings the TableOp to a series of gathers and numerica operations. This
// includes interpolation between the high/low values. For the I8 varient, this
// simplifies to a single gather operation.
class TableConverter : public OpRewritePattern<tosa::TableOp> {
public:
using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TableOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = op.getInput1();
Value table = op.getTable();
auto inputTy = cast<ShapedType>(input.getType());
auto tableTy = cast<ShapedType>(table.getType());
auto resultTy = cast<ShapedType>(op.getType());
auto inputElementTy = inputTy.getElementType();
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
SmallVector<Value> dynDims;
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(
rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
}
}
auto emptyTensor = rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
resultElementTy, dynDims)
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
OpBuilder::InsertionGuard regionGuard(rewriter);
Block *block = rewriter.createBlock(
&genericOp.getRegion(), genericOp.getRegion().end(),
TypeRange({inputElementTy, resultElementTy}), {loc, loc});
auto inputValue = block->getArgument(0);
rewriter.setInsertionPointToStart(block);
if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
resultElementTy.isInteger(8)) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), inputValue);
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
index, offset);
Value extract =
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
rewriter.create<linalg::YieldOp>(loc, extract);
return success();
}
if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
resultElementTy.isInteger(32)) {
Value extend = rewriter.create<arith::ExtSIOp>(
loc, rewriter.getI32Type(), inputValue);
auto offset = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(32768));
auto seven = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(7));
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(1));
auto b1111111 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(127));
// Compute the index and fractional part from the input value:
// value = value + 32768
// index = value >> 7;
// fraction = 0x01111111 & value
auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
Value fraction =
rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
// Extract the base and next values from the table.
// base = (int32_t) table[index];
// next = (int32_t) table[index + 1];
Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), index);
indexPlusOne = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), indexPlusOne);
Value base =
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
Value next = rewriter.create<tensor::ExtractOp>(
loc, table, ValueRange{indexPlusOne});
base =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
next =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
// Use the fractional part to interpolate between the input values:
// result = (base << 7) + (next - base) * fraction
Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
Value result =
rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
}
return rewriter.notifyMatchFailure(
op, "unable to create body for tosa.table op");
}
};
struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
OpFoldResult ofr) {
auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
return getAsOpFoldResult(plusOne);
}
static RankedTensorType
computeOutputShape(OpBuilder &builder, Location loc, Value input,
llvm::SmallVectorImpl<Value> &dynamicSizes) {
// Get [N, H, W]
auto dims = tensor::getMixedSizes(builder, loc, input);
// Set W = (W / 2) + 1 to account for the half-sized W dimension of the
// output tensors.
dims[2] = halfPlusOne(builder, loc, dims[2]);
llvm::SmallVector<int64_t, 3> staticSizes;
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
return RankedTensorType::get(staticSizes, elementType);
}
static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
RankedTensorType type,
llvm::ArrayRef<Value> dynamicSizes) {
auto emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
.result();
return filledTensor;
}
static Value castIndexToFloat(OpBuilder &builder, Location loc,
FloatType type, Value value) {
auto integerVal = builder.create<arith::IndexCastUIOp>(
loc,
type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
: builder.getI32Type(),
value);
return builder.create<arith::UIToFPOp>(loc, type, integerVal);
}
static Value createLinalgIndex(OpBuilder &builder, Location loc,
FloatType type, int64_t index) {
auto indexVal = builder.create<linalg::IndexOp>(loc, index);
return castIndexToFloat(builder, loc, type, indexVal);
}
template <typename... Args>
static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
Args... args) {
return {builder.getAffineDimExpr(args)...};
}
LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
PatternRewriter &rewriter) const override {
if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
!llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
return rewriter.notifyMatchFailure(rfft2d,
"only supports ranked tensors");
}
auto loc = rfft2d.getLoc();
auto input = rfft2d.getInputReal();
auto elementType =
dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
if (!elementType)
return rewriter.notifyMatchFailure(rfft2d,
"only supports float element types");
// Compute the output type and set of dynamic sizes
llvm::SmallVector<Value> dynamicSizes;
auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
// Iterator types for the linalg.generic implementation
llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::parallel, utils::IteratorType::reduction,
utils::IteratorType::reduction};
// Inputs/outputs to the linalg.generic implementation
llvm::SmallVector<Value> genericOpInputs = {input};
llvm::SmallVector<Value> genericOpOutputs = {
createZeroTensor(rewriter, loc, outputType, dynamicSizes),
createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
// Indexing maps for input and output tensors
auto indexingMaps = AffineMap::inferFromExprList(
llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
affineDimsExpr(rewriter, 0, 1, 2),
affineDimsExpr(rewriter, 0, 1, 2)},
rewriter.getContext());
// Width and height dimensions of the original input.
auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
Value valReal = args[0];
Value sumReal = args[1];
Value sumImag = args[2];
// Indices for angle computation
Value oy = builder.create<linalg::IndexOp>(loc, 1);
Value ox = builder.create<linalg::IndexOp>(loc, 2);
Value iy = builder.create<linalg::IndexOp>(loc, 3);
Value ix = builder.create<linalg::IndexOp>(loc, 4);
// Calculating angle without integer parts of components as sin/cos are
// periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
// / W);
auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
// realComponent = valReal * cos(angle)
// imagComponent = valReal * sin(angle)
auto cosAngle = builder.create<math::CosOp>(loc, angle);
auto sinAngle = builder.create<math::SinOp>(loc, angle);
auto realComponent =
builder.create<arith::MulFOp>(loc, valReal, cosAngle);
auto imagComponent =
builder.create<arith::MulFOp>(loc, valReal, sinAngle);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
indexingMaps, iteratorTypes, buildBody);
return success();
}
};
struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(FFT2dOp fft2d,
PatternRewriter &rewriter) const override {
if (!llvm::all_of(fft2d->getOperandTypes(),
RFFT2dConverter::isRankedTensor) ||
!llvm::all_of(fft2d->getResultTypes(),
RFFT2dConverter::isRankedTensor)) {
return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
}
Location loc = fft2d.getLoc();
Value input_real = fft2d.getInputReal();
Value input_imag = fft2d.getInputImag();
BoolAttr inverse = fft2d.getInverseAttr();
auto real_el_ty = cast<FloatType>(
cast<ShapedType>(input_real.getType()).getElementType());
[[maybe_unused]] auto imag_el_ty = cast<FloatType>(
cast<ShapedType>(input_imag.getType()).getElementType());
assert(real_el_ty == imag_el_ty);
// Compute the output type and set of dynamic sizes
SmallVector<Value> dynamicSizes;
// Get [N, H, W]
auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
SmallVector<int64_t, 3> staticSizes;
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
// Iterator types for the linalg.generic implementation
SmallVector<utils::IteratorType, 5> iteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::parallel, utils::IteratorType::reduction,
utils::IteratorType::reduction};
// Inputs/outputs to the linalg.generic implementation
SmallVector<Value> genericOpInputs = {input_real, input_imag};
SmallVector<Value> genericOpOutputs = {
RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
dynamicSizes),
RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
dynamicSizes)};
// Indexing maps for input and output tensors
auto indexingMaps = AffineMap::inferFromExprList(
ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
rewriter.getContext());
// Width and height dimensions of the original input.
auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
Value constH =
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
Value constW =
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
Value valReal = args[0];
Value valImag = args[1];
Value sumReal = args[2];
Value sumImag = args[3];
// Indices for angle computation
Value oy = builder.create<linalg::IndexOp>(loc, 1);
Value ox = builder.create<linalg::IndexOp>(loc, 2);
Value iy = builder.create<linalg::IndexOp>(loc, 3);
Value ix = builder.create<linalg::IndexOp>(loc, 4);
// float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
// ox) % W ) / W);
auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
auto iyRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
auto ixRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
if (inverse.getValue()) {
angle = builder.create<arith::MulFOp>(
loc, angle,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
}
// realComponent = val_real * cos(a) + val_imag * sin(a);
// imagComponent = -val_real * sin(a) + val_imag * cos(a);
auto cosAngle = builder.create<math::CosOp>(loc, angle);
auto sinAngle = builder.create<math::SinOp>(loc, angle);
auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
indexingMaps, iteratorTypes, buildBody);
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToLinalgConversionPatterns(
const TypeConverter &converter, RewritePatternSet *patterns) {
// We have multiple resize coverters to handle degenerate cases.
patterns->add<GenericResizeConverter>(patterns->getContext(),
/*benefit=*/100);
patterns->add<ResizeUnaryConverter>(patterns->getContext(),
/*benefit=*/200);
patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
/*benefit=*/300);
patterns->add<
// clang-format off
PointwiseConverter<tosa::AddOp>,
PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>,
PointwiseConverter<tosa::IntDivOp>,
PointwiseConverter<tosa::NegateOp>,
PointwiseConverter<tosa::PowOp>,
PointwiseConverter<tosa::ReciprocalOp>,
PointwiseConverter<tosa::RsqrtOp>,
PointwiseConverter<tosa::LogOp>,
PointwiseConverter<tosa::ExpOp>,
PointwiseConverter<tosa::AbsOp>,
PointwiseConverter<tosa::SinOp>,
PointwiseConverter<tosa::CosOp>,
PointwiseConverter<tosa::TanhOp>,
PointwiseConverter<tosa::ErfOp>,
PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseNotOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalAndOp>,
PointwiseConverter<tosa::LogicalNotOp>,
PointwiseConverter<tosa::LogicalOrOp>,
PointwiseConverter<tosa::LogicalXorOp>,
PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::ArithmeticRightShiftOp>,
PointwiseConverter<tosa::ClzOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::EqualOp>,
PointwiseConverter<tosa::MaximumOp>,
PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>,
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
PointwiseConverter<tosa::SigmoidOp>
>(converter, patterns->getContext());
patterns->add<
IdentityNConverter<tosa::IdentityOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>,
ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProductOp>,
ArgMaxConverter,
GatherConverter,
RescaleConverter,
ReverseConverter,
RFFT2dConverter,
FFT2dConverter,
TableConverter,
TileConverter>(patterns->getContext());
// clang-format on
}