llvm-project/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
River Riddle 1d7120c69a [mlir] Split out AttrDef/TypeDef and pattern constructs from OpBase.td
OpBase.td has formed into a huge monolith of all ODS constructs. This
commits starts to rectify that by splitting out some constructs to their
own .td files.

Differential Revision: https://reviews.llvm.org/D118636
2022-03-15 00:18:03 -07:00

55 lines
1.7 KiB
TableGen

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
return $0[0] == val;
})
}]>>;
def HasSingleElement : Constraint<CPred< [{
$0.size() == 1
}]>>;
def HasStaticShape : Constraint<CPred< [{
$0.getType().dyn_cast<ShapedType>().hasStaticShape()
}]>>;
// Helper that takes the first element of a range.
def TakeFront : NativeCodeCall<"$0.front()">;
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
(replaceWithValue $args),
[(HasSingleElement $args)]>;
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
[(AllInputShapesEq $shapes)]>;
def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
[(AllInputShapesEq $shapes)]>;
def IndexToSizeToIndexCanonicalization : Pat<
(Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
(replaceWithValue $arg)>;
def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;
// tensor.extract from shape_of -> tensor.dim. We can take the first index
// because shape_of always returns a 1D tensor.
def ExtractFromShapeOfExtentTensor : Pat<
(Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
(Tensor_DimOp $arg, (TakeFront $indices))>;