//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include using namespace mlir; using namespace mlir::tosa; // Helper function to materialize the semantically correct compare and select // operations given a binary operation with a specific NaN propagation mode. // // In the case of "PROPAGATE" semantics no compare and selection is required and // this function does nothing. // // In the case of "IGNORE" semantics this function materializes a comparison of // the current operands to the op which will return true for any NaN // argument and then selects between the non-NaN operation argument and the // calculated result based on whether the lhs or rhs is NaN or not. In pseudo // code: // // In the case that the op is operating on non floating point types we ignore // the attribute completely, this is consistent with the TOSA spec which has // the following wording: "This attribute is ignored by non floating-point // types." // // binary(lhs, rhs): // result = op(lhs, rhs) // if lhs == NaN return rhs // if rhs == NaN return lhs // return result template static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result) { // NaN propagation has no meaning for non floating point types. if (!isa(getElementTypeOrSelf(lhs))) return result; auto nanMode = op.getNanMode(); if (nanMode == "PROPAGATE") return result; // Unordered comparison of NaN against itself will always return true. Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs); Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs); Value rhsOrResult = arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result); return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs, rhsOrResult); } static Value createLinalgBodyCalculationForElementwiseOp( Operation *op, ValueRange args, ArrayRef resultTypes, ConversionPatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = cast(op->getOperand(0).getType()).getElementType(); // tosa::AbsOp if (isa(op) && isa(elementTy)) return math::AbsFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) { auto zero = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(elementTy)); auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]); return arith::MaxSIOp::create(rewriter, loc, args[0], neg); } // tosa::AddOp if (isa(op) && isa(elementTy)) return arith::AddFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) return arith::AddIOp::create(rewriter, loc, resultTypes, args); // tosa::SubOp if (isa(op) && isa(elementTy)) return arith::SubFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) return arith::SubIOp::create(rewriter, loc, resultTypes, args); // tosa::IntDivOp if (isa(op) && isa(elementTy)) return arith::DivSIOp::create(rewriter, loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && isa(elementTy)) { auto one = arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]); } // tosa::MulOp if (isa(op)) { auto shiftVal = cast(op).getShift(); DenseElementsAttr shiftElem; if (!matchPattern(shiftVal, m_Constant(&shiftElem))) { (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); return nullptr; } int32_t shift = shiftElem.getValues()[0].getInt(); if (isa(elementTy)) { if (shift != 0) { (void)rewriter.notifyMatchFailure(op, "Cannot have shift value for float"); return nullptr; } return arith::MulFOp::create(rewriter, loc, resultTypes, args[0], args[1]); } if (isa(elementTy)) { Value a = args[0]; Value b = args[1]; if (shift > 0) { auto shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b); auto result = tosa::ApplyScaleOp::create( rewriter, loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getStringAttr("SINGLE_ROUND")); if (elementTy.isInteger(32)) return result; return arith::TruncIOp::create(rewriter, loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); int bWidth = b.getType().getIntOrFloatBitWidth(); int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a); if (bWidth < cWidth) b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b); return arith::MulIOp::create(rewriter, loc, resultTypes, a, b); } } // tosa::NegateOp if (isa(op)) { auto negate = cast(op); FailureOr maybeInZp = negate.getInput1ZeroPoint(); if (failed(maybeInZp)) { (void)rewriter.notifyMatchFailure( op, "input1 zero point cannot be statically determined"); return nullptr; } FailureOr maybeOutZp = negate.getOutputZeroPoint(); if (failed(maybeOutZp)) { (void)rewriter.notifyMatchFailure( op, "output zero point cannot be statically determined"); return nullptr; } int64_t inZp = *maybeInZp; int64_t outZp = *maybeOutZp; if (isa(elementTy)) return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa(elementTy)) { if (!inZp && !outZp) { auto constant = arith::ConstantOp::create( rewriter, loc, IntegerAttr::get(elementTy, 0)); return arith::SubIOp::create(rewriter, loc, resultTypes, constant, args[0]); } // Compute the maximum value that can occur in the intermediate buffer. const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); const int64_t zpAdd = inZp + outZp; const int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + std::abs(zpAdd) + 1; // Convert that maximum value into the maximum bitwidth needed to // represent it. We assume 48-bit numbers may be supported further in // the pipeline. int intermediateBitWidth = 64; if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { intermediateBitWidth = 16; } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { intermediateBitWidth = 32; } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { intermediateBitWidth = 48; } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); Value zpAddValue = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]); auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext); // Clamp to the negation range. Value min = arith::ConstantIntOp::create( rewriter, loc, intermediateType, APInt::getSignedMinValue(inputBitWidth).getSExtValue()); Value max = arith::ConstantIntOp::create( rewriter, loc, intermediateType, APInt::getSignedMaxValue(inputBitWidth).getSExtValue()); auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false); // Truncate to the final value. return arith::TruncIOp::create(rewriter, loc, elementTy, clamp); } } // tosa::BitwiseAndOp if (isa(op) && isa(elementTy)) return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && isa(elementTy)) return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr); return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && isa(elementTy)) return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && isa(elementTy)) return arith::ShLIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && isa(elementTy)) return arith::ShRUIOp::create(rewriter, loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && isa(elementTy)) { auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args); auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; } Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 1)); auto zero = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 0)); auto i1one = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 auto shiftValueGreaterThanZero = arith::CmpIOp::create( rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one); auto shifted = arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract) ->getResults(); auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted, ArrayRef()); auto isInputOdd = arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one); auto shouldRound = arith::AndIOp::create( rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound); return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && isa(elementTy)) { return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]); } // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { auto one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(elementTy, 1)); return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::PowOp if (isa(op) && isa(elementTy)) return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args); // tosa::RsqrtOp if (isa(op) && isa(elementTy)) return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args); // tosa::LogOp if (isa(op) && isa(elementTy)) return mlir::math::LogOp::create(rewriter, loc, resultTypes, args); // tosa::ExpOp if (isa(op) && isa(elementTy)) return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args); // tosa::SinOp if (isa(op) && isa(elementTy)) return mlir::math::SinOp::create(rewriter, loc, resultTypes, args); // tosa::CosOp if (isa(op) && isa(elementTy)) return mlir::math::CosOp::create(rewriter, loc, resultTypes, args); // tosa::TanhOp if (isa(op) && isa(elementTy)) return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args); // tosa::ErfOp if (isa(op) && llvm::isa(elementTy)) return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && isa(elementTy)) return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && isa(elementTy)) return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, args[0], args[1]); // tosa::EqualOp if (isa(op) && isa(elementTy)) return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, args[0], args[1]); // tosa::SelectOp if (isa(op)) { elementTy = cast(op->getOperand(1).getType()).getElementType(); if (isa(elementTy) || isa(elementTy)) return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && isa(elementTy)) { auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], max); } if (isa(op) && elementTy.isSignlessInteger()) { return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && isa(elementTy)) { auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], min); } if (isa(op) && elementTy.isSignlessInteger()) { return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::CeilOp if (isa(op) && isa(elementTy)) return math::CeilOp::create(rewriter, loc, resultTypes, args); // tosa::FloorOp if (isa(op) && isa(elementTy)) return math::FloorOp::create(rewriter, loc, resultTypes, args); // tosa::ClampOp if (isa(op) && isa(elementTy)) { bool losesInfo = false; APFloat minApf = cast(op->getAttr("min_val")).getValue(); APFloat maxApf = cast(op->getAttr("max_val")).getValue(); minApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); auto min = arith::ConstantOp::create( rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); auto max = arith::ConstantOp::create( rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); auto result = clampFloatHelper(loc, args[0], min, max, rewriter); auto clampOp = llvm::cast(op); const auto nanMode = clampOp.getNanMode(); // NaN propagation has no meaning for non floating point types. if (!isa(elementTy)) return result; // In the case of "PROPAGATE" semantics no compare and selection is // required. if (nanMode == "PROPAGATE") return result; // In the case of "IGNORE" semantics materialize a comparison // of the current operand to the reduction which will return true for a NaN // argument and then selects between the initial reduction value and the // calculated result based on whether the argument is NaN or not. In pseudo // code: // // reduce(x, init): // result = op(init, x) // return init if x == NaN else result // Unordered comparison of NaN against itself will always return true. Value isNaN = arith::CmpFOp::create( rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); // TOSA specifies that in "ignore" NaN mode the result is "min" if the input // is NaN. return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result); } if (isa(op) && isa(elementTy)) { auto intTy = cast(elementTy); int64_t min = cast(op->getAttr("min_val")).getValue().getSExtValue(); int64_t max = cast(op->getAttr("max_val")).getValue().getSExtValue(); int64_t minRepresentable = std::numeric_limits::min(); int64_t maxRepresentable = std::numeric_limits::max(); if (intTy.isUnsignedInteger()) { minRepresentable = 0; if (intTy.getIntOrFloatBitWidth() <= 63) { maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth()) .getZExtValue(); } } else if (intTy.getIntOrFloatBitWidth() <= 64) { // Ensure that min & max fit into signed n-bit constants. minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) .getSExtValue(); maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) .getSExtValue(); } // Ensure that the bounds are representable as n-bit signed/unsigned // integers. min = std::max(min, minRepresentable); max = std::max(max, minRepresentable); min = std::min(min, maxRepresentable); max = std::min(max, maxRepresentable); auto minVal = arith::ConstantIntOp::create(rewriter, loc, min, intTy.getIntOrFloatBitWidth()); auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max, intTy.getIntOrFloatBitWidth()); return clampIntHelper(loc, args[0], minVal, maxVal, rewriter, intTy.isUnsignedInteger()); } // tosa::SigmoidOp if (isa(op) && isa(elementTy)) { auto one = arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate); auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one); return arith::DivFOp::create(rewriter, loc, resultTypes, one, added); } // tosa::CastOp if (isa(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) { (void)rewriter.notifyMatchFailure(op, "unsupported type"); return nullptr; } bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); if (srcTy == dstTy) return args.front(); if (isa(srcTy) && isa(dstTy) && bitExtend) return arith::ExtFOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) return arith::TruncFOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) return arith::UIToFPOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) return arith::ExtUIOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. if (srcTy.isUnsignedInteger() && isa(dstTy)) { auto unrealizedCast = rewriter .create( loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); } // All other si-to-fp conversions should be handled by SIToFP. if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) return arith::SIToFPOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // Casting to boolean, floats need to only be checked as not-equal to zero. if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(srcTy, 0.0)); return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE, args.front(), zero); } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]); const auto &fltSemantics = cast(srcTy).getFloatSemantics(); // Check whether neither int min nor int max can be represented in the // input floating-point type due to too short exponent range. if (static_cast(dstTy.getIntOrFloatBitWidth()) - 1 > APFloat::semanticsMaxExponent(fltSemantics)) { // Use cmp + select to replace infinites by int min / int max. Other // integral values can be represented in the integer space. auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded); auto posInf = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), APFloat::getInf(fltSemantics))); auto negInf = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APFloat::getInf(fltSemantics, /*Negative=*/true))); auto overflow = arith::CmpFOp::create( rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf); auto underflow = arith::CmpFOp::create( rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf); auto intMin = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); auto intMax = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto maxClamped = arith::SelectOp::create(rewriter, loc, overflow, intMax, conv); return arith::SelectOp::create(rewriter, loc, underflow, intMin, maxClamped); } auto intMinFP = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); // Check whether the mantissa has enough bits to represent int max. if (cast(srcTy).getFPMantissaWidth() >= dstTy.getIntOrFloatBitWidth() - 1) { // Int min can also be represented since it is a power of two and thus // consists of a single leading bit. Therefore we can clamp the input // in the floating-point domain. auto intMaxFP = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); Value clamped = clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped); } // Due to earlier check we know exponant range is big enough to represent // int min. We can therefore rely on int max + 1 being representable as // well because it's just int min with a positive sign. So clamp the min // value and compare against that to select the max int value if needed. auto intMaxPlusOneFP = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), static_cast( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue()) + 1.0f)); auto intMax = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto minClampedFP = arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP); auto minClamped = arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP); auto overflow = arith::CmpFOp::create( rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); return arith::SelectOp::create(rewriter, loc, overflow, intMax, minClamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, srcTy.getIntOrFloatBitWidth()); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, args.front(), zero); } if (isa(srcTy) && isa(dstTy) && bitExtend) return arith::ExtSIOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) { return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]); } } (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; } using IndexPool = DenseMap; // Emit an 'arith.constant' op for the given index if it has not been created // yet, or return an existing constant. This will prevent an excessive creation // of redundant constants, easing readability of emitted code for unit tests. static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index) { auto [it, inserted] = indexPool.try_emplace(index); if (inserted) it->second = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index)); return it->second; } static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto indexValue = createIndex(rewriter, loc, indexPool, index); return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult(); } static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto shapedType = dyn_cast(tensor.getType()); assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); assert(index >= 0 && index < shapedType.getRank() && "index out of bounds"); if (shapedType.isDynamicDim(index)) return getTensorDim(rewriter, loc, indexPool, tensor, index); return rewriter.getIndexAttr(shapedType.getDimSize(index)); } static bool operandsAndResultsRanked(Operation *operation) { auto isRanked = [](Value value) { return isa(value.getType()); }; return llvm::all_of(operation->getOperands(), isRanked) && llvm::all_of(operation->getResults(), isRanked); } // Compute the runtime dimension size for dimension 'dim' of the output by // inspecting input 'operands', all of which are expected to have the same rank. // This function returns a pair {targetSize, masterOperand}. // // The runtime size of the output dimension is returned either as a statically // computed attribute or as a runtime SSA value. // // If the target size was inferred directly from one dominating operand, that // operand is returned in 'masterOperand'. If the target size is inferred from // multiple operands, 'masterOperand' is set to nullptr. static std::pair computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim) { // If any input operand contains a static size greater than 1 for this // dimension, that is the target size. An occurrence of an additional static // dimension greater than 1 with a different value is undefined behavior. for (auto operand : operands) { auto size = cast(operand.getType()).getDimSize(dim); if (ShapedType::isStatic(size) && size > 1) return {rewriter.getIndexAttr(size), operand}; } // Filter operands with dynamic dimension auto operandsWithDynamicDim = llvm::filter_to_vector(operands, [&](Value operand) { return cast(operand.getType()).isDynamicDim(dim); }); // If no operand has a dynamic dimension, it means all sizes were 1 if (operandsWithDynamicDim.empty()) return {rewriter.getIndexAttr(1), operands.front()}; // Emit code that computes the runtime size for this dimension. If there is // only one operand with a dynamic dimension, it is considered the master // operand that determines the runtime size of the output dimension. auto targetSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim); if (operandsWithDynamicDim.size() == 1) return {targetSize, operandsWithDynamicDim[0]}; // Calculate maximum size among all dynamic dimensions for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { auto nextSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize); } return {targetSize, nullptr}; } // Compute the runtime output size for all dimensions. This function returns // a pair {targetShape, masterOperands}. static std::pair, SmallVector> computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands) { assert(!operands.empty()); auto rank = cast(operands.front().getType()).getRank(); SmallVector targetShape; SmallVector masterOperands; for (auto dim : llvm::seq(0, rank)) { auto [targetSize, masterOperand] = computeTargetSize(rewriter, loc, indexPool, operands, dim); targetShape.push_back(targetSize); masterOperands.push_back(masterOperand); } return {targetShape, masterOperands}; } static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand) { // Nothing to do if this is a static dimension auto rankedTensorType = cast(operand.getType()); if (!rankedTensorType.isDynamicDim(dim)) return operand; // If the target size for this dimension was directly inferred by only taking // this operand into account, there is no need to broadcast. This is an // optimization that will prevent redundant control flow, and constitutes the // main motivation for tracking "master operands". if (operand == masterOperand) return operand; // Affine maps for 'linalg.generic' op auto rank = rankedTensorType.getRank(); SmallVector affineExprs; for (auto index : llvm::seq(0, rank)) { auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(index); affineExprs.push_back(affineExpr); } auto broadcastAffineMap = AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank); SmallVector affineMaps = {broadcastAffineMap, identityAffineMap}; // Check if broadcast is necessary auto one = createIndex(rewriter, loc, indexPool, 1); auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); auto broadcastNecessary = arith::CmpIOp::create( rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one); // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { // It is not safe to cache constants across regions. // New constants could potentially violate dominance requirements. IndexPool localPool; // Emit 'tensor.empty' op SmallVector outputTensorShape; for (auto index : llvm::seq(0, rank)) { auto size = index == dim ? targetSize : getOrFoldTensorDim(rewriter, loc, localPool, operand, index); outputTensorShape.push_back(size); } Value outputTensor = tensor::EmptyOp::create( opBuilder, loc, outputTensorShape, rankedTensorType.getElementType()); // Emit 'linalg.generic' op auto resultTensor = opBuilder .create( loc, outputTensor.getType(), operand, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { // Emit 'linalg.yield' op linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); }) .getResult(0); // Cast to original operand type if necessary auto castResultTensor = rewriter.createOrFold( loc, operand.getType(), resultTensor); // Emit 'scf.yield' op scf::YieldOp::create(opBuilder, loc, castResultTensor); }; // Emit 'else' region of 'scf.if' auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { scf::YieldOp::create(opBuilder, loc, operand); }; // Emit 'scf.if' op auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary, emitThenRegion, emitElseRegion); return ifOp.getResult(0); } static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef targetShape, ArrayRef masterOperands) { int64_t rank = cast(operand.getType()).getRank(); assert((int64_t)targetShape.size() == rank); assert((int64_t)masterOperands.size() == rank); for (auto index : llvm::seq(0, rank)) operand = broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, targetShape[index], masterOperands[index]); return operand; } static SmallVector broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, ArrayRef targetShape, ArrayRef masterOperands) { // No need to broadcast for unary operations if (operands.size() == 1) return operands; // Broadcast dynamic dimensions operand by operand return llvm::map_to_vector(operands, [&](Value operand) { return broadcastDynamicDimensions(rewriter, loc, indexPool, operand, targetShape, masterOperands); }); } static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef targetShape, const TypeConverter &converter) { // Generate output tensor auto resultType = cast_or_null( converter.convertType(operation->getResultTypes().front())); if (!resultType) { return rewriter.notifyMatchFailure(operation, "failed to convert type"); } Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape, resultType.getElementType()); // Create affine maps. Input affine maps broadcast static dimensions of size // 1. The output affine map is an identity map. // auto rank = resultType.getRank(); auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) { auto shape = cast(operand.getType()).getShape(); SmallVector affineExprs; for (auto it : llvm::enumerate(shape)) { // Prefer producting identity maps whenever possible (i.e. no broadcasting // needed) because some transforms (like reshape folding) // do not support affine constant exprs. bool requiresBroadcast = (it.value() == 1 && resultType.getDimSize(it.index()) != 1); auto affineExpr = requiresBroadcast ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(it.index()); affineExprs.push_back(affineExpr); } return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); }); affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Emit 'linalg.generic' op bool encounteredError = false; auto linalgOp = linalg::GenericOp::create( rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), {resultType.getElementType()}, rewriter); if (!opResult) { encounteredError = true; return; } linalg::YieldOp::create(opBuilder, loc, opResult); }); if (encounteredError) return rewriter.notifyMatchFailure( operation, "unable to create linalg.generic body for elementwise op"); // Cast 'linalg.generic' result into original result type if needed auto castResult = rewriter.createOrFold( loc, resultType, linalgOp->getResult(0)); rewriter.replaceOp(operation, castResult); return success(); } static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands) { // Shift cannot broadcast if (isa(operation)) return operands.take_front(2); // Input1_zp and output_zp cannot broadcast if (isa(operation)) return operands.take_front(1); return operands; } static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter) { // Collect op properties assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand"); if (!operandsAndResultsRanked(operation)) return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported"); // Lower operation IndexPool indexPool; auto loc = operation->getLoc(); auto operandsToBroadcast = getBroadcastableOperands(operation, operands); auto [targetShape, masterOperands] = computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast); auto broadcastOperands = broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast, targetShape, masterOperands); return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, targetShape, converter); } // Returns the constant initial value for a given reduction operation. The // attribute type varies depending on the element type required. static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 0.0); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 0); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 1.0); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 1); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), false)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), true)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1)); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), true)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); return {}; } // Creates the body calculation for a reduction. The operations vary depending // on the input type. static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && isa(elementTy)) { return arith::AddFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { return arith::AddIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { return arith::MulFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { return arith::MulIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) return arith::AndIOp::create(rewriter, loc, args); if (isa(op) && elementTy.isInteger(1)) return arith::OrIOp::create(rewriter, loc, args); return {}; } // Performs the match and rewrite for reduction operations. This includes // declaring a correctly sized initial value, and the linalg.generic operation // that reduces across the specified axis. template static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); auto inputTy = dyn_cast(op->getOperand(0).getType()); auto resultTy = dyn_cast(op->getResult(0).getType()); if (!inputTy || !resultTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); SmallVector reduceShape; SmallVector dynDims; for (unsigned i = 0; i < inputTy.getRank(); i++) { if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); if (inputTy.isDynamicDim(i)) dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } SmallVector inputs, outputs; inputs.push_back(input); // First fill the output buffer with the init value. auto emptyTensor = rewriter .create(loc, reduceShape, resultTy.getElementType(), dynDims) .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; if constexpr (std::is_same_v || std::is_same_v) { // NaN propagation has no meaning for non floating point types. if (isa(elementTy) && op.getNanMode() == "IGNORE") { isNanIgnoreMode = true; // Because the TOSA spec requires the result be NaN iff all elements in // the reduction are NaN we can't simply perform a compare and select. // Additionally we have to keep track of whether we've seen any non-NaN // values and then do a final select based on this predicate. auto trueAttr = rewriter.getBoolAttr(true); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = rewriter .create(loc, reduceShape, trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = rewriter .create(loc, ValueRange{trueValue}, ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we // need to have the same number of inputs and outputs. // // The second input isn't actually used anywhere since the value used to // update the NaN flag is calculated inside the body of the reduction and // then used to update an out value. // In order to satisfy type constraints we just pass another copy of the // input here. inputs.push_back(input); outputs.push_back(allResultsNaNTensor); } } bool didEncounterError = false; linalg::LinalgOp linalgOp = linalg::ReduceOp::create( rewriter, loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { std::array binaryArgs{ blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; auto result = createLinalgBodyCalculationForReduceOp( op, binaryArgs, elementTy, rewriter); if (result) didEncounterError = true; SmallVector resultsToYield; if (isNanIgnoreMode) { auto inputValue = blockArgs[0]; auto initialValue = blockArgs[2]; auto oldAllResultsNanFlagValue = blockArgs[3]; // Unordered comparison of NaN against itself will always return true. Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue); // If we've encountered a NaN, take the non-NaN value. auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(), isNaN, initialValue, result); // Update the flag which keeps track of whether we have seen a non-NaN // value. auto newAllResultsNanFlagValue = arith::AndIOp::create( nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN); resultsToYield.push_back(selectOp); resultsToYield.push_back(newAllResultsNanFlagValue); } else { resultsToYield.push_back(result); } linalg::YieldOp::create(nestedBuilder, loc, resultsToYield); }); if (!didEncounterError) return rewriter.notifyMatchFailure( op, "unable to create linalg.generic body for reduce op"); if (isNanIgnoreMode) { // Materialize a check to see whether we encountered any non-NaN values, if // we didn't we need to select a tensor of NaNs since the result will just // be the initial identity value propagated through all the compares and // selects inside the reduction. // Create a tensor full of NaNs. auto nanValueAttr = rewriter.getFloatAttr( elementTy, APFloat::getNaN(cast(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = rewriter .create(loc, reduceShape, resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = rewriter .create(loc, ValueRange{nanValue}, ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = rewriter .create(loc, reduceShape, resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: // result = NaN if "all results NaN" else result. SmallVector ins, outs; ins.push_back(linalgOp->getOpResult(1)); ins.push_back(nanFilledTensor); ins.push_back(linalgOp->getResult(0)); outs.push_back(finalEmptyTensor); auto linalgSelect = linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs); linalgOp = linalgSelect; } SmallVector reassociationMap; uint64_t expandInputRank = cast(linalgOp->getResults()[0].getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { int32_t dimToPush = i > axis ? i + 1 : i; reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); } if (expandInputRank != 0) { int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; reassociationMap[expandedDim].push_back( rewriter.getAffineDimExpr(expandedDim + 1)); } // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`, // since here we know which dimension to expand, and `tosa::ReshapeOp` would // not have access to such information. This matters when handling dynamically // sized tensors. rewriter.replaceOpWithNewOp( op, resultTy, linalgOp->getResults()[0], reassociationMap); return success(); } namespace { template class PointwiseConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using typename OpConversionPattern::OpAdaptor; LogicalResult matchAndRewrite(SrcOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const final { return elementwiseMatchAndRewriteHelper( op, operands.getOperands(), rewriter, *this->getTypeConverter()); } }; class RescaleConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::RescaleOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); auto inputTy = cast(op.getInput().getType()); auto outputTy = cast(op.getOutput().getType()); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error if (op.getRoundingMode() == "INEXACT_ROUND") return rewriter.notifyMatchFailure( op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not " "currently supported"); if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32()) return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); if (!isa(inputTy.getElementType())) return rewriter.notifyMatchFailure(op, "only support integer type"); SmallVector dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } // The shift and multiplier values. DenseElementsAttr shiftElems; if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) return rewriter.notifyMatchFailure( op, "tosa.rescale requires constant shift input values"); DenseElementsAttr multiplierElems; if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) return rewriter.notifyMatchFailure( op, "tosa.rescale requires constant multiplier input values"); llvm::SmallVector shiftValues = llvm::to_vector(shiftElems.getValues()); // explicit cast is required here llvm::SmallVector multiplierValues = llvm::to_vector( llvm::map_range(multiplierElems.getValues(), [](IntegerAttr attr) -> int32_t { return static_cast(attr.getInt()); })); // If we shift by more than the bitwidth, this just sets to 0. for (int i = 0, s = multiplierValues.size(); i < s; i++) { if (shiftValues[i] > 63) { shiftValues[i] = 0; multiplierValues[i] = 0; } } // Double round only occurs if shift is greater than 31, check that this // is ever true. bool doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" && llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") : rewriter.getStringAttr("SINGLE_ROUND"); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(rank)}; SmallVector genericInputs = {input}; // If we are rescaling per-channel then we need to store the multiplier // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { multiplierConstant = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ rewriter.getAffineDimExpr(rank - 1)}; auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); genericInputs.push_back(arith::ConstantOp::create( rewriter, loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, multiplierExprs, rewriter.getContext())); multiplierArg = indexingMaps.size() - 1; } // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { shiftConstant = arith::ConstantOp::create( rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { rewriter.getAffineDimExpr(rank - 1)}; auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); genericInputs.push_back(arith::ConstantOp::create( rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, rewriter.getContext())); shiftArg = indexingMaps.size() - 1; } // Indexing maps for output values. indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. Value emptyTensor = tensor::EmptyOp::create( rewriter, loc, outputTy.getShape(), outputTy.getElementType(), ArrayRef({dynDims})); auto linalgOp = linalg::GenericOp::create( rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; Type valueTy = value.getType(); FailureOr maybeIZp = op.getInputZeroPoint(); if (failed(maybeIZp)) { (void)rewriter.notifyMatchFailure( op, "input zero point cannot be statically determined"); return; } const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); // Extend zeropoint for sub-32bits widths. const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; auto inputZp = arith::ConstantOp::create( nestedBuilder, loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), *maybeIZp)); FailureOr maybeOZp = op.getOutputZeroPoint(); if (failed(maybeOZp)) { (void)rewriter.notifyMatchFailure( op, "output zero point cannot be statically determined"); return; }; IntegerType outIntType = cast(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); auto outputZp = arith::ConstantOp::create( nestedBuilder, loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { value = nestedBuilder .create( nestedLoc, nestedBuilder.getIntegerType( valueTy.getIntOrFloatBitWidth()), value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { if (op.getInputUnsigned()) { value = arith::ExtUIOp::create(nestedBuilder, nestedLoc, nestedBuilder.getI32Type(), value); } else { value = arith::ExtSIOp::create(nestedBuilder, nestedLoc, nestedBuilder.getI32Type(), value); } } value = arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp); value = tosa::ApplyScaleOp::create(nestedBuilder, loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode); // Move to the new zero-point. value = arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp); // Saturate to the output size. int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); // Unsigned integers have a difference output value. if (op.getOutputUnsigned()) { intMin = 0; intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } auto intMinVal = arith::ConstantOp::create( nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin)); auto intMaxVal = arith::ConstantOp::create( nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax)); value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, nestedBuilder, /*isUnsigned=*/false); if (outIntType.getWidth() < 32) { value = arith::TruncIOp::create( nestedBuilder, nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); } if (outIntType.isUnsignedInteger()) { value = nestedBuilder .create(nestedLoc, outIntType, value) .getResult(0); } linalg::YieldOp::create(nestedBuilder, loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); return success(); } }; // Handle the resize case where the input is a 1x1 image. This case // can entirely avoiding having extract operations which target much // more difficult to optimize away. class ResizeUnaryConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); const bool isBilinear = op.getMode() == "BILINEAR"; auto inputH = inputTy.getDimSize(1); auto inputW = inputTy.getDimSize(2); auto outputH = resultTy.getDimSize(1); auto outputW = resultTy.getDimSize(2); if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1) return rewriter.notifyMatchFailure( op, "tosa.resize is not a pure 1x1->1x1 image operation"); // TODO(suderman): These string values should be declared the TOSA dialect. if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); if (inputTy == resultTy) { rewriter.replaceOp(op, input); return success(); } SmallVector scale; if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) { return failure(); } // Collapse the unit width and height away. SmallVector reassociationMap(2); reassociationMap[0].push_back(builder.getAffineDimExpr(0)); reassociationMap[1].push_back(builder.getAffineDimExpr(1)); reassociationMap[1].push_back(builder.getAffineDimExpr(2)); reassociationMap[1].push_back(builder.getAffineDimExpr(3)); auto collapseTy = RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, inputTy.getElementType()); Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input, reassociationMap); // Get any dynamic shapes that appear in the input format. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); // Generate the elementwise operation for casting scaling the input value. auto genericTy = collapseTy.clone(resultTy.getElementType()); Value empty = tensor::EmptyOp::create(builder, genericTy.getShape(), resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); SmallVector iterators(genericTy.getRank(), utils::IteratorType::parallel); auto generic = linalg::GenericOp::create( builder, genericTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{genericMap, genericMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; // This is the quantized case. if (inputTy.getElementType() != resultTy.getElementType()) { value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(), value); if (isBilinear && scale[0] != 0) { Value scaleY = arith::ConstantOp::create( b, loc, b.getI32IntegerAttr(scale[0])); value = arith::MulIOp::create(b, loc, value, scaleY); } if (isBilinear && scale[2] != 0) { Value scaleX = arith::ConstantOp::create( b, loc, b.getI32IntegerAttr(scale[2])); value = arith::MulIOp::create(b, loc, value, scaleX); } } linalg::YieldOp::create(b, loc, value); }); rewriter.replaceOpWithNewOp( op, resultTy, generic.getResults()[0], reassociationMap); return success(); } }; // TOSA resize with width or height of 1 may be broadcasted to a wider // dimension. This is done by materializing a new tosa.resize without // the broadcasting behavior, and an explicit broadcast afterwards. class MaterializeResizeBroadcast : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); auto inputTy = dyn_cast(input.getType()); auto resultTy = dyn_cast(op.getType()); if (!inputTy || !resultTy) return rewriter.notifyMatchFailure(op, "requires ranked input/output types"); auto batch = inputTy.getDimSize(0); auto channels = inputTy.getDimSize(3); auto inputH = inputTy.getDimSize(1); auto inputW = inputTy.getDimSize(2); auto outputH = resultTy.getDimSize(1); auto outputW = resultTy.getDimSize(2); if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1)) return rewriter.notifyMatchFailure( op, "tosa.resize has no broadcasting behavior"); // For any dimension that is broadcastable we generate a width of 1 // on the output. llvm::SmallVector resizeShape; resizeShape.push_back(batch); resizeShape.push_back(inputH == 1 ? 1 : outputH); resizeShape.push_back(inputW == 1 ? 1 : outputW); resizeShape.push_back(channels); auto resizeTy = resultTy.clone(resizeShape); auto resize = tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(), op.getOffset(), op.getBorder(), op.getMode()); // Collapse an unit result dims. SmallVector reassociationMap(2); reassociationMap[0].push_back(builder.getAffineDimExpr(0)); reassociationMap.back().push_back(builder.getAffineDimExpr(1)); if (inputH != 1) reassociationMap.push_back({}); reassociationMap.back().push_back(builder.getAffineDimExpr(2)); if (inputW != 1) reassociationMap.push_back({}); reassociationMap.back().push_back(builder.getAffineDimExpr(3)); llvm::SmallVector collapseShape = {batch}; if (inputH != 1) collapseShape.push_back(outputH); if (inputW != 1) collapseShape.push_back(outputW); collapseShape.push_back(channels); auto collapseTy = resultTy.clone(collapseShape); Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, resize, reassociationMap); // Broadcast the collapsed shape to the output result. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); SmallVector iterators(resultTy.getRank(), utils::IteratorType::parallel); Value empty = tensor::EmptyOp::create( builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize); SmallVector inputExprs{rewriter.getAffineDimExpr(0)}; if (inputH != 1) inputExprs.push_back(rewriter.getAffineDimExpr(1)); if (inputW != 1) inputExprs.push_back(rewriter.getAffineDimExpr(2)); inputExprs.push_back(rewriter.getAffineDimExpr(3)); auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, rewriter.getContext()); auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); rewriter.replaceOpWithNewOp( op, resultTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{inputMap, outputMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; linalg::YieldOp::create(b, loc, value); }); return success(); } }; class GenericResizeConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); auto resultETy = resultTy.getElementType(); bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; auto dynamicDimsOr = checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); if (!dynamicDimsOr.has_value()) return rewriter.notifyMatchFailure( op, "unable to get dynamic dimensions of tosa.resize"); if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(), resultETy, *dynamicDimsOr); auto genericOp = linalg::GenericOp::create( b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); Value resize = genericOp.getResult(0); { OpBuilder::InsertionGuard regionGuard(b); b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({resultETy}), loc); Value batch = linalg::IndexOp::create(b, 0); Value y = linalg::IndexOp::create(b, 1); Value x = linalg::IndexOp::create(b, 2); Value channel = linalg::IndexOp::create(b, 3); Value zeroI32 = arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type())); Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy)); Value hMax = arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1)); Value wMax = arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1)); Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y); Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x); SmallVector scale, offset, border; if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) || !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) { return rewriter.notifyMatchFailure( op, "tosa.resize scale/offset/border should have compile time " "constant values."); } Value yScaleN, yScaleD, xScaleN, xScaleD; yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0])); yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1])); xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2])); xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3])); Value yOffset, xOffset, yBorder, xBorder; yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0])); xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1])); yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0])); xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1])); // Compute the ix and dx values for both the X and Y dimensions. auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { if (size == 1) { index = zeroI32; delta = zeroFp; return; } // x = x * scale_d + offset; // ix = floor(x / scale_n) Value val = arith::MulIOp::create(b, in, scaleD); val = arith::AddIOp::create(b, val, offset); index = arith::FloorDivSIOp::create(b, val, scaleN); // rx = x % scale_n // dx = rx / scale_n Value r = arith::RemSIOp::create(b, val, scaleN); Value rFp = arith::SIToFPOp::create(b, floatTy, r); Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN); delta = arith::DivFOp::create(b, rFp, scaleNfp); }; // Compute the ix and dx values for the X and Y dimensions - int case. auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { if (size == 1) { index = zeroI32; delta = zeroI32; return; } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; Value val = arith::MulIOp::create(b, in, scaleD); val = arith::AddIOp::create(b, val, offset); index = arith::DivSIOp::create(b, val, scaleN); delta = arith::MulIOp::create(b, index, scaleN); delta = arith::SubIOp::create(b, val, delta); }; Value ix, iy, dx, dy; if (floatingPointMode) { getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } else { getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } if (op.getMode() == "NEAREST_NEIGHBOR") { auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, Value max, int size, ImplicitLocOpBuilder &b) -> Value { if (size == 1) { return arith::ConstantIndexOp::create(b, 0); } Value pred; if (floatingPointMode) { auto h = arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f)); pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h); } else { Value dvalDouble = arith::ShLIOp::create(b, dval, one); pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge, dvalDouble, scale); } auto offset = arith::SelectOp::create(b, pred, one, zeroI32); val = arith::AddIOp::create(b, val, offset); val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false); return arith::IndexCastOp::create(b, b.getIndexType(), val); }; iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); Value result = tensor::ExtractOp::create( b, input, ValueRange{batch, iy, ix, channel}); linalg::YieldOp::create(b, result); } else { // The mode here must be BILINEAR. assert(op.getMode() == "BILINEAR"); auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, Value max, ImplicitLocOpBuilder &b) { val0 = in; val1 = arith::AddIOp::create(b, val0, oneVal); val0 = clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false); val1 = clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false); val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0); val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1); }; // Linalg equivalent to the section below: // int16_t iy0 = apply_max(iy, 0); // int16_t iy1 = apply_min(iy + 1, IH - 1); // int16_t ix0 = apply_max(ix, 0); // int16_t ix1 = apply_min(ix + 1, IW - 1); Value x0, x1, y0, y1; getClampedIdxs(y0, y1, imageH, iy, hMax, b); getClampedIdxs(x0, x1, imageW, ix, wMax, b); Value y0x0 = tensor::ExtractOp::create( b, input, ValueRange{batch, y0, x0, channel}); Value y0x1 = tensor::ExtractOp::create( b, input, ValueRange{batch, y0, x1, channel}); Value y1x0 = tensor::ExtractOp::create( b, input, ValueRange{batch, y1, x0, channel}); Value y1x1 = tensor::ExtractOp::create( b, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return val0; Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta); Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta); Value mul1 = arith::MulFOp::create(b, val1, delta); return arith::AddFOp::create(b, mul0, mul1); }; // Linalg equivalent to the section below: // topAcc = v00 * (unit_x - dx); // topAcc += v01 * dx; Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b); // Linalg equivalent to the section below: // bottomAcc = v10 * (unit_x - dx); // bottomAcc += v11 * dx; Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b); // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); linalg::YieldOp::create(b, result); } else { // Perform in quantized space. y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0); y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1); y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0); y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1); const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { dx = arith::ExtSIOp::create(b, resultETy, dx); dy = arith::ExtSIOp::create(b, resultETy, dy); } Value yScaleNExt = yScaleN; Value xScaleNExt = xScaleN; const int64_t scaleBitwidth = xScaleN.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN); xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN); } auto interpolate = [](Value val0, Value val1, Value weight1, Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return arith::MulIOp::create(b, val0, scale); Value weight0 = arith::SubIOp::create(b, scale, weight1); Value mul0 = arith::MulIOp::create(b, val0, weight0); Value mul1 = arith::MulIOp::create(b, val1, weight1); return arith::AddIOp::create(b, mul0, mul1); }; Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); Value result = interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); linalg::YieldOp::create(b, result); } } } rewriter.replaceOp(op, resize); return success(); } }; // At the codegen level any identity operations should be removed. Any cases // where identity is load-bearing (e.g. cross device computation) should be // handled before lowering to codegen. template class IdentityNConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { rewriter.replaceOp(op, op.getOperation()->getOperands()); return success(); } }; template class ReduceConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp reduceOp, PatternRewriter &rewriter) const final { return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter); } }; class ReverseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ReverseOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput1(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); auto axis = op.getAxis(); SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i)) { dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. auto emptyTensor = rewriter .create(loc, inputTy.getShape(), inputTy.getElementType(), ArrayRef({dynDims})) .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( op, resultTy, ArrayRef({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { Value index = linalg::IndexOp::create(rewriter, nestedLoc, i).getResult(); if (i == axis) { auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1); auto sizeMinusOne = arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one); index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne, index); } indices.push_back(index); } auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc, input, indices); linalg::YieldOp::create(nestedBuilder, op.getLoc(), extract.getResult()); }); return success(); } }; // This converter translate a tile operation to a reshape, broadcast, reshape. // The first reshape minimally expands each tiled dimension to include a // proceding size-1 dim. This dim is then broadcasted to the appropriate // multiple. struct TileConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); auto resultTy = cast(op.getType()); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); SmallVector multiples; if (failed(op.getConstantMultiples(multiples))) return failure(); // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; for (int i = 0; i < rank; i++) { int64_t dim = multiples[i]; genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); genericShape.push_back(inputShape[i]); } SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) || multiples[i] == -1) { dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } auto emptyTensor = tensor::EmptyOp::create( rewriter, op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); auto readAffineMap = AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, rewriter.getContext()); SmallVector affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; auto genericOp = linalg::GenericOp::create( rewriter, loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin()); }); auto shapeValue = getTosaConstShape( rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape())); rewriter.replaceOpWithNewOp( op, resultTy, genericOp.getResult(0), shapeValue); return success(); } }; // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic // op, producing two output buffers. // // The first output buffer contains the index of the found maximum value. It is // initialized to 0 and is resulting integer type. // // The second output buffer contains the maximum value found. It is initialized // to the minimum representable value of the input element type. After being // populated by indexed_generic, this buffer is disgarded as only the index is // requested. // // The indexed_generic op updates both the maximum value and index if the // current value exceeds the running max. class ArgMaxConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); if (!isa(outElementTy)) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } // First fill the output buffer for the index. auto emptyTensorIdx = rewriter .create(loc, resultTy.getShape(), outElementTy, dynDims) .getResult(); auto fillValueIdx = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter .create(loc, ValueRange{fillValueIdx}, ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. auto emptyTensorMax = rewriter .create(loc, resultTy.getShape(), inElementTy, dynDims) .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); if (!fillValueMaxAttr) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = rewriter .create(loc, ValueRange{fillValueMax}, ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along // the rest. SmallVector iteratorTypes; iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); iteratorTypes[axis] = utils::IteratorType::reduction; SmallVector srcExprs; SmallVector dstExprs; for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); auto linalgOp = linalg::GenericOp::create( rewriter, loc, ArrayRef({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto newValue = blockArgs[0]; auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; Value newIndex = arith::IndexCastOp::create( rewriter, nestedLoc, oldIndex.getType(), linalg::IndexOp::create(rewriter, loc, axis)); Value predicate; if (isa(inElementTy)) { if (argmaxOp.getNanMode() == "IGNORE") { // Only update index & max value for non NaN values. If all // values are NaNs, the initial index will be return which is 0. predicate = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else { // Update max value if either of the following is true: // - new value is bigger // - cur max is not NaN and new value is NaN Value gt = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue); Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue); predicate = arith::AndIOp::create( rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); } } else if (isa(inElementTy)) { predicate = arith::CmpIOp::create(rewriter, nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; } auto resultMax = arith::SelectOp::create( rewriter, nestedLoc, predicate, newValue, oldValue); auto resultIndex = arith::SelectOp::create( rewriter, nestedLoc, predicate, newIndex, oldIndex); linalg::YieldOp::create(nestedBuilder, nestedLoc, ValueRange({resultIndex, resultMax})); }); if (didEncounterError) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); return success(); } }; class GatherConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto input = adaptor.getOperands()[0]; auto indices = adaptor.getOperands()[1]; auto valuesTy = dyn_cast(op.getValues().getType()); auto resultTy = dyn_cast(op.getType()); if (!valuesTy || !resultTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); auto dynamicDims = inferDynamicDimsForGather( rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices()); auto resultElementTy = resultTy.getElementType(); auto loc = op.getLoc(); auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, dynamicDims) .getResult(); SmallVector affineMaps = { AffineMap::get( /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = linalg::GenericOp::create( rewriter, loc, ArrayRef({resultTy}), ValueRange{indices}, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = linalg::IndexOp::create(rewriter, loc, 0); Value index1 = arith::IndexCastOp::create( rewriter, loc, rewriter.getIndexType(), indexValue); auto index2 = linalg::IndexOp::create(rewriter, loc, 2); Value extract = tensor::ExtractOp::create( rewriter, loc, input, ValueRange{index0, index1, index2}); linalg::YieldOp::create(rewriter, loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } static llvm::SmallVector inferDynamicDimsForGather(OpBuilder &builder, Location loc, Value values, Value indices) { llvm::SmallVector results; auto addDynamicDimension = [&](Value source, int64_t dim) { auto sz = tensor::getMixedSize(builder, loc, source, dim); if (auto dimValue = llvm::dyn_cast_if_present(sz)) results.push_back(dimValue); }; addDynamicDimension(values, 0); addDynamicDimension(indices, 1); addDynamicDimension(values, 2); return results; } }; // Lowerings the TableOp to a series of gathers and numerica operations. This // includes interpolation between the high/low values. For the I8 varient, this // simplifies to a single gather operation. class TableConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TableOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput1(); Value table = op.getTable(); auto inputTy = cast(input.getType()); auto tableTy = cast(table.getType()); auto resultTy = cast(op.getType()); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); SmallVector dynDims; for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { dynDims.push_back( tensor::DimOp::create(rewriter, loc, op.getOperand(0), i)); } } auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, dynDims) .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = linalg::GenericOp::create( rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { OpBuilder::InsertionGuard regionGuard(rewriter); Block *block = rewriter.createBlock( &genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({inputElementTy, resultElementTy}), {loc, loc}); auto inputValue = block->getArgument(0); rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { Value index = arith::IndexCastOp::create( rewriter, loc, rewriter.getIndexType(), inputValue); Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128); index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), index, offset); Value extract = tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); linalg::YieldOp::create(rewriter, loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { Value extend = arith::ExtSIOp::create( rewriter, loc, rewriter.getI32Type(), inputValue); auto offset = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(32768)); auto seven = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(7)); auto one = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1)); auto b1111111 = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset); Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven); Value fraction = arith::AndIOp::create(rewriter, loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one); index = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), index); indexPlusOne = arith::IndexCastOp::create( rewriter, loc, rewriter.getIndexType(), indexPlusOne); Value base = tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); Value next = tensor::ExtractOp::create(rewriter, loc, table, ValueRange{indexPlusOne}); base = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base); next = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven); Value diff = arith::SubIOp::create(rewriter, loc, next, base); Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction); Value result = arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled); linalg::YieldOp::create(rewriter, loc, result); return success(); } } return rewriter.notifyMatchFailure( op, "unable to create body for tosa.table op"); } }; struct RFFT2dConverter final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; static bool isRankedTensor(Type type) { return isa(type); } static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, OpFoldResult ofr) { auto one = arith::ConstantIndexOp::create(builder, loc, 1); auto two = arith::ConstantIndexOp::create(builder, loc, 2); auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); auto divBy2 = builder.createOrFold(loc, value, two); auto plusOne = builder.createOrFold(loc, divBy2, one); return getAsOpFoldResult(plusOne); } static RankedTensorType computeOutputShape(OpBuilder &builder, Location loc, Value input, llvm::SmallVectorImpl &dynamicSizes) { // Get [N, H, W] auto dims = tensor::getMixedSizes(builder, loc, input); // Set W = (W / 2) + 1 to account for the half-sized W dimension of the // output tensors. dims[2] = halfPlusOne(builder, loc, dims[2]); llvm::SmallVector staticSizes; dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); auto elementType = cast(input.getType()).getElementType(); return RankedTensorType::get(staticSizes, elementType); } static Value createZeroTensor(PatternRewriter &rewriter, Location loc, RankedTensorType type, llvm::ArrayRef dynamicSizes) { auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) .result(); return filledTensor; } static Value castIndexToFloat(OpBuilder &builder, Location loc, FloatType type, Value value) { auto integerVal = arith::IndexCastUIOp::create( builder, loc, type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() : builder.getI32Type(), value); return arith::UIToFPOp::create(builder, loc, type, integerVal); } static Value createLinalgIndex(OpBuilder &builder, Location loc, FloatType type, int64_t index) { auto indexVal = linalg::IndexOp::create(builder, loc, index); return castIndexToFloat(builder, loc, type, indexVal); } template static llvm::SmallVector affineDimsExpr(OpBuilder &builder, Args... args) { return {builder.getAffineDimExpr(args)...}; } LogicalResult matchAndRewrite(RFFT2dOp rfft2d, PatternRewriter &rewriter) const override { if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) || !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) { return rewriter.notifyMatchFailure(rfft2d, "only supports ranked tensors"); } auto loc = rfft2d.getLoc(); auto input = rfft2d.getInputReal(); auto elementType = dyn_cast(cast(input.getType()).getElementType()); if (!elementType) return rewriter.notifyMatchFailure(rfft2d, "only supports float element types"); // Compute the output type and set of dynamic sizes llvm::SmallVector dynamicSizes; auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes); // Iterator types for the linalg.generic implementation llvm::SmallVector iteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::reduction, utils::IteratorType::reduction}; // Inputs/outputs to the linalg.generic implementation llvm::SmallVector genericOpInputs = {input}; llvm::SmallVector genericOpOutputs = { createZeroTensor(rewriter, loc, outputType, dynamicSizes), createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; // Indexing maps for input and output tensors auto indexingMaps = AffineMap::inferFromExprList( llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2), affineDimsExpr(rewriter, 0, 1, 2)}, rewriter.getContext()); // Width and height dimensions of the original input. auto dimH = rewriter.createOrFold(loc, input, 1); auto dimW = rewriter.createOrFold(loc, input, 2); // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { Value valReal = args[0]; Value sumReal = args[1]; Value sumImag = args[2]; // Indices for angle computation Value oy = linalg::IndexOp::create(builder, loc, 1); Value ox = linalg::IndexOp::create(builder, loc, 2); Value iy = linalg::IndexOp::create(builder, loc, 3); Value ix = linalg::IndexOp::create(builder, loc, 4); // Calculating angle without integer parts of components as sin/cos are // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) // / W); auto iyXoy = index::MulOp::create(builder, loc, iy, oy); auto ixXox = index::MulOp::create(builder, loc, ix, ox); auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); // realComponent = valReal * cos(angle) // imagComponent = valReal * sin(angle) auto cosAngle = math::CosOp::create(builder, loc, angle); auto sinAngle = math::SinOp::create(builder, loc, angle); auto realComponent = arith::MulFOp::create(builder, loc, valReal, cosAngle); auto imagComponent = arith::MulFOp::create(builder, loc, valReal, sinAngle); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent auto outReal = arith::AddFOp::create(builder, loc, sumReal, realComponent); auto outImag = arith::SubFOp::create(builder, loc, sumImag, imagComponent); linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs, indexingMaps, iteratorTypes, buildBody); return success(); } }; struct FFT2dConverter final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FFT2dOp fft2d, PatternRewriter &rewriter) const override { if (!llvm::all_of(fft2d->getOperandTypes(), RFFT2dConverter::isRankedTensor) || !llvm::all_of(fft2d->getResultTypes(), RFFT2dConverter::isRankedTensor)) { return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors"); } Location loc = fft2d.getLoc(); Value input_real = fft2d.getInputReal(); Value input_imag = fft2d.getInputImag(); BoolAttr inverse = fft2d.getInverseAttr(); auto real_el_ty = cast( cast(input_real.getType()).getElementType()); [[maybe_unused]] auto imag_el_ty = cast( cast(input_imag.getType()).getElementType()); assert(real_el_ty == imag_el_ty); // Compute the output type and set of dynamic sizes SmallVector dynamicSizes; // Get [N, H, W] auto dims = tensor::getMixedSizes(rewriter, loc, input_real); SmallVector staticSizes; dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); auto outputType = RankedTensorType::get(staticSizes, real_el_ty); // Iterator types for the linalg.generic implementation SmallVector iteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::reduction, utils::IteratorType::reduction}; // Inputs/outputs to the linalg.generic implementation SmallVector genericOpInputs = {input_real, input_imag}; SmallVector genericOpOutputs = { RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, dynamicSizes), RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; // Indexing maps for input and output tensors auto indexingMaps = AffineMap::inferFromExprList( ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2), RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)}, rewriter.getContext()); // Width and height dimensions of the original input. auto dimH = rewriter.createOrFold(loc, input_real, 1); auto dimW = rewriter.createOrFold(loc, input_real, 2); // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); Value constH = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); Value constW = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { Value valReal = args[0]; Value valImag = args[1]; Value sumReal = args[2]; Value sumImag = args[3]; // Indices for angle computation Value oy = linalg::IndexOp::create(builder, loc, 1); Value ox = linalg::IndexOp::create(builder, loc, 2); Value iy = linalg::IndexOp::create(builder, loc, 3); Value ix = linalg::IndexOp::create(builder, loc, 4); // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * // ox) % W ) / W); auto iyXoy = index::MulOp::create(builder, loc, iy, oy); auto ixXox = index::MulOp::create(builder, loc, ix, ox); auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); auto ixRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); if (inverse.getValue()) { angle = arith::MulFOp::create( builder, loc, angle, arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(real_el_ty, -1.0))); } // realComponent = val_real * cos(a) + val_imag * sin(a); // imagComponent = -val_real * sin(a) + val_imag * cos(a); auto cosAngle = math::CosOp::create(builder, loc, angle); auto sinAngle = math::SinOp::create(builder, loc, angle); auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle); auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle); auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin); auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle); auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle); auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent auto outReal = arith::AddFOp::create(builder, loc, sumReal, realComponent); auto outImag = arith::AddFOp::create(builder, loc, sumImag, imagComponent); linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs, indexingMaps, iteratorTypes, buildBody); return success(); } }; } // namespace void mlir::tosa::populateTosaToLinalgConversionPatterns( const TypeConverter &converter, RewritePatternSet *patterns) { // We have multiple resize coverters to handle degenerate cases. patterns->add(patterns->getContext(), /*benefit=*/100); patterns->add(patterns->getContext(), /*benefit=*/200); patterns->add(patterns->getContext(), /*benefit=*/300); patterns->add< // clang-format off PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter >(converter, patterns->getContext()); patterns->add< IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ArgMaxConverter, GatherConverter, RescaleConverter, ReverseConverter, RFFT2dConverter, FFT2dConverter, TableConverter, TileConverter>(patterns->getContext()); // clang-format on }