[mlir][OpDSL] Rename PrimFn to ArithFn.

The revision renames `PrimFn` to `ArithFn`. The name resembles the newly introduced arith dialect that implements most of the arithmetic functions. An exception are log/exp that are part of the math dialect.

Depends On D115239

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D115240
This commit is contained in:
gysit 2022-01-07 12:37:52 +00:00
parent 15757ea80a
commit cf05668c17
11 changed files with 208 additions and 186 deletions

View File

@ -177,14 +177,20 @@ TODO: Introduce a directive to fix the dimension bindings.
Reduction dimensions are inferred to be any dimensions on the RHS that are not
on the LHS.
A number of arithmetic primitive functions are supported:
A number of arithmetic functions are supported:
* `PrimFn.add(a, b)` (also via overloading the binary `+` operator)
* `PrimFn.exp(a)`
* `PrimFn.log(a)`
* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator)
* `PrimFn.max(a, b)`
* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator)
* `ArithFn.add(a, b)` (also via overloading the binary `+` operator)
* `ArithFn.exp(a)`
* `ArithFn.log(a)`
* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator)
* `ArithFn.max(a, b)`
* `ArithFn.min(a, b)`
* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator)
* `ArithFn.max_unsigned(a, b)`
* `ArithFn.min_unsigned(a, b)`
As the integer types are signless, signedness is implement by different
functions that treat integers as signed or unsigned values.
Reduction functions can appear as the outer-most function on the RHS:
@ -233,6 +239,8 @@ The following examples illustrate the lowering of signed and unsigned functions:
* cast(F32 -> I32) -> `arith.FPToSIOp`
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
* max -> `arith.MaxSIOp`
* max_unsinged -> `arith.MaxUIOp`
Not all functions are applicable for all numeric types, and on mismatch, op
verification will fail.

View File

@ -41,13 +41,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -105,13 +105,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -179,17 +179,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -207,7 +207,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -276,13 +276,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: accum
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -341,13 +341,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -416,17 +416,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -444,7 +444,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -501,13 +501,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -564,13 +564,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -628,13 +628,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -690,13 +690,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -753,13 +753,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -818,13 +818,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -886,13 +886,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -964,13 +964,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1054,13 +1054,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1157,17 +1157,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1185,7 +1185,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1269,13 +1269,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1359,13 +1359,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1436,13 +1436,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1519,13 +1519,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1613,17 +1613,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1641,7 +1641,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1721,13 +1721,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@ -1819,17 +1819,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1847,7 +1847,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -1923,7 +1923,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -1994,7 +1994,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: max
operands:
- !ScalarExpression
@ -2065,7 +2065,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: max_unsigned
operands:
- !ScalarExpression
@ -2136,7 +2136,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: max
operands:
- !ScalarExpression
@ -2207,7 +2207,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: min
operands:
- !ScalarExpression
@ -2278,7 +2278,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: min_unsigned
operands:
- !ScalarExpression
@ -2355,7 +2355,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -2432,7 +2432,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: max
operands:
- !ScalarExpression
@ -2509,7 +2509,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: min
operands:
- !ScalarExpression
@ -2572,15 +2572,15 @@ structured_op: !LinalgStructuredOpConfig
type_var: T
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -2596,15 +2596,15 @@ structured_op: !LinalgStructuredOpConfig
type_var: F64
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -2615,15 +2615,15 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_index: 1
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -2664,11 +2664,11 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_const: '12345 : i64'
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@ -2716,11 +2716,11 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: log
operands:
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -2731,7 +2731,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
- !ScalarExpression
scalar_apply:
arith_fn:
fn_name: exp
operands:
- !ScalarExpression

View File

