
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
1074 lines
36 KiB
C++
1074 lines
36 KiB
C++
//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// \file
|
|
// TOSA canonicalization patterns and folders.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
|
|
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
|
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <functional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tosa;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operator Canonicalizers.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
|
|
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ConcatOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getInput1().size() != 1)
|
|
return failure();
|
|
if (op.getInput1().front().getType() != op.getType()) {
|
|
rewriter
|
|
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
|
|
op.getInput1().front())
|
|
.getResult();
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOp(op, op.getInput1().front());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConcatOptimization>(context);
|
|
}
|
|
|
|
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
|
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput1();
|
|
Operation *definingOp = input.getDefiningOp();
|
|
if (!definingOp)
|
|
return failure();
|
|
|
|
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
|
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
|
op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
|
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput1();
|
|
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
|
|
ShapedType resultTy = llvm::cast<ShapedType>(op.getType());
|
|
|
|
if (inputTy.getElementType() != resultTy.getElementType())
|
|
return rewriter.notifyMatchFailure(op, "element type does not match.");
|
|
|
|
// Check if input is constant
|
|
DenseElementsAttr inputAttr;
|
|
if (!matchPattern(input, m_Constant(&inputAttr)))
|
|
return rewriter.notifyMatchFailure(op, "Non-constant input.");
|
|
|
|
// Check if has >1 consumer and is not splat
|
|
if (!input.hasOneUse() && !inputAttr.isSplat())
|
|
return rewriter.notifyMatchFailure(op,
|
|
"Used more than once or not-splat");
|
|
|
|
// Build new const op with correct output shape
|
|
DenseElementsAttr outputAttr = inputAttr.reshape(
|
|
llvm::cast<ShapedType>(inputAttr.getType()).clone(op.getNewShape()));
|
|
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultTy, outputAttr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ReshapeReshapeOptimization>(context);
|
|
results.add<ReshapeConstOptimization>(context);
|
|
}
|
|
|
|
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
|
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
|
if (!notOp)
|
|
return failure();
|
|
rewriter.updateRootInPlace(op, [&]() {
|
|
op.getOperation()->setOperands(
|
|
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
|
|
});
|
|
return success();
|
|
}
|
|
|
|
struct ConsolidateTransposeOptimization
|
|
: public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Input is also TransposeOp - transpose(transpose(A)).
|
|
auto innerTranspose =
|
|
transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
|
|
if (!innerTranspose)
|
|
return rewriter.notifyMatchFailure(transposeOp,
|
|
"input must be transpose operation");
|
|
|
|
SmallVector<int64_t> transposePerms, innerTransposePerms;
|
|
if (transposeOp.getConstantPerms(transposePerms).failed())
|
|
return rewriter.notifyMatchFailure(transposeOp,
|
|
"transpose perms must be constant");
|
|
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp, "inner transpose perms must be constant");
|
|
if (transposePerms.size() != innerTransposePerms.size())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp,
|
|
"transpose and inner transpose perms sizes must be equal");
|
|
if (transposePerms.empty())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp, "transpose perms sizes must be positive");
|
|
|
|
// Consolidate transposes into one transpose.
|
|
SmallVector<int32_t> perms(transposePerms.size());
|
|
for (int i = 0, s = transposePerms.size(); i < s; ++i)
|
|
perms[i] = innerTransposePerms[transposePerms[i]];
|
|
|
|
auto permsTy =
|
|
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
|
|
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
|
|
Value permsValue =
|
|
rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
|
transposeOp, transposeOp.getResult().getType(),
|
|
innerTranspose.getInput1(), permsValue);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Determines the case when tosa.transpose is a tosa.reshape operation.
|
|
struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
DenseIntElementsAttr permAttr;
|
|
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
|
|
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
|
|
|
|
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Src is from transpose, can compose transposes");
|
|
|
|
Value result = op.getResult();
|
|
for (Operation *subop : result.getUsers()) {
|
|
if (dyn_cast_or_null<tosa::TransposeOp>(subop))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Dest is used by transpose, can compose transposes");
|
|
}
|
|
|
|
auto input = op.getInput1();
|
|
auto inputTy = llvm::cast<ShapedType>(input.getType());
|
|
if (!inputTy.hasRank())
|
|
return rewriter.notifyMatchFailure(op, "Unranked input.");
|
|
|
|
int64_t numDynDims = 0;
|
|
for (int i = 0; i < inputTy.getRank(); ++i)
|
|
if (inputTy.isDynamicDim(i))
|
|
numDynDims++;
|
|
|
|
if (numDynDims > 1)
|
|
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
|
|
|
|
SmallVector<int64_t> permValues = llvm::to_vector<6>(
|
|
llvm::map_range(permAttr.getValues<APInt>(),
|
|
[](const APInt &val) { return val.getSExtValue(); }));
|
|
|
|
SmallVector<int64_t> nonZeroPerms;
|
|
nonZeroPerms.reserve(permValues.size());
|
|
for (auto idx : permValues) {
|
|
auto sz = inputTy.getDimSize(idx);
|
|
if (sz != 1)
|
|
nonZeroPerms.push_back(idx);
|
|
}
|
|
|
|
for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
|
|
if (nonZeroPerms[i - 1] > nonZeroPerms[i])
|
|
return rewriter.notifyMatchFailure(op,
|
|
"Transpose changes memory layout.");
|
|
|
|
SmallVector<int64_t> newShape;
|
|
newShape.reserve(inputTy.getRank());
|
|
for (int i = 0, s = inputTy.getRank(); i < s; ++i)
|
|
newShape.push_back(inputTy.getDimSize(permValues[i]));
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
|
op, op.getType(), op.getInput1(),
|
|
rewriter.getDenseI64ArrayAttr(newShape));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
|
|
}
|
|
|
|
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::PadOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getPadConst())
|
|
return failure();
|
|
|
|
auto input = op.getInput1();
|
|
auto padding = op.getPadding();
|
|
|
|
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
|
|
Type elementTy = inputTy.getElementType();
|
|
|
|
Attribute constantAttr;
|
|
if (llvm::isa<FloatType>(elementTy)) {
|
|
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
|
} else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
|
} else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
|
|
auto value = op.getQuantizationInfo()->getInputZp();
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, value);
|
|
}
|
|
|
|
if (!constantAttr) {
|
|
return rewriter.notifyMatchFailure(
|
|
op,
|
|
"tosa.pad to linalg lowering encountered an unknown element type");
|
|
}
|
|
|
|
auto denseAttr = DenseElementsAttr::get(
|
|
RankedTensorType::get({}, elementTy), constantAttr);
|
|
auto constantVal = rewriter.create<tosa::ConstOp>(
|
|
op.getLoc(), denseAttr.getType(), denseAttr);
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::PadOp>(
|
|
op, op.getType(), ValueRange{input, padding, constantVal},
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<MaterializePadValue>(context);
|
|
}
|
|
|
|
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
Value output = op.getOutput();
|
|
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
|
|
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
|
|
|
|
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
|
|
return failure();
|
|
}
|
|
|
|
// If the output and input shapes are 1x1, then this is a no op.
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (outputShape[1] != 1 || outputShape[2] != 1) {
|
|
return failure();
|
|
}
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
if (inputShape[1] != 1 || inputShape[2] != 1) {
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<MaxPool2dIsNoOp>(context);
|
|
}
|
|
|
|
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
|
|
auto inputElementType = inputType.getElementType();
|
|
|
|
if (!inputType.hasStaticShape()) {
|
|
return failure();
|
|
}
|
|
|
|
if (inputElementType.isF32()) {
|
|
auto minClamp = op.getMinFp();
|
|
auto maxClamp = op.getMaxFp();
|
|
bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
|
|
minClamp.isNegative();
|
|
bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
|
|
!maxClamp.isNegative();
|
|
|
|
if (isMin && isMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
if (inputElementType.isUnsignedInteger()) {
|
|
int64_t minClamp = op.getMinInt();
|
|
int64_t maxClamp = op.getMaxInt();
|
|
|
|
int64_t intMin =
|
|
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getZExtValue();
|
|
int64_t intMax =
|
|
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getZExtValue();
|
|
|
|
if (minClamp <= intMin && maxClamp >= intMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inputElementType)) {
|
|
int64_t minClamp = op.getMinInt();
|
|
int64_t maxClamp = op.getMaxInt();
|
|
|
|
int64_t intMin =
|
|
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getSExtValue();
|
|
int64_t intMax =
|
|
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getSExtValue();
|
|
|
|
if (minClamp <= intMin && maxClamp >= intMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
|
|
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
|
|
Operation *definingOp = input.getDefiningOp();
|
|
if (!definingOp)
|
|
return failure();
|
|
|
|
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
|
|
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
|
|
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
|
|
|
|
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
|
|
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
|
|
op, op.getType(), clampOp.getInput(),
|
|
rewriter.getI64IntegerAttr(minInt),
|
|
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
|
|
rewriter.getF32FloatAttr(maxFp));
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ClampIsNoOp>(context);
|
|
results.add<ClampClampOptimization>(context);
|
|
}
|
|
|
|
struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
|
|
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Value sliceInput = sliceOp.getInput();
|
|
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
|
|
if (!concatOp)
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "slice input must be concat operation");
|
|
|
|
OperandRange inputs = concatOp.getInput1();
|
|
auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
|
|
if (!concatType || !concatType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "slice input must be a static ranked tensor");
|
|
int32_t axis = concatOp.getAxis();
|
|
|
|
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
|
|
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
|
|
|
|
// Validate slice on the concatenated axis. Slicing along this
|
|
// axis should span only one of the inputs to the concatenate
|
|
// operation.
|
|
std::optional<Value> replaceWithSlice;
|
|
for (auto input : inputs) {
|
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
|
if (!inputType || !inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "concat input must be a static ranked tensor");
|
|
|
|
if (sliceStart[axis] >= 0 &&
|
|
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
|
|
replaceWithSlice =
|
|
rewriter
|
|
.create<tosa::SliceOp>(
|
|
sliceOp.getLoc(), sliceOp.getType(), input,
|
|
rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
|
|
rewriter.getDenseI64ArrayAttr(sliceSize))
|
|
.getResult();
|
|
break;
|
|
}
|
|
sliceStart[axis] -= inputType.getDimSize(axis);
|
|
}
|
|
|
|
if (!replaceWithSlice)
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "corresponding concat input not found for slice");
|
|
|
|
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConcatSliceOptimization>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operator Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename IntFolder, typename FloatFolder>
|
|
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
|
|
RankedTensorType returnTy) {
|
|
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
|
|
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
|
|
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
|
|
if (lETy != rETy)
|
|
return {};
|
|
|
|
if (llvm::isa<IntegerType>(lETy)) {
|
|
APInt l = lhs.getSplatValue<APInt>();
|
|
APInt r = rhs.getSplatValue<APInt>();
|
|
auto result = IntFolder()(l, r);
|
|
return DenseElementsAttr::get(returnTy, result);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(lETy)) {
|
|
APFloat l = lhs.getSplatValue<APFloat>();
|
|
APFloat r = rhs.getSplatValue<APFloat>();
|
|
auto result = FloatFolder()(l, r);
|
|
return DenseElementsAttr::get(returnTy, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
|
|
if (llvm::isa<FloatType>(elemType))
|
|
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
|
|
if (llvm::isa<IntegerType>(elemType))
|
|
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
|
|
return false;
|
|
}
|
|
|
|
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
|
|
if (llvm::isa<FloatType>(elemType))
|
|
return val && val.isSplat() &&
|
|
val.getSplatValue<APFloat>().isExactlyValue(1.0);
|
|
if (llvm::isa<IntegerType>(elemType)) {
|
|
const int64_t shifted = 1LL << shift;
|
|
return val && val.isSplat() &&
|
|
val.getSplatValue<APInt>().getSExtValue() == shifted;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
|
return getInput1();
|
|
if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
|
|
return getInput2();
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
if (lhsTy != rhsTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
if (lhsAttr && lhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy) &&
|
|
lhsAttr.getSplatValue<APInt>().isZero())
|
|
return lhsAttr;
|
|
}
|
|
|
|
if (rhsAttr && rhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy) &&
|
|
rhsAttr.getSplatValue<APInt>().isOne())
|
|
return getInput1();
|
|
}
|
|
|
|
if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy)) {
|
|
APInt l = lhsAttr.getSplatValue<APInt>();
|
|
APInt r = rhsAttr.getSplatValue<APInt>();
|
|
APInt result = l.sdiv(r);
|
|
return DenseElementsAttr::get(resultTy, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
namespace {
|
|
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
|
|
RankedTensorType ty, int32_t shift) {
|
|
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
|
|
if (llvm::isa<IntegerType>(ty.getElementType())) {
|
|
APInt l = lhs.getSplatValue<APInt>();
|
|
APInt r = rhs.getSplatValue<APInt>();
|
|
|
|
if (shift == 0) {
|
|
return DenseElementsAttr::get(ty, l * r);
|
|
}
|
|
|
|
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
|
|
l = l.sext(bitwidth * 2);
|
|
r = r.sext(bitwidth * 2);
|
|
auto result = l * r;
|
|
result.lshrInPlace(shift);
|
|
result = result.trunc(bitwidth);
|
|
return DenseElementsAttr::get(ty, result);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(ty.getElementType())) {
|
|
APFloat l = lhs.getSplatValue<APFloat>();
|
|
APFloat r = rhs.getSplatValue<APFloat>();
|
|
APFloat result = l * r;
|
|
return DenseElementsAttr::get(ty, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
} // namespace
|
|
|
|
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
|
auto lhs = getInput1();
|
|
auto rhs = getInput2();
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
|
if (rhsTy == resultTy) {
|
|
if (isSplatZero(resultETy, lhsAttr))
|
|
return lhsAttr.resizeSplat(resultTy);
|
|
if (isSplatOne(resultETy, lhsAttr, shift))
|
|
return rhs;
|
|
}
|
|
if (lhsTy == resultTy) {
|
|
if (isSplatZero(resultETy, rhsAttr))
|
|
return rhsAttr.resizeSplat(resultTy);
|
|
if (isSplatOne(resultETy, rhsAttr, shift))
|
|
return lhs;
|
|
}
|
|
|
|
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
|
|
}
|
|
|
|
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
|
return getInput1();
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
namespace {
|
|
template <typename Cmp>
|
|
struct ComparisonFold {
|
|
ComparisonFold() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, Cmp()(l, r));
|
|
}
|
|
|
|
APInt operator()(const APFloat &l, const APFloat &r) {
|
|
return APInt(1, Cmp()(l, r));
|
|
}
|
|
};
|
|
|
|
struct APIntFoldGreater {
|
|
APIntFoldGreater() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, l.sgt(r));
|
|
}
|
|
};
|
|
|
|
struct APIntFoldGreaterEqual {
|
|
APIntFoldGreaterEqual() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, l.sge(r));
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
|
|
lhsAttr, rhsAttr, resultTy);
|
|
}
|
|
|
|
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<APIntFoldGreaterEqual,
|
|
ComparisonFold<std::greater_equal<APFloat>>>(
|
|
lhsAttr, rhsAttr, resultTy);
|
|
}
|
|
|
|
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
Value lhs = getInput1();
|
|
Value rhs = getInput2();
|
|
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
|
|
|
// If we are comparing an integer value to itself it is always true. We can
|
|
// not do this with float due to float values.
|
|
if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
|
|
resultTy.hasStaticShape() && lhs == rhs) {
|
|
return DenseElementsAttr::get(resultTy, true);
|
|
}
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
|
|
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
|
|
if (getInput().getType() == getType())
|
|
return getInput();
|
|
|
|
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
|
|
if (!operand)
|
|
return {};
|
|
|
|
auto inTy = llvm::cast<ShapedType>(getInput().getType());
|
|
auto outTy = llvm::cast<ShapedType>(getType());
|
|
auto inETy = inTy.getElementType();
|
|
auto outETy = outTy.getElementType();
|
|
|
|
if (operand.isSplat()) {
|
|
if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
|
|
bool overflow;
|
|
auto splatVal = operand.getSplatValue<APFloat>();
|
|
auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
|
|
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
|
|
&overflow);
|
|
return SplatElementsAttr::get(outTy, splatVal);
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
|
|
auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
|
|
APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
|
|
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
|
|
llvm::RoundingMode::NearestTiesToEven);
|
|
return SplatElementsAttr::get(outTy, splatVal);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
|
|
auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
|
|
auto intVal = APSInt(
|
|
llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
|
|
auto floatVal = operand.getSplatValue<APFloat>();
|
|
bool exact;
|
|
floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
|
|
return SplatElementsAttr::get(outTy, intVal);
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
|
|
auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
|
|
bool trunc =
|
|
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
|
|
auto intVal = operand.getSplatValue<APInt>();
|
|
auto bitwidth = outETy.getIntOrFloatBitWidth();
|
|
|
|
if (trunc) {
|
|
intVal = intVal.trunc(bitwidth);
|
|
} else if (unsignIn) {
|
|
intVal = intVal.zext(bitwidth);
|
|
} else {
|
|
intVal = intVal.sext(bitwidth);
|
|
}
|
|
|
|
return SplatElementsAttr::get(outTy, intVal);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
|
|
|
#define REDUCE_FOLDER(OP) \
|
|
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
|
|
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
|
|
if (!inputTy.hasRank()) \
|
|
return {}; \
|
|
if (inputTy.getDimSize(getAxis()) == 1) \
|
|
return getInput(); \
|
|
return {}; \
|
|
}
|
|
|
|
REDUCE_FOLDER(ReduceAllOp)
|
|
REDUCE_FOLDER(ReduceAnyOp)
|
|
REDUCE_FOLDER(ReduceMaxOp)
|
|
REDUCE_FOLDER(ReduceMinOp)
|
|
REDUCE_FOLDER(ReduceProdOp)
|
|
REDUCE_FOLDER(ReduceSumOp)
|
|
#undef REDUCE_FOLDER
|
|
|
|
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
|
|
if (!inputTy || !outputTy)
|
|
return {};
|
|
|
|
if (inputTy == outputTy)
|
|
return getInput1();
|
|
|
|
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
|
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
|
// If the pad is all zeros we can fold this operation away.
|
|
if (adaptor.getPadding()) {
|
|
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
|
|
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
|
return getInput1();
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
// Fold away cases where a tosa.resize operation returns a copy
|
|
// of the input image.
|
|
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
|
|
ArrayRef<int64_t> offset = getOffset();
|
|
ArrayRef<int64_t> border = getBorder();
|
|
ArrayRef<int64_t> scale = getScale();
|
|
|
|
// Check unit scaling.
|
|
if (scale[0] != scale[1] || scale[2] != scale[3]) {
|
|
return {};
|
|
}
|
|
|
|
// There should be no offset.
|
|
if (offset[0] != 0 || offset[1] != 0) {
|
|
return {};
|
|
}
|
|
|
|
// There should be no border.
|
|
if (border[0] != 0 || border[1] != 0) {
|
|
return {};
|
|
}
|
|
|
|
auto input = getInput();
|
|
auto inputTy = llvm::cast<RankedTensorType>(input.getType());
|
|
auto resultTy = llvm::cast<RankedTensorType>(getType());
|
|
if (inputTy != resultTy)
|
|
return {};
|
|
|
|
return input;
|
|
}
|
|
|
|
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
|
|
auto operand = getInput();
|
|
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
|
auto axis = getAxis();
|
|
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
|
if (operandAttr)
|
|
return operandAttr;
|
|
|
|
// If the dim-length is 1, tosa.reverse is a no-op.
|
|
if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
|
|
return operand;
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
|
|
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
|
|
if (!inputTy || !outputTy)
|
|
return {};
|
|
|
|
if (inputTy == outputTy && inputTy.hasStaticShape())
|
|
return getInput();
|
|
|
|
if (!adaptor.getInput())
|
|
return {};
|
|
|
|
// Cannot create an ElementsAttr from non-int/float/index types
|
|
if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
|
|
!outputTy.getElementType().isIntOrIndexOrFloat())
|
|
return {};
|
|
|
|
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
|
|
if (operand.isSplat() && outputTy.hasStaticShape()) {
|
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
|
}
|
|
|
|
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
|
|
outputTy.getNumElements() == 1) {
|
|
llvm::SmallVector<uint64_t> indices(getStart());
|
|
auto value = operand.getValues<Attribute>()[indices];
|
|
return SplatElementsAttr::get(outputTy, value);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
|
|
if (getOnTrue() == getOnFalse())
|
|
return getOnTrue();
|
|
|
|
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
|
if (!predicate)
|
|
return {};
|
|
|
|
if (!predicate.isSplat())
|
|
return {};
|
|
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
|
|
: getOnFalse();
|
|
}
|
|
|
|
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
|
|
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
|
|
if (allOnes && getInput1().getType() == getType())
|
|
return getInput1();
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
|
|
auto resultTy = llvm::cast<ShapedType>(getType());
|
|
|
|
// Transposing splat values just means reshaping.
|
|
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
|
if (input.isSplat() && resultTy.hasStaticShape() &&
|
|
inputTy.getElementType() == resultTy.getElementType())
|
|
return input.reshape(resultTy);
|
|
}
|
|
|
|
// Transpose does not change the input type.
|
|
if (getInput1().getType() != getType())
|
|
return {};
|
|
|
|
// Transpose is not the identity transpose.
|
|
SmallVector<int64_t> perms;
|
|
if (getConstantPerms(perms).failed())
|
|
return {};
|
|
|
|
if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
|
|
return {};
|
|
|
|
return getInput1();
|
|
}
|
|
|
|
OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise log(exp(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise exp(log(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::LogOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise negate(negate(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise abs(abs(x)) = abs(x)
|
|
if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
|
|
return input;
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
|
// Fold consecutive concats on the same axis into a single op.
|
|
// Keep track of the operands so we are able to construct a new concat
|
|
// later. Conservatively assume that we double the number of operands when
|
|
// folding
|
|
SmallVector<Value, 8> concatOperands;
|
|
concatOperands.reserve(2 * getNumOperands());
|
|
|
|
// Find all operands that are foldable concats
|
|
bool foundFoldableConcat = false;
|
|
for (Value operand : getOperands()) {
|
|
concatOperands.emplace_back(operand);
|
|
|
|
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
|
|
if (!producer)
|
|
continue;
|
|
|
|
// Not foldable if axes are not the same
|
|
if (getAxis() != producer.getAxis())
|
|
continue;
|
|
|
|
// Replace the original operand with all incoming operands
|
|
foundFoldableConcat = true;
|
|
concatOperands.pop_back();
|
|
llvm::append_range(concatOperands, producer->getOperands());
|
|
}
|
|
|
|
if (!foundFoldableConcat)
|
|
return {};
|
|
|
|
getOperation()->setOperands(concatOperands);
|
|
return getResult();
|
|
}
|