Tres Popp 68f58812e3 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This patch updates all remaining uses of the deprecated functionality in
mlir/. This was done with clang-tidy as described below and further
modifications to GPUBase.td and OpenMPOpsInterfaces.td.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```

Differential Revision: https://reviews.llvm.org/D151542
2023-05-26 10:29:55 +02:00

2023 lines
72 KiB
C++

//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <utility>
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::shape;
#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
namespace {
#include "ShapeCanonicalization.inc"
} // namespace
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
return RankedTensorType::get({rank}, IndexType::get(ctx));
}
bool shape::isExtentTensorType(Type type) {
auto ranked = llvm::dyn_cast<RankedTensorType>(type);
return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
}
LogicalResult shape::getShapeVec(Value input,
SmallVectorImpl<int64_t> &shapeValues) {
if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
if (!type.hasRank())
return failure();
llvm::append_range(shapeValues, type.getShape());
return success();
}
DenseIntElementsAttr attr;
if (matchPattern(input, m_Constant(&attr))) {
llvm::append_range(shapeValues, attr.getValues<int64_t>());
return success();
}
return failure();
}
static bool isErrorPropagationPossible(TypeRange operandTypes) {
return llvm::any_of(operandTypes, [](Type ty) {
return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
});
}
static LogicalResult verifySizeOrIndexOp(Operation *op) {
assert(op != nullptr && op->getNumResults() == 1);
Type resultTy = op->getResultTypes().front();
if (isErrorPropagationPossible(op->getOperandTypes())) {
if (!llvm::isa<SizeType>(resultTy))
return op->emitOpError()
<< "if at least one of the operands can hold error values then "
"the result must be of type `size` to propagate them";
}
return success();
}
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
assert(op != nullptr && op->getNumResults() == 1);
Type resultTy = op->getResultTypes().front();
if (isErrorPropagationPossible(op->getOperandTypes())) {
if (!llvm::isa<ShapeType>(resultTy))
return op->emitOpError()
<< "if at least one of the operands can hold error values then "
"the result must be of type `shape` to propagate them";
}
return success();
}
template <typename... Ty>
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
}
template <typename... Ty, typename... ranges>
static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
}
//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for inlining shape dialect ops.
struct ShapeInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// Returns true if the given region 'src' can be inlined into the region
// 'dest' that is attached to an operation registered to the current dialect.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &) const final {
return true;
}
// Returns true if the given operation 'op', that is registered to this
// dialect, can be inlined into the region 'dest' that is attached to an
// operation registered to the current dialect.
bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
IRMapping &) const final {
return true;
}
};
} // namespace
void ShapeDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
>();
addInterfaces<ShapeInlinerInterface>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try different variants before actually defining the op.
allowUnknownOperations();
}
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
return builder.create<ConstShapeOp>(
loc, type, llvm::cast<DenseIntElementsAttr>(value));
if (llvm::isa<SizeType>(type))
return builder.create<ConstSizeOp>(loc, type,
llvm::cast<IntegerAttr>(value));
if (llvm::isa<WitnessType>(type))
return builder.create<ConstWitnessOp>(loc, type,
llvm::cast<BoolAttr>(value));
return arith::ConstantOp::materialize(builder, value, type, loc);
}
LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) {
// Verify shape.lib attribute.
if (attribute.getName() == "shape.lib") {
if (!op->hasTrait<OpTrait::SymbolTable>())
return op->emitError(
"shape.lib attribute may only be on op implementing SymbolTable");
if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
if (!symbol)
return op->emitError("shape function library ")
<< symbolRef << " not found";
return isa<shape::FunctionLibraryOp>(symbol)
? success()
: op->emitError()
<< symbolRef << " required to be shape function library";
}
if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
// Verify all entries are function libraries and mappings in libraries
// refer to unique ops.
DenseSet<StringAttr> key;
for (auto it : arr) {
if (!llvm::isa<SymbolRefAttr>(it))
return op->emitError(
"only SymbolRefAttr allowed in shape.lib attribute array");
auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
if (!shapeFnLib)
return op->emitError()
<< it << " does not refer to FunctionLibraryOp";
for (auto mapping : shapeFnLib.getMapping()) {
if (!key.insert(mapping.getName()).second) {
return op->emitError("only one op to shape mapping allowed, found "
"multiple for `")
<< mapping.getName() << "`";
}
}
}
return success();
}
return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
"allowed as shape.lib attribute");
}
return success();
}
//===----------------------------------------------------------------------===//
// AnyOp
//===----------------------------------------------------------------------===//
// TODO: Canonicalization should be implemented for shapes that can be
// determined through mixtures of the known dimensions of the inputs.
OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
// Only the last operand is checked because AnyOp is commutative.
if (adaptor.getInputs().back())
return adaptor.getInputs().back();
return nullptr;
}
//===----------------------------------------------------------------------===//
// AssumingOp
//===----------------------------------------------------------------------===//
ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
result.regions.reserve(1);
Region *doRegion = result.addRegion();
auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
if (parser.parseOperand(cond) ||
parser.resolveOperand(cond, builder.getType<WitnessType>(),
result.operands))
return failure();
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
// Parse the region and add a terminator if elided.
if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
void AssumingOp::print(OpAsmPrinter &p) {
bool yieldsResults = !getResults().empty();
p << " " << getWitness();
if (yieldsResults)
p << " -> (" << getResultTypes() << ")";
p << ' ';
p.printRegion(getDoRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/yieldsResults);
p.printOptionalAttrDict((*this)->getAttrs());
}
namespace {
// Removes AssumingOp with a passing witness and inlines the region.
struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
using OpRewritePattern<AssumingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingOp op,
PatternRewriter &rewriter) const override {
auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
if (!witness || !witness.getPassingAttr())
return failure();
AssumingOp::inlineRegionIntoParent(op, rewriter);
return success();
}
};
struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
using OpRewritePattern<AssumingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingOp op,
PatternRewriter &rewriter) const override {
Block *body = op.getBody();
auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
// Find used values.
SmallVector<Value, 4> newYieldOperands;
for (auto [opResult, yieldOperand] :
llvm::zip(op.getResults(), yieldOp.getOperands())) {
if (!opResult.getUses().empty()) {
newYieldOperands.push_back(yieldOperand);
}
}
// Rewrite only if redundant results exist.
if (newYieldOperands.size() == yieldOp->getNumOperands())
return failure();
// Replace yield op in the old assuming op's body and move the entire region
// to the new assuming op.
rewriter.setInsertionPointToEnd(body);
auto newYieldOp =
rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
rewriter.setInsertionPoint(op);
auto newOp = rewriter.create<AssumingOp>(
op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
newOp.getDoRegion().takeBody(op.getDoRegion());
// Use the new results to replace the previously used ones.
SmallVector<Value, 4> replacementValues;
auto src = newOp.getResults().begin();
for (auto it : op.getResults()) {
if (it.getUses().empty())
replacementValues.push_back(nullptr);
else
replacementValues.push_back(*src++);
}
rewriter.replaceOp(op, replacementValues);
return success();
}
};
} // namespace
void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
}
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
regions.push_back(RegionSuccessor(&getDoRegion()));
}
void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
PatternRewriter &rewriter) {
auto *blockBeforeAssuming = rewriter.getInsertionBlock();
auto *assumingBlock = op.getBody();
auto initPosition = rewriter.getInsertionPoint();
auto *blockAfterAssuming =
rewriter.splitBlock(blockBeforeAssuming, initPosition);
// Remove the AssumingOp and AssumingYieldOp.
auto &yieldOp = assumingBlock->back();
rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
rewriter.replaceOp(op, yieldOp.getOperands());
rewriter.eraseOp(&yieldOp);
// Merge blocks together as there was no branching behavior from the
// AssumingOp.
rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
}
void AssumingOp::build(
OpBuilder &builder, OperationState &result, Value witness,
function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
result.addOperands(witness);
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
// Build body.
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
builder.create<AssumingYieldOp>(result.location, yieldValues);
SmallVector<Type, 2> assumingTypes;
for (Value v : yieldValues)
assumingTypes.push_back(v.getType());
result.addTypes(assumingTypes);
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
// add(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](APInt a, const APInt &b) { return std::move(a) + b; });
}
LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
namespace {
// Merge multiple `shape.assuming_all` operations together.
//
// %0 = shape.assuming_all %w0, %w1
// %1 = shape.assuming_all %w2, %0
//
// to:
//
// %0 = shape.assuming_all %w0, %w2, %w2
struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> operands;
for (Value operand : op.getInputs()) {
if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
else
operands.push_back(operand);
}
// We didn't find any other `assuming_all` ops to merge with.
if (operands.size() == op.getNumOperands())
return failure();
// Replace with a new `assuming_all` operation with merged constraints.
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
return success();
}
};
// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
// are subsumed by others.
//
// %0 = shape.cstr_broadcastable %shape0, %shape1
// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
//
// %2 = shape.cstr_broadcastable %shape3, %shape4
// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
//
// %4 = shape.assuming_all %0, %1, %2, %3
//
// to:
//
// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
// %2 = shape.assuming_all %0, %1
//
// In this example if shapes [0, 1, 2] are broadcastable, then it means that
// shapes [0, 1] are broadcastable too, and can be removed from the list of
// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
// Collect all `CstrBroadcastableOp` operands first.
SetVector<CstrBroadcastableOp> operands;
for (Value operand : op.getInputs()) {
// TODO: Apply this optimization if some of the witnesses are not
// produced by the `cstr_broadcastable`.
auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
if (!broadcastable)
return failure();
operands.insert(broadcastable);
}
// Skip trivial `assuming_all` operations.
if (operands.size() <= 1)
return failure();
// Collect shapes checked by `cstr_broadcastable` operands.
SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
for (auto cstr : operands) {
DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
shapes.emplace_back(cstr, std::move(shapesSet));
}
// Sort by the number of shape operands (larger to smaller).
llvm::sort(shapes, [](auto a, auto b) {
return a.first.getNumOperands() > b.first.getNumOperands();
});
// We start from the `cst_broadcastable` operations with largest number of
// shape operands, and remove redundant `cst_broadcastable` operations. We
// do this until we find a set of `cst_broadcastable` operations with
// non-overlapping constraints.
SmallVector<CstrBroadcastableOp> markedForErase;
for (unsigned i = 0; i < shapes.size(); ++i) {
auto isSubset = [&](auto pair) {
return llvm::set_is_subset(pair.second, shapes[i].second);
};
// Keep redundant `cstr_broadcastable` operations to be erased.
auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
for (auto *it0 = it; it0 < shapes.end(); ++it0)
markedForErase.push_back(it0->first);
shapes.erase(it, shapes.end());
}
// We didn't find any operands that could be removed.
if (markedForErase.empty())
return failure();
// Collect non-overlapping `cst_broadcastable` constraints.
SmallVector<Value> uniqueConstraints;
for (auto &shape : shapes)
uniqueConstraints.push_back(shape.first.getResult());
// Replace with a new `assuming_all` operation ...
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
// ... and maybe erase `cstr_broadcastable` ops without uses.
for (auto &op : markedForErase)
if (op->use_empty())
rewriter.eraseOp(op);
return success();
}
};
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> shapes;
for (Value w : op.getInputs()) {
auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
if (!cstrEqOp)
return failure();
bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
return llvm::is_contained(shapes, s);
});
if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
return failure();
shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
}
rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
return success();
}
};
template <typename OpTy>
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find unique operands.
SetVector<Value> unique(op.operand_begin(), op.operand_end());
// Reduce op to equivalent with unique operands.
if (unique.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
unique.takeVector(), op->getAttrs());
return success();
}
return failure();
}
};
} // namespace
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns
.add<MergeAssumingAllOps, AssumingAllOneOp,
AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
// Iterate in reverse to first handle all constant operands. They are
// guaranteed to be the tail of the inputs because this is commutative.
for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
Attribute a = adaptor.getInputs()[idx];
// Cannot fold if any inputs are not constant;
if (!a)
return nullptr;
// We do not need to keep statically known values after handling them in
// this method.
getOperation()->eraseOperand(idx);
// Always false if any input is statically known false
if (!llvm::cast<BoolAttr>(a).getValue())
return a;
}
// If this is reached, all inputs were statically known passing.
return BoolAttr::get(getContext(), true);
}
LogicalResult AssumingAllOp::verify() {
// Ensure that AssumingAllOp contains at least one operand
if (getNumOperands() == 0)
return emitOpError("no operands specified");
return success();
}
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getShapes().size() == 1) {
// Otherwise, we need a cast which would be a canonicalization, not folding.
if (getShapes().front().getType() != getType())
return nullptr;
return getShapes().front();
}
// TODO: Support folding with more than 2 input shapes
if (getShapes().size() > 2)
return nullptr;
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
return nullptr;
auto lhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
.getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
.getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
return nullptr;
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
}
LogicalResult BroadcastOp::verify() {
return verifyShapeOrExtentTensorOp(*this);
}
namespace {
template <typename OpTy>
struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto isPotentiallyNonEmptyShape = [](Value shape) {
if (auto extentTensorTy =
llvm::dyn_cast<RankedTensorType>(shape.getType())) {
if (extentTensorTy.getDimSize(0) == 0)
return false;
}
if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
if (constShape.getShape().empty())
return false;
}
return true;
};
auto newOperands = llvm::to_vector<8>(
llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
// Reduce op to equivalent without empty shape operands.
if (newOperands.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
op->getAttrs());
return success();
}
return failure();
}
};
struct BroadcastForwardSingleOperandPattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 1)
return failure();
Value replacement = op.getShapes().front();
// Insert cast if needed.
if (replacement.getType() != op.getType()) {
auto loc = op.getLoc();
if (llvm::isa<ShapeType>(op.getType())) {
replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
} else {
assert(!llvm::isa<ShapeType>(op.getType()) &&
!llvm::isa<ShapeType>(replacement.getType()) &&
"expect extent tensor cast");
replacement =
rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
}
}
rewriter.replaceOp(op, replacement);
return success();
}
};
struct BroadcastFoldConstantOperandsPattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t, 8> foldedConstantShape;
SmallVector<Value, 8> newShapeOperands;
for (Value shape : op.getShapes()) {
if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
SmallVector<int64_t, 8> newFoldedConstantShape;
if (OpTrait::util::getBroadcastedShape(
foldedConstantShape,
llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
newFoldedConstantShape)) {
foldedConstantShape = newFoldedConstantShape;
continue;
}
}
newShapeOperands.push_back(shape);
}
// Need at least two constant operands to fold anything.
if (op.getNumOperands() - newShapeOperands.size() < 2)
return failure();
auto foldedConstantOperandsTy = RankedTensorType::get(
{static_cast<int64_t>(foldedConstantShape.size())},
rewriter.getIndexType());
newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
op.getLoc(), foldedConstantOperandsTy,
rewriter.getIndexTensorAttr(foldedConstantShape)));
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
newShapeOperands);
return success();
}
};
template <typename OpTy>
struct CanonicalizeCastExtentTensorOperandsPattern
: public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Canonicalize operands.
bool anyChange = false;
auto canonicalizeOperand = [&](Value operand) -> Value {
if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
// Only eliminate the cast if it holds no shape information.
bool isInformationLoosingCast =
llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
if (isInformationLoosingCast) {
anyChange = true;
return castOp.getSource();
}
}
return operand;
};
auto newOperands = llvm::to_vector<8>(
llvm::map_range(op.getOperands(), canonicalizeOperand));
// Rewrite op if any change required.
if (!anyChange)
return failure();
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
return success();
}
};
struct BroadcastConcretizeResultTypePattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
// Only concretize dynamic extent tensor result types.
auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
if (!resultTy || !resultTy.isDynamicDim(0))
return failure();
// Infer resulting shape rank if possible.
int64_t maxRank = 0;
for (Value shape : op.getShapes()) {
if (auto extentTensorTy =
llvm::dyn_cast<RankedTensorType>(shape.getType())) {
// Cannot infer resulting shape rank if any operand is dynamically
// ranked.
if (extentTensorTy.isDynamicDim(0))
return failure();
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
}
}
auto newOp = rewriter.create<BroadcastOp>(
op.getLoc(), getExtentTensorType(getContext(), maxRank),
op.getShapes());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<BroadcastConcretizeResultTypePattern,
BroadcastFoldConstantOperandsPattern,
BroadcastForwardSingleOperandPattern,
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
RemoveDuplicateOperandsPattern<BroadcastOp>,
RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getLhs() || !adaptor.getRhs())
return nullptr;
auto lhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
resultShape.append(lhsShape.begin(), lhsShape.end());
resultShape.append(rhsShape.begin(), rhsShape.end());
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
}
//===----------------------------------------------------------------------===//
// ConstShapeOp
//===----------------------------------------------------------------------===//
void ConstShapeOp::print(OpAsmPrinter &p) {
p << " ";
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
p << "[";
interleaveComma(getShape().getValues<int64_t>(), p);
p << "] : ";
p.printType(getType());
}
ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// We piggy-back on ArrayAttr parsing, though we don't internally store the
// shape as an ArrayAttr.
// TODO: Implement custom parser and maybe make syntax a bit more concise.
Attribute extentsRaw;
NamedAttrList dummy;
if (parser.parseAttribute(extentsRaw, "dummy", dummy))
return failure();
auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
if (!extentsArray)
return failure();
SmallVector<int64_t, 6> ints;
for (Attribute extent : extentsArray) {
IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
if (!attr)
return failure();
ints.push_back(attr.getInt());
}
Builder &builder = parser.getBuilder();
result.addAttribute("shape", builder.getIndexTensorAttr(ints));
Type resultTy;
if (parser.parseColonType(resultTy))
return failure();
result.types.push_back(resultTy);
return success();
}
OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<TensorCastConstShape>(context);
}
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
Properties *prop = properties.as<Properties *>();
DenseIntElementsAttr shape;
// TODO: this is only exercised by the Python bindings codepath which does not
// support properties
if (prop)
shape = prop->shape;
else
shape = attributes.getAs<DenseIntElementsAttr>("shape");
if (!shape)
return emitOptionalError(location, "missing shape attribute");
inferredReturnTypes.assign({RankedTensorType::get(
{static_cast<int64_t>(shape.size())}, b.getIndexType())});
return success();
}
bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
Type lhs = l.front();
Type rhs = r.front();
if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
// Shape type is compatible with all other valid return types.
return true;
return lhs == rhs;
}
//===----------------------------------------------------------------------===//
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//
void CstrBroadcastableOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// Canonicalization patterns have overlap with the considerations during
// folding in case additional shape information is inferred at some point that
// does not result in folding.
patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
CstrBroadcastableEqOps,
RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
}
// Return true if there is exactly one attribute not representing a scalar
// broadcast.
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
bool nonScalarSeen = false;
for (Attribute a : attributes) {
if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
if (nonScalarSeen)
return false;
nonScalarSeen = true;
}
}
return true;
}
OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
// No broadcasting is needed if all operands but one are scalar.
if (hasAtMostSingleNonScalar(adaptor.getShapes()))
return BoolAttr::get(getContext(), true);
if ([&] {
SmallVector<SmallVector<int64_t, 6>, 6> extents;
for (const auto &operand : adaptor.getShapes()) {
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
return BoolAttr::get(getContext(), true);
// Lastly, see if folding can be completed based on what constraints are known
// on the input shapes.
if ([&] {
SmallVector<SmallVector<int64_t, 6>, 6> extents;
for (auto shapeValue : getShapes()) {
extents.emplace_back();
if (failed(getShapeVec(shapeValue, extents.back())))
return false;
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
// failure, we do not replace it with a constant witness.
return nullptr;
}
LogicalResult CstrBroadcastableOp::verify() {
// Ensure that CstrBroadcastableOp contains at least two operands
if (getNumOperands() < 2)
return emitOpError("required at least 2 input shapes");
return success();
}
//===----------------------------------------------------------------------===//
// CstrEqOp
//===----------------------------------------------------------------------===//
void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// If inputs are equal, return passing witness
patterns.add<CstrEqEqOps>(context);
}
OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
return a && a == adaptor.getShapes().front();
}))
return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
// failure, we do not try to replace it with a constant witness. Similarly, we
// cannot if there are any non-const inputs.
return nullptr;
}
//===----------------------------------------------------------------------===//
// ConstSizeOp
//===----------------------------------------------------------------------===//
void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
int64_t value) {
build(builder, result, builder.getIndexAttr(value));
}
OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
void ConstSizeOp::getAsmResultNames(
llvm::function_ref<void(Value, StringRef)> setNameFn) {
SmallString<4> buffer;
llvm::raw_svector_ostream os(buffer);
os << "c" << getValue();
setNameFn(getResult(), os.str());
}
//===----------------------------------------------------------------------===//
// ConstWitnessOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
//===----------------------------------------------------------------------===//
// CstrRequireOp
//===----------------------------------------------------------------------===//
OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
return adaptor.getPred();
}
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
std::optional<int64_t> DimOp::getConstantIndex() {
if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
return constSizeOp.getValue().getLimitedValue();
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
return std::nullopt;
}
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
Type valType = getValue().getType();
auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
if (!valShapedType || !valShapedType.hasRank())
return nullptr;
std::optional<int64_t> index = getConstantIndex();
if (!index.has_value())
return nullptr;
if (index.value() >= valShapedType.getRank())
return nullptr;
auto extent = valShapedType.getDimSize(*index);
if (ShapedType::isDynamic(extent))
return nullptr;
return IntegerAttr::get(IndexType::get(getContext()), extent);
}
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
DimOpAdaptor dimOp(operands);
inferredReturnTypes.assign({dimOp.getIndex().getType()});
return success();
}
bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult mlir::shape::DimOp::verify() {
auto st = llvm::cast<ShapedType>(getValue().getType());
if (!st.hasRank())
return success();
if (auto index = getConstantIndex()) {
if (*index < 0 || *index >= st.getRank())
return emitOpError("index is out of range");
}
return success();
}
//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
// Division in APInt does not follow floor(lhs, rhs) when the result is
// negative. Rather, APInt rounds toward zero.
APInt quotient, remainder;
APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
if (quotient.isNegative() && !remainder.isZero()) {
quotient -= 1;
}
Type indexTy = IndexType::get(getContext());
return IntegerAttr::get(indexTy, quotient);
}
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
// ShapeEqOp
//===----------------------------------------------------------------------===//
OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
bool allSame = true;
if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
return {};
for (Attribute operand : adaptor.getShapes().drop_front()) {
if (!operand)
return {};
allSame = allSame && operand == adaptor.getShapes().front();
}
return BoolAttr::get(getContext(), allSame);
}
//===----------------------------------------------------------------------===//
// IndexToSizeOp
//===----------------------------------------------------------------------===//
OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
// Constant values of both types, `shape.size` and `index`, are represented as
// `IntegerAttr`s which makes constant folding simple.
if (Attribute arg = adaptor.getArg())
return arg;
return {};
}
void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<SizeToIndexToSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
// FromExtentsOp
//===----------------------------------------------------------------------===//
OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
return nullptr;
SmallVector<int64_t, 6> extents;
for (auto attr : adaptor.getExtents())
extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
Builder builder(getContext());
return builder.getIndexTensorAttr(extents);
}
//===----------------------------------------------------------------------===//
// FunctionLibraryOp
//===----------------------------------------------------------------------===//
void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
StringRef name) {
result.attributes.push_back(builder.getNamedAttr(
::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
}
FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
getMapping().get(op->getName().getIdentifier()));
if (!attr)
return nullptr;
return lookupSymbol<FuncOp>(attr);
}
ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the op name.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
auto *bodyRegion = result.addRegion();
if (parser.parseRegion(*bodyRegion))
return failure();
if (parser.parseKeyword("mapping"))
return failure();
DictionaryAttr mappingAttr;
if (parser.parseAttribute(mappingAttr,
parser.getBuilder().getType<NoneType>(), "mapping",
result.attributes))
return failure();
return success();
}
void FunctionLibraryOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getName());
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p << " mapping ";
p.printAttributeWithoutType(getMappingAttr());
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) {
OpBuilder builder(location->getContext());
OperationState state(location, getOperationName());
FuncOp::build(builder, state, name, type, attrs);
return cast<FuncOp>(Operation::create(state));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
Operation::dialect_attr_range attrs) {
SmallVector<NamedAttribute, 8> attrRef(attrs);
return create(location, name, type, llvm::ArrayRef(attrRef));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
FuncOp func = create(location, name, type, attrs);
func.setAllArgAttrs(argAttrs);
return func;
}
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(FuncOp::getSymNameAttrName(state.name),
builder.getStringAttr(name));
state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//
std::optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
return constSizeOp.getValue().getLimitedValue();
if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
return std::nullopt;
}
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!elements)
return nullptr;
std::optional<int64_t> dim = getConstantDim();
if (!dim.has_value())
return nullptr;
if (dim.value() >= elements.getNumElements())
return nullptr;
return elements.getValues<Attribute>()[(uint64_t)dim.value()];
}
void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
int64_t dim) {
auto loc = result.location;
auto dimAttr = builder.getIndexAttr(dim);
if (llvm::isa<ShapeType>(shape.getType())) {
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
build(builder, result, builder.getType<SizeType>(), shape, dim);
} else {
Value dim =
builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
build(builder, result, builder.getIndexType(), shape, dim);
}
}
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
// IsBroadcastableOp
//===----------------------------------------------------------------------===//
void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
}
OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
// Can always broadcast fewer than two shapes.
if (adaptor.getShapes().size() < 2) {
return BoolAttr::get(getContext(), true);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// MeetOp
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
return failure();
auto isShapeType = [](Type arg) {
if (llvm::isa<ShapeType>(arg))
return true;
return isExtentTensorType(arg);
};
ValueRange::type_range types = operands.getTypes();
Type acc = types.front();
for (auto t : drop_begin(types)) {
Type l = acc, r = t;
if (!llvm::isa<ShapeType, SizeType>(l))
std::swap(l, r);
// Handle sizes, propagate error type if present.
if (llvm::isa<SizeType>(l)) {
if (llvm::isa<SizeType, IndexType>(r))
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (llvm::isa<IndexType>(l)) {
if (llvm::isa<IndexType>(r))
acc = r;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (llvm::isa<ShapeType>(l)) {
// Handle shapes, propagate error type if present.
if (isShapeType(r))
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (isExtentTensorType(l)) {
auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
if (ShapedType::isDynamic(rank1))
acc = l;
else if (ShapedType::isDynamic(rank2))
acc = r;
else if (rank1 != rank2)
return emitOptionalError(location, "unequal shape cardinality");
else
acc = l;
}
}
inferredReturnTypes.assign({acc});
return success();
}
bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)
return true;
Type lhs = l.front();
Type rhs = r.front();
if (!llvm::isa<ShapeType, SizeType>(lhs))
std::swap(lhs, rhs);
if (llvm::isa<SizeType>(lhs))
return llvm::isa<SizeType, IndexType>(rhs);
if (llvm::isa<ShapeType>(lhs))
return llvm::isa<ShapeType, TensorType>(rhs);
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
return false;
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!shape)
return {};
int64_t rank = shape.getNumElements();
Builder builder(getContext());
return builder.getIndexAttr(rank);
}
/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
/// Constant folding fails in cases where only the rank is constant, not the
/// shape itself.
/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
///
/// Example:
///
/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
/// %rank = shape.rank %shape
///
/// becomes
///
/// %rank = shape.const_size 3
namespace {
struct RankShapeOfCanonicalizationPattern
: public OpRewritePattern<shape::RankOp> {
using OpRewritePattern<shape::RankOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::RankOp op,
PatternRewriter &rewriter) const override {
auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
return failure();
auto rankedTensorType =
llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
if (!rankedTensorType)
return failure();
int64_t rank = rankedTensorType.getRank();
if (llvm::isa<IndexType>(op.getType())) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
rank);
} else if (llvm::isa<shape::SizeType>(op.getType())) {
rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
} else {
return failure();
}
return success();
}
};
} // namespace
void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<RankShapeOfCanonicalizationPattern>(context);
}
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ShapeType>(operands[0].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
// NumElementsOp
//===----------------------------------------------------------------------===//
OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
// Fold only when argument constant.
Attribute shape = adaptor.getShape();
if (!shape)
return {};
APInt product(64, 1);
for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
product *= value;
Builder builder(getContext());
return builder.getIndexAttr(product.getLimitedValue());
}
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ShapeType>(operands[0].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult shape::NumElementsOp::verify() {
return verifySizeOrIndexOp(*this);
}
//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//
OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
// If operands are equal, just propagate one.
if (getLhs() == getRhs())
return getLhs();
return nullptr;
}
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
}
bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
return true;
if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
return true;
return false;
}
//===----------------------------------------------------------------------===//
// MinOp
//===----------------------------------------------------------------------===//
OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
// If operands are equal, just propagate one.
if (getLhs() == getRhs())
return getLhs();
return nullptr;
}
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
}
bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
return true;
if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
return true;
return false;
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
Type indexTy = IndexType::get(getContext());
return IntegerAttr::get(indexTy, folded);
}
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//
OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
if (!type || !type.hasStaticShape())
return nullptr;
Builder builder(getContext());
return builder.getIndexTensorAttr(type.getShape());
}
namespace {
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (!llvm::isa<ShapedType>(op.getArg().getType()))
return failure();
if (llvm::isa<ShapedType>(op.getType()))
return failure();
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
op.getArg());
return success();
}
};
// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
// ```
// to
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
if (!ty || ty.getRank() != 1)
return failure();
auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
return failure();
// Argument type must be ranked and must not conflict.
auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
return failure();
rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
return success();
}
};
} // namespace
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
ExtractFromShapeOfExtentTensor>(context);
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ValueShapeType>(operands[0].getType()))
inferredReturnTypes.assign({ShapeType::get(context)});
else {
auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
Type indexTy = IndexType::get(context);
Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
inferredReturnTypes.assign({extentTensorTy});
}
return success();
}
bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)
return true;
Type lhs = l.front();
Type rhs = r.front();
if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
!llvm::isa<ShapeType, ShapedType>(rhs))
return false;
if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
// Shape type is compatible with all other valid return types.
return true;
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
return false;
}
LogicalResult shape::ShapeOfOp::verify() {
return verifyShapeOrExtentTensorOp(*this);
}
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//
OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
// Constant values of both types, `shape.size` and `index`, are represented as
// `IntegerAttr`s which makes constant folding simple.
if (Attribute arg = adaptor.getArg())
return arg;
return OpFoldResult();
}
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<IndexToSizeToIndexCanonicalization>(context);
}
bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
return llvm::isa<IndexType, SizeType>(inputs[0]) &&
llvm::isa<IndexType>(outputs[0]);
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
LogicalResult shape::YieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
auto results = parentOp->getResults();
auto operands = getOperands();
if (parentOp->getNumResults() != getNumOperands())
return emitOpError() << "number of operands does not match number of "
"results of its parent";
for (auto e : llvm::zip(results, operands))
if (std::get<0>(e).getType() != std::get<1>(e).getType())
return emitOpError() << "types mismatch between yield op and its parent";
return success();
}
//===----------------------------------------------------------------------===//
// SplitAtOp
//===----------------------------------------------------------------------===//
LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!adaptor.getOperand() || !adaptor.getIndex())
return failure();
auto shapeVec = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
auto shape = llvm::ArrayRef(shapeVec);
auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
// Verify that the split point is in the correct range.
// TODO: Constant fold to an "error".
int64_t rank = shape.size();
if (-rank > splitPoint || splitPoint > rank)
return failure();
if (splitPoint < 0)
splitPoint += shape.size();
Builder builder(adaptor.getOperand().getContext());
results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
return success();
}
//===----------------------------------------------------------------------===//
// ToExtentTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getInput())
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
builder.getIndexType());
return DenseIntElementsAttr::get(type, shape);
}
bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
inputTensor.getRank() != 1)
return false;
} else if (!llvm::isa<ShapeType>(inputs[0])) {
return false;
}
TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
}
//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
ValueRange initVals) {
result.addOperands(shape);
result.addOperands(initVals);
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArgument(builder.getIndexType(), result.location);
Type elementType;
if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
elementType = tensorType.getElementType();
else
elementType = SizeType::get(builder.getContext());
bodyBlock.addArgument(elementType, shape.getLoc());
for (Value initVal : initVals) {
bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
result.addTypes(initVal.getType());
}
}
LogicalResult ReduceOp::verify() {
// Verify block arg types.
Block &block = getRegion().front();
// The block takes index, extent, and aggregated values as arguments.
auto blockArgsCount = getInitVals().size() + 2;
if (block.getNumArguments() != blockArgsCount)
return emitOpError() << "ReduceOp body is expected to have "
<< blockArgsCount << " arguments";
// The first block argument is the index and must always be of type `index`.
if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
return emitOpError(
"argument 0 of ReduceOp body is expected to be of IndexType");
// The second block argument is the extent and must be of type `size` or
// `index`, depending on whether the reduce operation is applied to a shape or
// to an extent tensor.
Type extentTy = block.getArgument(1).getType();
if (llvm::isa<ShapeType>(getShape().getType())) {
if (!llvm::isa<SizeType>(extentTy))
return emitOpError("argument 1 of ReduceOp body is expected to be of "
"SizeType if the ReduceOp operates on a ShapeType");
} else {
if (!llvm::isa<IndexType>(extentTy))
return emitOpError(
"argument 1 of ReduceOp body is expected to be of IndexType if the "
"ReduceOp operates on an extent tensor");
}
for (const auto &type : llvm::enumerate(getInitVals()))
if (block.getArgument(type.index() + 2).getType() != type.value().getType())
return emitOpError() << "type mismatch between argument "
<< type.index() + 2
<< " of ReduceOp body and initial value "
<< type.index();
return success();
}
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse operands.
SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
Type shapeOrExtentTensorType;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser.parseColonType(shapeOrExtentTensorType) ||
parser.parseOptionalArrowTypeList(result.types))
return failure();
// Resolve operands.
auto initVals = llvm::ArrayRef(operands).drop_front();
if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
result.operands) ||
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
result.operands))
return failure();
// Parse the body.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
return failure();
// Parse attributes.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
void ReduceOp::print(OpAsmPrinter &p) {
p << '(' << getShape() << ", " << getInitVals()
<< ") : " << getShape().getType();
p.printOptionalArrowTypeList(getResultTypes());
p << ' ';
p.printRegion(getRegion());
p.printOptionalAttrDict((*this)->getAttrs());
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"