[mlir][OpDSL] Rename PrimFn
to ArithFn
.
The revision renames `PrimFn` to `ArithFn`. The name resembles the newly introduced arith dialect that implements most of the arithmetic functions. An exception are log/exp that are part of the math dialect. Depends On D115239 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115240
This commit is contained in:
parent
15757ea80a
commit
cf05668c17
@ -177,14 +177,20 @@ TODO: Introduce a directive to fix the dimension bindings.
|
||||
Reduction dimensions are inferred to be any dimensions on the RHS that are not
|
||||
on the LHS.
|
||||
|
||||
A number of arithmetic primitive functions are supported:
|
||||
A number of arithmetic functions are supported:
|
||||
|
||||
* `PrimFn.add(a, b)` (also via overloading the binary `+` operator)
|
||||
* `PrimFn.exp(a)`
|
||||
* `PrimFn.log(a)`
|
||||
* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator)
|
||||
* `PrimFn.max(a, b)`
|
||||
* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator)
|
||||
* `ArithFn.add(a, b)` (also via overloading the binary `+` operator)
|
||||
* `ArithFn.exp(a)`
|
||||
* `ArithFn.log(a)`
|
||||
* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator)
|
||||
* `ArithFn.max(a, b)`
|
||||
* `ArithFn.min(a, b)`
|
||||
* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator)
|
||||
* `ArithFn.max_unsigned(a, b)`
|
||||
* `ArithFn.min_unsigned(a, b)`
|
||||
|
||||
As the integer types are signless, signedness is implement by different
|
||||
functions that treat integers as signed or unsigned values.
|
||||
|
||||
Reduction functions can appear as the outer-most function on the RHS:
|
||||
|
||||
@ -233,6 +239,8 @@ The following examples illustrate the lowering of signed and unsigned functions:
|
||||
* cast(F32 -> I32) -> `arith.FPToSIOp`
|
||||
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
|
||||
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
|
||||
* max -> `arith.MaxSIOp`
|
||||
* max_unsinged -> `arith.MaxUIOp`
|
||||
|
||||
Not all functions are applicable for all numeric types, and on mismatch, op
|
||||
verification will fail.
|
||||
|
@ -41,13 +41,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -105,13 +105,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -179,17 +179,17 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -207,7 +207,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: AZp
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -276,13 +276,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: accum
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: accum
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -341,13 +341,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -416,17 +416,17 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -444,7 +444,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: AZp
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -501,13 +501,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: x
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: x
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -564,13 +564,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: x
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: x
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -628,13 +628,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -690,13 +690,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -753,13 +753,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -818,13 +818,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -886,13 +886,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -964,13 +964,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1054,13 +1054,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1157,17 +1157,17 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1185,7 +1185,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: IZp
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1269,13 +1269,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1359,13 +1359,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1436,13 +1436,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1519,13 +1519,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1613,17 +1613,17 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1641,7 +1641,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: IZp
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1721,13 +1721,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1819,17 +1819,17 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1847,7 +1847,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: IZp
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1923,7 +1923,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -1994,7 +1994,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: max
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2065,7 +2065,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: max_unsigned
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2136,7 +2136,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: max
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2207,7 +2207,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: min
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2278,7 +2278,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: min_unsigned
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2355,7 +2355,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2432,7 +2432,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: max
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2509,7 +2509,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: min
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2572,15 +2572,15 @@ structured_op: !LinalgStructuredOpConfig
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2596,15 +2596,15 @@ structured_op: !LinalgStructuredOpConfig
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2615,15 +2615,15 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_index: 1
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2664,11 +2664,11 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_const: '12345 : i64'
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2716,11 +2716,11 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: log
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -2731,7 +2731,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_const: '1.000000e+00 : f64'
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: exp
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -148,11 +148,11 @@ static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
|
||||
// TODO: Move this to a utility library.
|
||||
// The public methods on this class are referenced directly from generated code
|
||||
// and bind by name to math and type conversion functions in the DSL as:
|
||||
// `applyfn__{fnName}`
|
||||
// `arithfn__{fnName}`
|
||||
// `typefn__{fnName}`
|
||||
// Examples:
|
||||
// `applyfn__add`
|
||||
// `applyfn__mul`
|
||||
// `arithfn__add`
|
||||
// `arithfn__mul`
|
||||
// `typefn__cast`
|
||||
// The naming convention is intentional in order to match snake-cased DSL names.
|
||||
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
|
||||
@ -241,7 +241,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__add(Value lhs, Value rhs) {
|
||||
Value arithfn__add(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -251,7 +251,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__exp(Value x) {
|
||||
Value arithfn__exp(Value x) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(x))
|
||||
return builder.create<math::ExpOp>(x.getLoc(), x);
|
||||
@ -259,7 +259,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__log(Value x) {
|
||||
Value arithfn__log(Value x) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(x))
|
||||
return builder.create<math::LogOp>(x.getLoc(), x);
|
||||
@ -267,7 +267,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__sub(Value lhs, Value rhs) {
|
||||
Value arithfn__sub(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -277,7 +277,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__mul(Value lhs, Value rhs) {
|
||||
Value arithfn__mul(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -287,7 +287,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__max(Value lhs, Value rhs) {
|
||||
Value arithfn__max(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -297,7 +297,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__max_unsigned(Value lhs, Value rhs) {
|
||||
Value arithfn__max_unsigned(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -307,7 +307,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__min(Value lhs, Value rhs) {
|
||||
Value arithfn__min(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
||||
@ -317,7 +317,7 @@ public:
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value applyfn__min_unsigned(Value lhs, Value rhs) {
|
||||
Value arithfn__min_unsigned(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
||||
|
@ -77,13 +77,13 @@ class TensorExpression:
|
||||
self.visit_tensor_exprs(visit_scalar_def)
|
||||
|
||||
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
|
||||
return PrimFn.add(self, rhs)
|
||||
return ArithFn.add(self, rhs)
|
||||
|
||||
def __mul__(self, rhs) -> "TensorExpression":
|
||||
return PrimFn.mul(self, rhs)
|
||||
return ArithFn.mul(self, rhs)
|
||||
|
||||
def __sub__(self, rhs) -> "TensorExpression":
|
||||
return PrimFn.sub(self, rhs)
|
||||
return ArithFn.sub(self, rhs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
@ -347,42 +347,55 @@ class TypeFn:
|
||||
cast_unsigned = TypeFnType("cast_unsigned")
|
||||
|
||||
|
||||
class PrimFnType:
|
||||
"""Primitive operations."""
|
||||
class ArithFnType:
|
||||
"""Arithmetic function.
|
||||
|
||||
def __init__(self, prim_name: str):
|
||||
self.prim_name = prim_name
|
||||
An arithmetic function takes one ore more tensor expressions and returns the
|
||||
function evaluation result.
|
||||
"""
|
||||
|
||||
def __call__(self, *args):
|
||||
return PrimApply(self, args)
|
||||
def __init__(self, fn_name: str):
|
||||
self.fn_name = fn_name
|
||||
|
||||
def __call__(self, *args) -> "TensorArithFn":
|
||||
return TensorArithFn(self, args)
|
||||
|
||||
def reduce(self, *reduce_dims: DimDef):
|
||||
"""Shortcut to create a Reduce operation from this primitive."""
|
||||
"""Shortcut to create a Reduce operation from this function."""
|
||||
return ReduceFnType(self, *reduce_dims)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.prim_name}"
|
||||
return f"{self.fn_name}"
|
||||
|
||||
|
||||
class PrimFn:
|
||||
add = PrimFnType("add")
|
||||
exp = PrimFnType("exp")
|
||||
log = PrimFnType("log")
|
||||
mul = PrimFnType("mul")
|
||||
max = PrimFnType("max")
|
||||
min = PrimFnType("min")
|
||||
sub = PrimFnType("sub")
|
||||
max_unsigned = PrimFnType("max_unsigned")
|
||||
min_unsigned = PrimFnType("min_unsigned")
|
||||
class ArithFn:
|
||||
"""Arithmetic function namespace.
|
||||
|
||||
As the integer types are signless, signedness is implement by different
|
||||
functions that treat integers as signed or unsigned values.
|
||||
|
||||
Examples:
|
||||
- max -> `arith.MaxSIOp`
|
||||
- max_unsinged -> `arith.MaxUIOp`
|
||||
"""
|
||||
add = ArithFnType("add")
|
||||
exp = ArithFnType("exp")
|
||||
log = ArithFnType("log")
|
||||
mul = ArithFnType("mul")
|
||||
max = ArithFnType("max")
|
||||
min = ArithFnType("min")
|
||||
sub = ArithFnType("sub")
|
||||
max_unsigned = ArithFnType("max_unsigned")
|
||||
min_unsigned = ArithFnType("min_unsigned")
|
||||
|
||||
|
||||
class ReduceFnType:
|
||||
"""A reduction operator that reduces into its LHS from its RHS."""
|
||||
|
||||
def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
|
||||
"""Initializes the ReduceFn with a primitive function and dims."""
|
||||
if not isinstance(operator, PrimFnType):
|
||||
raise ValueError(f"Reduce expected a Prim operator but got {operator}")
|
||||
def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
|
||||
"""Initializes the ReduceFn with an airthmetic function and dims."""
|
||||
if not isinstance(operator, ArithFnType):
|
||||
raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
|
||||
self.operator = operator
|
||||
self.reduce_dims = tuple(reduce_dims)
|
||||
|
||||
@ -390,28 +403,28 @@ class ReduceFnType:
|
||||
return ReduceApply(self, args)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"reduce_{self.operator.prim_name}"
|
||||
return (f"reduce_{self.operator.fn_name}"
|
||||
f"({', '.join(repr(d) for d in self.reduce_dims)})")
|
||||
|
||||
|
||||
class ReduceFn:
|
||||
add = PrimFn.add.reduce
|
||||
mul = PrimFn.mul.reduce
|
||||
max = PrimFn.max.reduce
|
||||
min = PrimFn.min.reduce
|
||||
max_unsigned = PrimFn.max_unsigned.reduce
|
||||
min_unsigned = PrimFn.min_unsigned.reduce
|
||||
add = ArithFn.add.reduce
|
||||
mul = ArithFn.mul.reduce
|
||||
max = ArithFn.max.reduce
|
||||
min = ArithFn.min.reduce
|
||||
max_unsigned = ArithFn.max_unsigned.reduce
|
||||
min_unsigned = ArithFn.min_unsigned.reduce
|
||||
|
||||
|
||||
class PrimApply(TensorExpression):
|
||||
"""Application of a primitive."""
|
||||
class TensorArithFn(TensorExpression):
|
||||
"""Application of an arithmetic function."""
|
||||
|
||||
def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]):
|
||||
self.prim = prim
|
||||
def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]):
|
||||
self.arith_fn = arith_fn
|
||||
self.args = tuple(args)
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
return ScalarApplyFn(self.prim.prim_name,
|
||||
return ScalarArithFn(self.arith_fn.fn_name,
|
||||
*[arg.to_scalar_expression() for arg in self.args
|
||||
]).expr()
|
||||
|
||||
@ -421,7 +434,7 @@ class PrimApply(TensorExpression):
|
||||
arg.visit_tensor_exprs(callback)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
|
||||
return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
|
||||
|
||||
|
||||
class TensorTypeFn(TensorExpression):
|
||||
@ -503,7 +516,7 @@ class ReduceApply(TensorExpression):
|
||||
f"bound to its lhs: {self}")
|
||||
full_args = [self.lhs.to_scalar_expression()
|
||||
] + [arg.to_scalar_expression() for arg in self.args]
|
||||
return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
|
||||
return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback):
|
||||
for arg in self.args:
|
||||
|
@ -221,10 +221,10 @@ class _BodyBuilder:
|
||||
dim_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(64), expr.scalar_index.dim)
|
||||
return linalg.IndexOp(dim_attr).result
|
||||
elif expr.scalar_apply:
|
||||
fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}")
|
||||
elif expr.arith_fn:
|
||||
fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
|
||||
operand_values = [
|
||||
self.expression(operand) for operand in expr.scalar_apply.operands
|
||||
self.expression(operand) for operand in expr.arith_fn.operands
|
||||
]
|
||||
return fn(*operand_values)
|
||||
elif expr.type_fn:
|
||||
@ -310,59 +310,59 @@ class _BodyBuilder:
|
||||
def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
|
||||
return self._cast(type_var_name, operand, True)
|
||||
|
||||
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_add(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.AddFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.AddIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
|
||||
|
||||
def _eval_exp(self, x: Value) -> Value:
|
||||
def _arithfn_exp(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
return math.ExpOp(x).result
|
||||
raise NotImplementedError("Unsupported 'exp' operand: {x}")
|
||||
|
||||
def _eval_log(self, x: Value) -> Value:
|
||||
def _arithfn_log(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
return math.LogOp(x).result
|
||||
raise NotImplementedError("Unsupported 'log' operand: {x}")
|
||||
|
||||
def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.SubFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.SubIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
|
||||
|
||||
def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MulFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.MulIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
|
||||
|
||||
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MaxFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.MaxSIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
|
||||
|
||||
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MaxFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.MaxUIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
|
||||
|
||||
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MinFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
return arith.MinSIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
|
||||
|
||||
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MinFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
|
@ -20,7 +20,7 @@ from .types import *
|
||||
|
||||
__all__ = [
|
||||
"ScalarAssign",
|
||||
"ScalarApplyFn",
|
||||
"ScalarArithFn",
|
||||
"ScalarTypeFn",
|
||||
"ScalarArg",
|
||||
"ScalarConst",
|
||||
@ -29,18 +29,18 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class ScalarApplyFn:
|
||||
"""A type of ScalarExpression that applies a named function to operands."""
|
||||
class ScalarArithFn:
|
||||
"""A type of ScalarExpression that applies an arithmetic function."""
|
||||
|
||||
def __init__(self, fn_name: str, *operands: "ScalarExpression"):
|
||||
self.fn_name = fn_name
|
||||
self.operands = operands
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_apply=self)
|
||||
return ScalarExpression(arith_fn=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
|
||||
return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
|
||||
|
||||
|
||||
class ScalarTypeFn:
|
||||
@ -102,7 +102,7 @@ class ScalarExpression(YAMLObject):
|
||||
"""An expression on scalar values.
|
||||
|
||||
Can be one of:
|
||||
- ScalarApplyFn
|
||||
- ScalarArithFn
|
||||
- ScalarTypeFn
|
||||
- ScalarArg
|
||||
- ScalarConst
|
||||
@ -112,27 +112,27 @@ class ScalarExpression(YAMLObject):
|
||||
yaml_tag = "!ScalarExpression"
|
||||
|
||||
def __init__(self,
|
||||
scalar_apply: Optional[ScalarApplyFn] = None,
|
||||
arith_fn: Optional[ScalarArithFn] = None,
|
||||
type_fn: Optional[ScalarTypeFn] = None,
|
||||
scalar_arg: Optional[ScalarArg] = None,
|
||||
scalar_const: Optional[ScalarConst] = None,
|
||||
scalar_index: Optional[ScalarIndex] = None):
|
||||
if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) +
|
||||
bool(scalar_const) + bool(scalar_index)) != 1:
|
||||
raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', "
|
||||
if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
|
||||
bool(scalar_index)) != 1:
|
||||
raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
|
||||
"'scalar_const', 'scalar_index', must be specified")
|
||||
self.scalar_apply = scalar_apply
|
||||
self.arith_fn = arith_fn
|
||||
self.type_fn = type_fn
|
||||
self.scalar_arg = scalar_arg
|
||||
self.scalar_const = scalar_const
|
||||
self.scalar_index = scalar_index
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.scalar_apply:
|
||||
if self.arith_fn:
|
||||
return dict(
|
||||
scalar_apply=dict(
|
||||
fn_name=self.scalar_apply.fn_name,
|
||||
operands=list(self.scalar_apply.operands),
|
||||
arith_fn=dict(
|
||||
fn_name=self.arith_fn.fn_name,
|
||||
operands=list(self.arith_fn.operands),
|
||||
))
|
||||
if self.type_fn:
|
||||
# Note that even though operands must be arity 1, we write it the
|
||||
|
@ -665,4 +665,4 @@ def soft_plus_2d(
|
||||
"""
|
||||
domain(D.m, D.n)
|
||||
O[D.m, D.n] = \
|
||||
PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n])))
|
||||
ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n])))
|
||||
|
@ -34,7 +34,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
arith_fn:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
@ -89,7 +89,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
|
||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
|
||||
|
||||
|
||||
# @linalg_structured_op
|
||||
|
@ -9,10 +9,10 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||
# CHECK: -
|
||||
# CHECK: arg: C
|
||||
# CHECK: value:
|
||||
# CHECK: scalar_apply:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_apply:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: fn_name: mul
|
||||
# CHECK: operands:
|
||||
# CHECK: type_fn:
|
||||
@ -36,10 +36,10 @@ def matmul(
|
||||
# CHECK: assignments:
|
||||
# CHECK: -
|
||||
# CHECK: arg: O
|
||||
# CHECK: scalar_apply:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: fn_name: sub
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_apply:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: type_fn:
|
||||
@ -67,7 +67,7 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
|
||||
# CHECK: assignments:
|
||||
# CHECK: -
|
||||
# CHECK: arg: O
|
||||
# CHECK: scalar_apply:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_index: 1
|
||||
|
@ -35,8 +35,8 @@ def fill_rng_poly(
|
||||
@linalg_structured_op
|
||||
def soft_plus_poly(
|
||||
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
|
||||
O[D.m, D.n] = PrimFn.log(
|
||||
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n])))
|
||||
O[D.m, D.n] = ArithFn.log(
|
||||
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n])))
|
||||
|
||||
|
||||
@linalg_structured_op(op_name="custom_op_name")
|
||||
|
@ -82,7 +82,7 @@ struct LinalgIndexingMapsConfig {
|
||||
|
||||
struct ScalarExpression;
|
||||
|
||||
struct ScalarApply {
|
||||
struct ScalarArithFn {
|
||||
std::string fnName;
|
||||
// NOTE: Must be pure heap allocated container (not SmallVector)
|
||||
// due to recursive data type.
|
||||
@ -101,7 +101,7 @@ struct ScalarExpression {
|
||||
Optional<std::string> arg;
|
||||
Optional<std::string> constant;
|
||||
Optional<int64_t> index;
|
||||
Optional<ScalarApply> apply;
|
||||
Optional<ScalarArithFn> arithFn;
|
||||
Optional<ScalarTypeFn> typeFn;
|
||||
};
|
||||
|
||||
@ -245,9 +245,10 @@ struct MappingTraits<ScalarAssign> {
|
||||
};
|
||||
|
||||
/// A scalar expression (RHS of an assignment). Must be one of:
|
||||
/// - `scalar_arg`: Name of an argument to the op.
|
||||
/// - `scalar_apply`: Result of evaluating a named function (see
|
||||
/// `ScalarApply`).
|
||||
/// - `scalar_arg`: An operation argument.
|
||||
/// - `scalar_const`: A constant definition.
|
||||
/// - `scalar_index`: An iteration index.
|
||||
/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
|
||||
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
|
||||
template <>
|
||||
struct MappingTraits<ScalarExpression> {
|
||||
@ -255,7 +256,7 @@ struct MappingTraits<ScalarExpression> {
|
||||
io.mapOptional("scalar_arg", info.arg);
|
||||
io.mapOptional("scalar_const", info.constant);
|
||||
io.mapOptional("scalar_index", info.index);
|
||||
io.mapOptional("scalar_apply", info.apply);
|
||||
io.mapOptional("arith_fn", info.arithFn);
|
||||
io.mapOptional("type_fn", info.typeFn);
|
||||
}
|
||||
};
|
||||
@ -266,8 +267,8 @@ struct MappingTraits<ScalarExpression> {
|
||||
/// - `add(lhs, rhs)`
|
||||
/// - `mul(lhs, rhs)`
|
||||
template <>
|
||||
struct MappingTraits<ScalarApply> {
|
||||
static void mapping(IO &io, ScalarApply &info) {
|
||||
struct MappingTraits<ScalarArithFn> {
|
||||
static void mapping(IO &io, ScalarArithFn &info) {
|
||||
io.mapRequired("fn_name", info.fnName);
|
||||
io.mapRequired("operands", info.operands);
|
||||
}
|
||||
@ -944,11 +945,11 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
|
||||
cppIdent, *expression.index));
|
||||
return cppIdent;
|
||||
}
|
||||
if (expression.apply) {
|
||||
if (expression.arithFn) {
|
||||
// Apply function.
|
||||
// Recursively generate operands.
|
||||
SmallVector<std::string> operandCppValues;
|
||||
for (ScalarExpression &operand : expression.apply->operands) {
|
||||
for (ScalarExpression &operand : expression.arithFn->operands) {
|
||||
auto operandCppValue = generateExpression(operand);
|
||||
if (!operandCppValue)
|
||||
return None;
|
||||
@ -956,8 +957,8 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
|
||||
}
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(
|
||||
llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
|
||||
expression.apply->fnName,
|
||||
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
|
||||
expression.arithFn->fnName,
|
||||
interleaveToString(operandCppValues, ", ")));
|
||||
return cppIdent;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user