@ -148,11 +148,11 @@ static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code
// and bind by name to math and type conversion functions in the DSL as:
// `applyfn__{fnName}`
// `arithfn__{fnName}`
// `typefn__{fnName}`
// Examples:
// `applyfn__add`
// `applyfn__mul`
// `arithfn__add`
// `arithfn__mul`
// `typefn__cast`
// The naming convention is intentional in order to match snake-cased DSL names.
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
@ -241,7 +241,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__add(Value lhs, Value rhs) {
Value arithfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
@ -251,7 +251,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__exp(Value x) {
Value arithfn__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
@ -259,7 +259,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__log(Value x) {
Value arithfn__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
@ -267,7 +267,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__sub(Value lhs, Value rhs) {
Value arithfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
@ -277,7 +277,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__mul(Value lhs, Value rhs) {
Value arithfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
@ -287,7 +287,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__max(Value lhs, Value rhs) {
Value arithfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@ -297,7 +297,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__max_unsigned(Value lhs, Value rhs) {
Value arithfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@ -307,7 +307,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__min(Value lhs, Value rhs) {
Value arithfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
@ -317,7 +317,7 @@ public:
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__min_unsigned(Value lhs, Value rhs) {
Value arithfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);

View File

@ -77,13 +77,13 @@ class TensorExpression:
self.visit_tensor_exprs(visit_scalar_def)
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
return PrimFn.add(self, rhs)
return ArithFn.add(self, rhs)
def __mul__(self, rhs) -> "TensorExpression":
return PrimFn.mul(self, rhs)
return ArithFn.mul(self, rhs)
def __sub__(self, rhs) -> "TensorExpression":
return PrimFn.sub(self, rhs)
return ArithFn.sub(self, rhs)
def __hash__(self):
return hash(id(self))
@ -347,42 +347,55 @@ class TypeFn:
cast_unsigned = TypeFnType("cast_unsigned")
class PrimFnType:
"""Primitive operations."""
class ArithFnType:
"""Arithmetic function.
def __init__(self, prim_name: str):
self.prim_name = prim_name
An arithmetic function takes one ore more tensor expressions and returns the
function evaluation result.
"""
def __call__(self, *args):
return PrimApply(self, args)
def __init__(self, fn_name: str):
self.fn_name = fn_name
def __call__(self, *args) -> "TensorArithFn":
return TensorArithFn(self, args)
def reduce(self, *reduce_dims: DimDef):
"""Shortcut to create a Reduce operation from this primitive."""
"""Shortcut to create a Reduce operation from this function."""
return ReduceFnType(self, *reduce_dims)
def __repr__(self):
return f"{self.prim_name}"
return f"{self.fn_name}"
class PrimFn:
add = PrimFnType("add")
exp = PrimFnType("exp")
log = PrimFnType("log")
mul = PrimFnType("mul")
max = PrimFnType("max")
min = PrimFnType("min")
sub = PrimFnType("sub")
max_unsigned = PrimFnType("max_unsigned")
min_unsigned = PrimFnType("min_unsigned")
class ArithFn:
"""Arithmetic function namespace.
As the integer types are signless, signedness is implement by different
functions that treat integers as signed or unsigned values.
Examples:
- max -> `arith.MaxSIOp`
- max_unsinged -> `arith.MaxUIOp`
"""
add = ArithFnType("add")
exp = ArithFnType("exp")
log = ArithFnType("log")
mul = ArithFnType("mul")
max = ArithFnType("max")
min = ArithFnType("min")
sub = ArithFnType("sub")
max_unsigned = ArithFnType("max_unsigned")
min_unsigned = ArithFnType("min_unsigned")
class ReduceFnType:
"""A reduction operator that reduces into its LHS from its RHS."""
def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
"""Initializes the ReduceFn with a primitive function and dims."""
if not isinstance(operator, PrimFnType):
raise ValueError(f"Reduce expected a Prim operator but got {operator}")
def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
"""Initializes the ReduceFn with an airthmetic function and dims."""
if not isinstance(operator, ArithFnType):
raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
self.operator = operator
self.reduce_dims = tuple(reduce_dims)
@ -390,28 +403,28 @@ class ReduceFnType:
return ReduceApply(self, args)
def __repr__(self):
return (f"reduce_{self.operator.prim_name}"
return (f"reduce_{self.operator.fn_name}"
f"({', '.join(repr(d) for d in self.reduce_dims)})")
class ReduceFn:
add = PrimFn.add.reduce
mul = PrimFn.mul.reduce
max = PrimFn.max.reduce
min = PrimFn.min.reduce
max_unsigned = PrimFn.max_unsigned.reduce
min_unsigned = PrimFn.min_unsigned.reduce
add = ArithFn.add.reduce
mul = ArithFn.mul.reduce
max = ArithFn.max.reduce
min = ArithFn.min.reduce
max_unsigned = ArithFn.max_unsigned.reduce
min_unsigned = ArithFn.min_unsigned.reduce
class PrimApply(TensorExpression):
"""Application of a primitive."""
class TensorArithFn(TensorExpression):
"""Application of an arithmetic function."""
def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]):
self.prim = prim
def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]):
self.arith_fn = arith_fn
self.args = tuple(args)
def to_scalar_expression(self) -> ScalarExpression:
return ScalarApplyFn(self.prim.prim_name,
return ScalarArithFn(self.arith_fn.fn_name,
*[arg.to_scalar_expression() for arg in self.args
]).expr()
@ -421,7 +434,7 @@ class PrimApply(TensorExpression):
arg.visit_tensor_exprs(callback)
def __repr__(self):
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
class TensorTypeFn(TensorExpression):
@ -503,7 +516,7 @@ class ReduceApply(TensorExpression):
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
def visit_tensor_exprs(self, callback):
for arg in self.args:

View File

@ -221,10 +221,10 @@ class _BodyBuilder:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
elif expr.scalar_apply:
fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}")
elif expr.arith_fn:
fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
operand_values = [
self.expression(operand) for operand in expr.scalar_apply.operands
self.expression(operand) for operand in expr.arith_fn.operands
]
return fn(*operand_values)
elif expr.type_fn:
@ -310,59 +310,59 @@ class _BodyBuilder:
def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
def _eval_exp(self, x: Value) -> Value:
def _arithfn_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _eval_log(self, x: Value) -> Value:
def _arithfn_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):

View File

@ -20,7 +20,7 @@ from .types import *
__all__ = [
"ScalarAssign",
"ScalarApplyFn",
"ScalarArithFn",
"ScalarTypeFn",
"ScalarArg",
"ScalarConst",
@ -29,18 +29,18 @@ __all__ = [
]
class ScalarApplyFn:
"""A type of ScalarExpression that applies a named function to operands."""
class ScalarArithFn:
"""A type of ScalarExpression that applies an arithmetic function."""
def __init__(self, fn_name: str, *operands: "ScalarExpression"):
self.fn_name = fn_name
self.operands = operands
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_apply=self)
return ScalarExpression(arith_fn=self)
def __repr__(self):
return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
class ScalarTypeFn:
@ -102,7 +102,7 @@ class ScalarExpression(YAMLObject):
"""An expression on scalar values.
Can be one of:
- ScalarApplyFn
- ScalarArithFn
- ScalarTypeFn
- ScalarArg
- ScalarConst
@ -112,27 +112,27 @@ class ScalarExpression(YAMLObject):
yaml_tag = "!ScalarExpression"
def __init__(self,
scalar_apply: Optional[ScalarApplyFn] = None,
arith_fn: Optional[ScalarArithFn] = None,
type_fn: Optional[ScalarTypeFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) +
bool(scalar_const) + bool(scalar_index)) != 1:
raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', "
if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
bool(scalar_index)) != 1:
raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
"'scalar_const', 'scalar_index', must be specified")
self.scalar_apply = scalar_apply
self.arith_fn = arith_fn
self.type_fn = type_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
if self.scalar_apply:
if self.arith_fn:
return dict(
scalar_apply=dict(
fn_name=self.scalar_apply.fn_name,
operands=list(self.scalar_apply.operands),
arith_fn=dict(
fn_name=self.arith_fn.fn_name,
operands=list(self.arith_fn.operands),
))
if self.type_fn:
# Note that even though operands must be arity 1, we write it the

View File

@ -665,4 +665,4 @@ def soft_plus_2d(
"""
domain(D.m, D.n)
O[D.m, D.n] = \
PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n])))
ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n])))

