217 lines
8.1 KiB
C++
217 lines
8.1 KiB
C++
//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
namespace {
|
|
|
|
/// Rewrite tensor.generate with arith.constant if the yielded value is a
|
|
/// constant and the tensor type is static.
|
|
struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
|
|
using OpRewritePattern<GenerateOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenerateOp generateOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto tensorType =
|
|
llvm::cast<RankedTensorType>(generateOp.getResult().getType());
|
|
if (!tensorType.hasStaticShape())
|
|
return failure();
|
|
auto terminatorOp =
|
|
cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
|
|
Attribute attr;
|
|
if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
|
|
return failure();
|
|
Operation *constantOp =
|
|
rewriter.getContext()
|
|
->getLoadedDialect<TensorDialect>()
|
|
->materializeConstant(rewriter,
|
|
DenseElementsAttr::get(tensorType, attr),
|
|
tensorType, generateOp->getLoc());
|
|
if (!constantOp)
|
|
return failure();
|
|
rewriter.replaceOp(generateOp, constantOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Transform a linear index from one indexing space to another given:
|
|
///
|
|
/// - the shape of the source indexing space,
|
|
/// - the strides of the target indexing space,
|
|
/// - a linear index into the source indexing space.
|
|
///
|
|
/// This function is logically a sequence of linearize/delinearize over
|
|
/// different bases but avoids allocating intermediate SmallVectors.
|
|
int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
|
|
ArrayRef<int64_t> outputStrides,
|
|
int64_t srcLinearIndex) {
|
|
assert(inputShape.size() == outputStrides.size());
|
|
|
|
int64_t dstLinearIndex = 0;
|
|
|
|
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
|
|
// Compute the index into the current dimension of the source tensor.
|
|
// `quotient` is the remaining linear index after accounting for the
|
|
// current dimension.
|
|
//
|
|
// `remainder` is the index into the source tensor for the current
|
|
// dimension.
|
|
auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
|
|
|
|
srcLinearIndex = quotient;
|
|
|
|
// Add the contribution of the current dimension to the output using the
|
|
// permutation map.
|
|
dstLinearIndex += outputStrides[dim] * remainder;
|
|
}
|
|
|
|
return dstLinearIndex;
|
|
}
|
|
|
|
template <typename ElemType, typename AttrType>
|
|
Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
|
|
DenseElementsAttr input, AttrType padValue,
|
|
ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
|
|
auto inputValues = input.tryGetValues<ElemType>();
|
|
if (failed(inputValues))
|
|
return nullptr;
|
|
|
|
auto oldShape = input.getType().getShape();
|
|
|
|
// Compute the output shape of the new value.
|
|
auto newShape =
|
|
llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
|
|
[](std::tuple<int64_t, int64_t, int64_t> pack) {
|
|
auto [old, low, high] = pack;
|
|
return old + low + high;
|
|
});
|
|
|
|
int64_t outputSize = computeProduct(newShape);
|
|
|
|
// Fully initialize the vector with the padding value.
|
|
// The non-padded area will then be copied.
|
|
SmallVector<ElemType> values(outputSize, padValue.getValue());
|
|
|
|
// Strides for input and output are used to transform between the indexing
|
|
// space of the input and output tensors.
|
|
SmallVector<int64_t> outputStrides = computeStrides(newShape);
|
|
|
|
// The contribution of the low padding to the offset in the output tensor.
|
|
// This is the starting position of the source tensor within the padding
|
|
// tensor.
|
|
int64_t startingOffset = linearize(padLow, outputStrides);
|
|
|
|
// Copy values from the input tensor to the corresponding sub-region
|
|
// of the output tensor.
|
|
for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
|
|
auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
|
|
values[outputIndex + startingOffset] = inputValue;
|
|
}
|
|
|
|
// Create an attribute for the folded value.
|
|
auto newType = input.getType().clone(newShape);
|
|
auto newAttr = DenseElementsAttr::get(newType, values);
|
|
|
|
Operation *constantOp =
|
|
rewriter.getContext()
|
|
->getLoadedDialect<TensorDialect>()
|
|
->materializeConstant(rewriter, newAttr, newType, loc);
|
|
|
|
return constantOp ? constantOp->getResult(0) : nullptr;
|
|
}
|
|
|
|
struct PadOpToConstant final : public OpRewritePattern<PadOp> {
|
|
|
|
PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
|
|
|
|
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (padTensorOp.getNofold())
|
|
return rewriter.notifyMatchFailure(
|
|
padTensorOp, "refusing to fold nofold pad operation");
|
|
|
|
TypedValue<RankedTensorType> input = padTensorOp.getSource();
|
|
RankedTensorType resultType = padTensorOp.getResult().getType();
|
|
|
|
DenseElementsAttr inputAttr = nullptr;
|
|
if (!matchPattern(input, m_Constant(&inputAttr)))
|
|
return failure();
|
|
|
|
Value paddingValue = padTensorOp.getConstantPaddingValue();
|
|
|
|
// Extract the constant value used for padding or bail out.
|
|
Attribute paddingAttr = nullptr;
|
|
if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
|
|
return rewriter.notifyMatchFailure(padTensorOp,
|
|
"unable to get constant value");
|
|
|
|
// Try to extract the constant values of the low and high padding.
|
|
auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
|
|
auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
|
|
|
|
// If the padding cannot be extracted, bail out.
|
|
if (!lowPad || !highPad)
|
|
return rewriter.notifyMatchFailure(padTensorOp,
|
|
"unable to extract constant padding");
|
|
|
|
// We have a potential candidate, consult the control function to
|
|
// determine if the op should fold.
|
|
if (!controlFn(&padTensorOp.getSourceMutable()))
|
|
return rewriter.notifyMatchFailure(padTensorOp,
|
|
"not folding due to cost function");
|
|
|
|
Location loc = padTensorOp.getLoc();
|
|
|
|
// Try constant folding the supported cases of integer and float values.
|
|
Value newOp =
|
|
llvm::TypeSwitch<Attribute, Value>(paddingAttr)
|
|
.Case([&](FloatAttr floatAttr) {
|
|
return constantFoldPadOp<llvm::APFloat>(
|
|
rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
|
|
})
|
|
.Case([&](IntegerAttr integerAttr) {
|
|
return constantFoldPadOp<llvm::APInt>(
|
|
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
|
|
})
|
|
.Default(Value());
|
|
|
|
if (!newOp)
|
|
return rewriter.notifyMatchFailure(padTensorOp,
|
|
"tensor type not supported");
|
|
|
|
if (newOp.getType() != resultType)
|
|
newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp);
|
|
|
|
rewriter.replaceOp(padTensorOp, newOp);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlFoldFn controlFn;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tensor::populateRewriteAsConstantPatterns(
|
|
RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
|
|
patterns.add<GenerateToConstant>(patterns.getContext());
|
|
|
|
patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
|
|
}
|