View File

@ -34,7 +34,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
arith_fn:
fn_name: add
operands:
- !ScalarExpression
@ -89,7 +89,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
# @linalg_structured_op

View File

@ -9,10 +9,10 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: -
# CHECK: arg: C
# CHECK: value:
# CHECK: scalar_apply:
# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_apply:
# CHECK: arith_fn:
# CHECK: fn_name: mul
# CHECK: operands:
# CHECK: type_fn:
@ -36,10 +36,10 @@ def matmul(
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
# CHECK: scalar_apply:
# CHECK: arith_fn:
# CHECK: fn_name: sub
# CHECK: operands:
# CHECK: scalar_apply:
# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: type_fn:
@ -67,7 +67,7 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
# CHECK: scalar_apply:
# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1

View File

@ -35,8 +35,8 @@ def fill_rng_poly(
@linalg_structured_op
def soft_plus_poly(
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
O[D.m, D.n] = PrimFn.log(
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n])))
O[D.m, D.n] = ArithFn.log(
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n])))
@linalg_structured_op(op_name="custom_op_name")

View File

@ -82,7 +82,7 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
struct ScalarApply {
struct ScalarArithFn {
std::string fnName;
// NOTE: Must be pure heap allocated container (not SmallVector)
// due to recursive data type.
@ -101,7 +101,7 @@ struct ScalarExpression {
Optional<std::string> arg;
Optional<std::string> constant;
Optional<int64_t> index;
Optional<ScalarApply> apply;
Optional<ScalarArithFn> arithFn;
Optional<ScalarTypeFn> typeFn;
};
@ -245,9 +245,10 @@ struct MappingTraits<ScalarAssign> {
};
/// A scalar expression (RHS of an assignment). Must be one of:
/// - `scalar_arg`: Name of an argument to the op.
/// - `scalar_apply`: Result of evaluating a named function (see
/// `ScalarApply`).
/// - `scalar_arg`: An operation argument.
/// - `scalar_const`: A constant definition.
/// - `scalar_index`: An iteration index.
/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
template <>
struct MappingTraits<ScalarExpression> {
@ -255,7 +256,7 @@ struct MappingTraits<ScalarExpression> {
io.mapOptional("scalar_arg", info.arg);
io.mapOptional("scalar_const", info.constant);
io.mapOptional("scalar_index", info.index);
io.mapOptional("scalar_apply", info.apply);
io.mapOptional("arith_fn", info.arithFn);
io.mapOptional("type_fn", info.typeFn);
}
};
@ -266,8 +267,8 @@ struct MappingTraits<ScalarExpression> {
/// - `add(lhs, rhs)`
/// - `mul(lhs, rhs)`
template <>
struct MappingTraits<ScalarApply> {
static void mapping(IO &io, ScalarApply &info) {
struct MappingTraits<ScalarArithFn> {
static void mapping(IO &io, ScalarArithFn &info) {
io.mapRequired("fn_name", info.fnName);
io.mapRequired("operands", info.operands);
}
@ -944,11 +945,11 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
cppIdent, *expression.index));
return cppIdent;
}
if (expression.apply) {
if (expression.arithFn) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
for (ScalarExpression &operand : expression.apply->operands) {
for (ScalarExpression &operand : expression.arithFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
@ -956,8 +957,8 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
expression.apply->fnName,
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
expression.arithFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}