[mlir][OpDSL] Add TypeFn
class.
This revision introduces a the `TypeFn` class that similar to the `PrimFn` class contains an extensible set of type conversion functions. Having the same mechanism for both type conversion functions and arithmetic functions improves code consistency. Additionally, having an explicit function class and function name is a prerequisite to specify a conversion or arithmetic function via attribute. In a follow up commits, we will introduce function attributes to make OpDSL operations more generic. In particular, the goal is to handle signed and unsigned computation in one operations. Today, there is a linalg.matmul and a linalg.matmul_unsigned. The commit implements the following changes: - Introduce the class of type conversion functions `TypeFn` - Replace the hardwired cast and cast_unsigned ops by the `TypeFn` counterparts - Adapt the python and C++ code generation paths to support the new cast operations Example: ``` cast(U, A[D.m, D.k]) ``` changes to ``` TypeFn.cast(U, A[D.m, D.k]) ``` Depends On D115237 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115239
This commit is contained in:
parent
babad7c566
commit
15757ea80a
@ -56,7 +56,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
```
|
```
|
||||||
|
|
||||||
Here we have a simple type polymorphic contraction that takes arguments `A` and
|
Here we have a simple type polymorphic contraction that takes arguments `A` and
|
||||||
@ -159,8 +159,8 @@ def pooling_poly(
|
|||||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||||
strides=IndexAttrDef(S.SH, S.SW),
|
strides=IndexAttrDef(S.SH, S.SW),
|
||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
O[D.n, D.oh, D.ow, D.c] += \
|
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
|
||||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||||
```
|
```
|
||||||
|
|
||||||
The pooling operation does not access the shape-only tensor `K`. Instead, the
|
The pooling operation does not access the shape-only tensor `K`. Instead, the
|
||||||
@ -192,10 +192,18 @@ Reduction functions can appear as the outer-most function on the RHS:
|
|||||||
* `ReduceFn.mul`
|
* `ReduceFn.mul`
|
||||||
* `ReduceFn.max`
|
* `ReduceFn.max`
|
||||||
|
|
||||||
|
Additionally, type conversion functions cast an operand to a target type:
|
||||||
|
|
||||||
|
* `TypeFn.cast(TypeVar, operand)`
|
||||||
|
* `TypeFn.cast_unsigned(TypeVar, operand)`
|
||||||
|
|
||||||
|
As the integer types are signless, signedness is implement by different
|
||||||
|
functions that treat integers as signed (`TypeFn.cast`) or unsigned
|
||||||
|
(`TypeFn.cast_unsigned`) values.
|
||||||
|
|
||||||
There are also special forms:
|
There are also special forms:
|
||||||
|
|
||||||
* `cast(TypeVar, operand)` casts the `operand` to the target type `TypeVar`.
|
* `const(value)` returns a constant value.
|
||||||
* `const(TypeVar, value)` returns a constant value of type `TypeVar`.
|
|
||||||
* `index(dim)` returns the iteration index in the given dimension `dim`.
|
* `index(dim)` returns the iteration index in the given dimension `dim`.
|
||||||
|
|
||||||
## Types
|
## Types
|
||||||
@ -206,18 +214,25 @@ output types of constructed ops. An exception are predefined types such as
|
|||||||
computations with a type that is independent of the input and output types. For
|
computations with a type that is independent of the input and output types. For
|
||||||
example, parts of floating point computation may require double precision
|
example, parts of floating point computation may require double precision
|
||||||
arithmetic despite all inputs and outputs being single precision values.
|
arithmetic despite all inputs and outputs being single precision values.
|
||||||
Assignment expressions with no `cast` calls will generally require uniform types
|
Assignment expressions with no `TypeFn.cast` calls will generally require
|
||||||
throughout and will fail to verify if violated. The presence of a `cast` allows
|
uniform types throughout and will fail to verify if violated. The presence of a
|
||||||
for a limited form of numeric type conversion between element types that can be
|
`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric
|
||||||
derived from inputs and outputs (and in the future, attributes). `cast` calls
|
type conversion between element types that can be derived from inputs and
|
||||||
with a `TypeVar` first argument are emitted as `symbolic_cast` primitives in the
|
outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar`
|
||||||
YAML definition.
|
first argument are emitted as `type_fn` primitives in the YAML definition.
|
||||||
|
|
||||||
Casting will perform `int<->float` and `index->int` type conversions and will
|
Casting will perform `int<->float` and `index->int` type conversions and will
|
||||||
perform any necessary extension or truncation within type family. Note that
|
perform any necessary extension or truncation within the type family. The
|
||||||
presently, any integer type is assumed to be signed for the purpose of
|
integer types themselves are signless and signedness is implemented by
|
||||||
determining how to extend or truncate. Supporting unsigned integer types is left
|
functions/operations. The `TypeFn.cast` function treats all integers as signed,
|
||||||
for future work.
|
while `TypeFn.cast_unsigned` treats them as unsigned.
|
||||||
|
|
||||||
|
The following examples illustrate the lowering of signed and unsigned functions:
|
||||||
|
|
||||||
|
* cast(I32 -> I64) -> `arith.ExtSIOp`
|
||||||
|
* cast(F32 -> I32) -> `arith.FPToSIOp`
|
||||||
|
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
|
||||||
|
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
|
||||||
|
|
||||||
Not all functions are applicable for all numeric types, and on mismatch, op
|
Not all functions are applicable for all numeric types, and on mismatch, op
|
||||||
verification will fail.
|
verification will fail.
|
||||||
|
@ -51,19 +51,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: matmul_unsigned
|
name: matmul_unsigned
|
||||||
@ -115,19 +115,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast_unsigned
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: true
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast_unsigned
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: true
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: quantized_matmul
|
name: quantized_matmul
|
||||||
@ -193,37 +193,37 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: AZp
|
scalar_arg: AZp
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: BZp
|
scalar_arg: BZp
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: mmt4d
|
name: mmt4d
|
||||||
@ -286,19 +286,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: AccumType
|
type_var: AccumType
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: lhs
|
scalar_arg: lhs
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: AccumType
|
type_var: AccumType
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: rhs
|
scalar_arg: rhs
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: batch_matmul
|
name: batch_matmul
|
||||||
@ -351,19 +351,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: quantized_batch_matmul
|
name: quantized_batch_matmul
|
||||||
@ -430,37 +430,37 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: AZp
|
scalar_arg: AZp
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: BZp
|
scalar_arg: BZp
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: matvec
|
name: matvec
|
||||||
@ -511,19 +511,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: y
|
scalar_arg: y
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: vecmat
|
name: vecmat
|
||||||
@ -574,19 +574,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: y
|
scalar_arg: y
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: batch_matvec
|
name: batch_matvec
|
||||||
@ -638,19 +638,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: dot
|
name: dot
|
||||||
@ -700,19 +700,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: A
|
scalar_arg: A
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: B
|
scalar_arg: B
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_1d
|
name: conv_1d
|
||||||
@ -763,19 +763,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_2d
|
name: conv_2d
|
||||||
@ -828,19 +828,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_3d
|
name: conv_3d
|
||||||
@ -896,19 +896,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_1d_nwc_wcf
|
name: conv_1d_nwc_wcf
|
||||||
@ -974,19 +974,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_2d_nhwc_hwcf
|
name: conv_2d_nhwc_hwcf
|
||||||
@ -1064,19 +1064,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_2d_nhwc_hwcf_q
|
name: conv_2d_nhwc_hwcf_q
|
||||||
@ -1171,37 +1171,37 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: IZp
|
scalar_arg: IZp
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: KZp
|
scalar_arg: KZp
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_2d_nchw_fchw
|
name: conv_2d_nchw_fchw
|
||||||
@ -1279,19 +1279,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_3d_ndhwc_dhwcf
|
name: conv_3d_ndhwc_dhwcf
|
||||||
@ -1369,19 +1369,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: depthwise_conv_1d_nwc_wc
|
name: depthwise_conv_1d_nwc_wc
|
||||||
@ -1446,19 +1446,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: depthwise_conv_2d_nhwc_hwc
|
name: depthwise_conv_2d_nhwc_hwc
|
||||||
@ -1529,19 +1529,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: depthwise_conv_2d_nhwc_hwc_q
|
name: depthwise_conv_2d_nhwc_hwc_q
|
||||||
@ -1627,37 +1627,37 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: IZp
|
scalar_arg: IZp
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: KZp
|
scalar_arg: KZp
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: depthwise_conv_2d_nhwc_hwcm
|
name: depthwise_conv_2d_nhwc_hwcm
|
||||||
@ -1731,19 +1731,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: mul
|
fn_name: mul
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: depthwise_conv_2d_nhwc_hwcm_q
|
name: depthwise_conv_2d_nhwc_hwcm_q
|
||||||
@ -1833,37 +1833,37 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: IZp
|
scalar_arg: IZp
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: sub
|
fn_name: sub
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: KZp
|
scalar_arg: KZp
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nhwc_sum
|
name: pooling_nhwc_sum
|
||||||
@ -1929,12 +1929,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nhwc_max
|
name: pooling_nhwc_max
|
||||||
@ -2000,12 +2000,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nhwc_max_unsigned
|
name: pooling_nhwc_max_unsigned
|
||||||
@ -2071,12 +2071,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast_unsigned
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: true
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nchw_max
|
name: pooling_nchw_max
|
||||||
@ -2142,12 +2142,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nhwc_min
|
name: pooling_nhwc_min
|
||||||
@ -2213,12 +2213,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_nhwc_min_unsigned
|
name: pooling_nhwc_min_unsigned
|
||||||
@ -2284,12 +2284,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast_unsigned
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: true
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_ndhwc_sum
|
name: pooling_ndhwc_sum
|
||||||
@ -2361,12 +2361,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_ndhwc_max
|
name: pooling_ndhwc_max
|
||||||
@ -2438,12 +2438,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: pooling_ndhwc_min
|
name: pooling_ndhwc_min
|
||||||
@ -2515,12 +2515,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: O
|
scalar_arg: O
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: fill_rng_2d
|
name: fill_rng_2d
|
||||||
@ -2567,7 +2567,8 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarAssign
|
- !ScalarAssign
|
||||||
arg: O
|
arg: O
|
||||||
value: !ScalarExpression
|
value: !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: T
|
type_var: T
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
@ -2583,14 +2584,15 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: add
|
fn_name: add
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: F64
|
type_var: F64
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '2147483647 : i64'
|
scalar_const: '2147483647 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: F64
|
type_var: F64
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
@ -2606,12 +2608,12 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: add
|
fn_name: add
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_index: 1
|
scalar_index: 1
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: add
|
fn_name: add
|
||||||
@ -2625,43 +2627,42 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: add
|
fn_name: add
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_index: 0
|
scalar_index: 0
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: seed
|
scalar_arg: seed
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '1103515245 : i64'
|
scalar_const: '1103515245 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '12345 : i64'
|
scalar_const: '12345 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '1103515245 : i64'
|
scalar_const: '1103515245 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: I32
|
type_var: I32
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '12345 : i64'
|
scalar_const: '12345 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: mul
|
fn_name: mul
|
||||||
@ -2675,15 +2676,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: min
|
scalar_arg: min
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: F64
|
type_var: F64
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '2.3283063999999999E-10 : f64'
|
scalar_const: '2.3283063999999999E-10 : f64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: min
|
scalar_arg: min
|
||||||
is_unsigned_cast: false
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: soft_plus_2d
|
name: soft_plus_2d
|
||||||
@ -2724,20 +2724,20 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: add
|
fn_name: add
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '1.000000e+00 : f64'
|
scalar_const: '1.000000e+00 : f64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_apply:
|
scalar_apply:
|
||||||
fn_name: exp
|
fn_name: exp
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: U
|
type_var: U
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
is_unsigned_cast: false
|
|
||||||
|
@ -147,11 +147,13 @@ static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
|
|||||||
// Region builder helper.
|
// Region builder helper.
|
||||||
// TODO: Move this to a utility library.
|
// TODO: Move this to a utility library.
|
||||||
// The public methods on this class are referenced directly from generated code
|
// The public methods on this class are referenced directly from generated code
|
||||||
// and bind by name to math functions in the DSL as:
|
// and bind by name to math and type conversion functions in the DSL as:
|
||||||
// `applyfn__{fnName}`
|
// `applyfn__{fnName}`
|
||||||
|
// `typefn__{fnName}`
|
||||||
// Examples:
|
// Examples:
|
||||||
// `applyfn__add`
|
// `applyfn__add`
|
||||||
// `applyfn__mul`
|
// `applyfn__mul`
|
||||||
|
// `typefn__cast`
|
||||||
// The naming convention is intentional in order to match snake-cased DSL names.
|
// The naming convention is intentional in order to match snake-cased DSL names.
|
||||||
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
|
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
|
||||||
//
|
//
|
||||||
@ -228,6 +230,16 @@ public:
|
|||||||
return operand;
|
return operand;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
|
Value typefn__cast(Type toType, Value operand) {
|
||||||
|
return cast(toType, operand, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
|
Value typefn__cast_unsigned(Type toType, Value operand) {
|
||||||
|
return cast(toType, operand, true);
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__add(Value lhs, Value rhs) {
|
Value applyfn__add(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
|
@ -314,6 +314,39 @@ class Comprehension:
|
|||||||
return f"{defs_repr} = {values_repr}"
|
return f"{defs_repr} = {values_repr}"
|
||||||
|
|
||||||
|
|
||||||
|
class TypeFnType:
|
||||||
|
"""Type conversion function.
|
||||||
|
|
||||||
|
A type conversion function takes a target type and a tensor expression and
|
||||||
|
returns the casted tensor expression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fn_name: str):
|
||||||
|
self.fn_name = fn_name
|
||||||
|
|
||||||
|
def __call__(self, type_var: TypeVar,
|
||||||
|
arg: TensorExpression) -> "TensorTypeFn":
|
||||||
|
return TensorTypeFn(self, type_var, arg)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.fn_name}"
|
||||||
|
|
||||||
|
|
||||||
|
class TypeFn:
|
||||||
|
"""Type conversion function namespace.
|
||||||
|
|
||||||
|
As the integer types are signless, signedness is implement by different cast
|
||||||
|
functions that treat integers as signed (`cast`) or unsigned
|
||||||
|
(`cast_unsigned`) values.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- cast(I32 -> I64) -> `arith.ExtSIOp`
|
||||||
|
- cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
|
||||||
|
"""
|
||||||
|
cast = TypeFnType("cast")
|
||||||
|
cast_unsigned = TypeFnType("cast_unsigned")
|
||||||
|
|
||||||
|
|
||||||
class PrimFnType:
|
class PrimFnType:
|
||||||
"""Primitive operations."""
|
"""Primitive operations."""
|
||||||
|
|
||||||
@ -391,6 +424,26 @@ class PrimApply(TensorExpression):
|
|||||||
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
|
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
|
||||||
|
|
||||||
|
|
||||||
|
class TensorTypeFn(TensorExpression):
|
||||||
|
"""Application of a type conversion function."""
|
||||||
|
|
||||||
|
def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression):
|
||||||
|
self.type_fn = type_fn
|
||||||
|
self.type_var = type_var
|
||||||
|
self.arg = arg
|
||||||
|
|
||||||
|
def to_scalar_expression(self) -> ScalarExpression:
|
||||||
|
return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
|
||||||
|
self.arg.to_scalar_expression()).expr()
|
||||||
|
|
||||||
|
def visit_tensor_exprs(self, callback):
|
||||||
|
super().visit_tensor_exprs(callback)
|
||||||
|
self.arg.visit_tensor_exprs(callback)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{repr(self.type_fn)}({type_var}, {self.arg})"
|
||||||
|
|
||||||
|
|
||||||
class const(TensorExpression):
|
class const(TensorExpression):
|
||||||
"""Returns the given constant floating point or integer value."""
|
"""Returns the given constant floating point or integer value."""
|
||||||
|
|
||||||
@ -433,36 +486,6 @@ class index(TensorExpression):
|
|||||||
return f"index({repr(self.dim)})"
|
return f"index({repr(self.dim)})"
|
||||||
|
|
||||||
|
|
||||||
class cast(TensorExpression):
|
|
||||||
"""Casts the element type to a type (typically symbolic TypeVar)."""
|
|
||||||
|
|
||||||
def __init__(self, to_type: TypeVar, operand: TensorExpression):
|
|
||||||
self.to_type = to_type
|
|
||||||
self.operand = operand
|
|
||||||
|
|
||||||
def to_scalar_expression(self) -> ScalarExpression:
|
|
||||||
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
|
|
||||||
False).expr()
|
|
||||||
|
|
||||||
def visit_tensor_exprs(self, callback):
|
|
||||||
super().visit_tensor_exprs(callback)
|
|
||||||
self.operand.visit_tensor_exprs(callback)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"cast({self.to_type}, {repr(self.operand)})"
|
|
||||||
|
|
||||||
|
|
||||||
class cast_unsigned(cast):
|
|
||||||
"""Casts the element type to an unsigned type (typically symbolic TypeVar)."""
|
|
||||||
|
|
||||||
def to_scalar_expression(self) -> ScalarExpression:
|
|
||||||
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
|
|
||||||
True).expr()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"cast_unsigned({self.to_type}, {repr(self.operand)})"
|
|
||||||
|
|
||||||
|
|
||||||
class ReduceApply(TensorExpression):
|
class ReduceApply(TensorExpression):
|
||||||
"""Application of a reduction.
|
"""Application of a reduction.
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
from typing import Dict, List, Sequence, Tuple, Union
|
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
||||||
|
|
||||||
from .....ir import *
|
from .....ir import *
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ __all__ = [
|
|||||||
|
|
||||||
ValueList = Union[Sequence[Value], OpResultList]
|
ValueList = Union[Sequence[Value], OpResultList]
|
||||||
|
|
||||||
|
|
||||||
def isa(cls: Type, ty: Type):
|
def isa(cls: Type, ty: Type):
|
||||||
try:
|
try:
|
||||||
cls(ty)
|
cls(ty)
|
||||||
@ -221,24 +222,38 @@ class _BodyBuilder:
|
|||||||
IntegerType.get_signless(64), expr.scalar_index.dim)
|
IntegerType.get_signless(64), expr.scalar_index.dim)
|
||||||
return linalg.IndexOp(dim_attr).result
|
return linalg.IndexOp(dim_attr).result
|
||||||
elif expr.scalar_apply:
|
elif expr.scalar_apply:
|
||||||
try:
|
fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}")
|
||||||
fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Function '{expr.scalar_apply.fn_name}' is not a known "
|
|
||||||
"scalar body function")
|
|
||||||
operand_values = [
|
operand_values = [
|
||||||
self.expression(operand) for operand in expr.scalar_apply.operands
|
self.expression(operand) for operand in expr.scalar_apply.operands
|
||||||
]
|
]
|
||||||
return fn(*operand_values)
|
return fn(*operand_values)
|
||||||
elif expr.symbolic_cast:
|
elif expr.type_fn:
|
||||||
operand_value = self.expression(expr.symbolic_cast.operand)
|
fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}")
|
||||||
return self.cast(expr.symbolic_cast.to_type.name, operand_value,
|
operand = self.expression(expr.type_fn.operand)
|
||||||
expr.symbolic_cast.is_unsigned_cast)
|
return fn(expr.type_fn.type_var.name, operand)
|
||||||
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
||||||
|
|
||||||
def cast(self, type_var_name: str, operand: Value,
|
def yield_outputs(self, *output_names: str):
|
||||||
is_unsigned_cast: bool) -> Value:
|
output_values = []
|
||||||
|
for n in output_names:
|
||||||
|
try:
|
||||||
|
output_values.append(self.yield_mapping[n])
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"Body assignments do not assign all outputs: "
|
||||||
|
f"missing '{n}'")
|
||||||
|
linalg.YieldOp(output_values)
|
||||||
|
|
||||||
|
def _get_function(self, fn_name: str) -> Callable:
|
||||||
|
try:
|
||||||
|
fn = getattr(self, f"{fn_name}")
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(f"Function '{fn_name}' is not a known function")
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def _cast(self,
|
||||||
|
type_var_name: str,
|
||||||
|
operand: Value,
|
||||||
|
is_unsigned_cast: bool = False) -> Value:
|
||||||
try:
|
try:
|
||||||
to_type = self.type_mapping[type_var_name]
|
to_type = self.type_mapping[type_var_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -289,15 +304,11 @@ class _BodyBuilder:
|
|||||||
raise ValueError(f"Unable to cast body expression from {operand_type} to "
|
raise ValueError(f"Unable to cast body expression from {operand_type} to "
|
||||||
f"{to_type}")
|
f"{to_type}")
|
||||||
|
|
||||||
def yield_outputs(self, *output_names: str):
|
def _typefn_cast(self, type_var_name: str, operand: Value) -> Value:
|
||||||
output_values = []
|
return self._cast(type_var_name, operand, False)
|
||||||
for n in output_names:
|
|
||||||
try:
|
def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
|
||||||
output_values.append(self.yield_mapping[n])
|
return self._cast(type_var_name, operand, True)
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"Body assignments do not assign all outputs: "
|
|
||||||
f"missing '{n}'")
|
|
||||||
linalg.YieldOp(output_values)
|
|
||||||
|
|
||||||
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
|
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
|
||||||
if _is_floating_point_type(lhs.type):
|
if _is_floating_point_type(lhs.type):
|
||||||
|
@ -21,11 +21,11 @@ from .types import *
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ScalarAssign",
|
"ScalarAssign",
|
||||||
"ScalarApplyFn",
|
"ScalarApplyFn",
|
||||||
|
"ScalarTypeFn",
|
||||||
"ScalarArg",
|
"ScalarArg",
|
||||||
"ScalarConst",
|
"ScalarConst",
|
||||||
"ScalarIndex",
|
"ScalarIndex",
|
||||||
"ScalarExpression",
|
"ScalarExpression",
|
||||||
"ScalarSymbolicCast",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +43,22 @@ class ScalarApplyFn:
|
|||||||
return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
|
return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarTypeFn:
|
||||||
|
"""A type of ScalarExpression that applies a type conversion function."""
|
||||||
|
|
||||||
|
def __init__(self, fn_name: str, type_var: TypeVar,
|
||||||
|
operand: "ScalarExpression"):
|
||||||
|
self.fn_name = fn_name
|
||||||
|
self.type_var = type_var
|
||||||
|
self.operand = operand
|
||||||
|
|
||||||
|
def expr(self) -> "ScalarExpression":
|
||||||
|
return ScalarExpression(type_fn=self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})"
|
||||||
|
|
||||||
|
|
||||||
class ScalarArg:
|
class ScalarArg:
|
||||||
"""A type of ScalarExpression that references a named argument."""
|
"""A type of ScalarExpression that references a named argument."""
|
||||||
|
|
||||||
@ -82,27 +98,12 @@ class ScalarIndex:
|
|||||||
return f"(ScalarIndex({self.dim})"
|
return f"(ScalarIndex({self.dim})"
|
||||||
|
|
||||||
|
|
||||||
class ScalarSymbolicCast:
|
|
||||||
"""A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
|
|
||||||
|
|
||||||
def __init__(self, to_type: TypeVar, operand: "ScalarExpression",
|
|
||||||
is_unsigned_cast: bool):
|
|
||||||
self.to_type = to_type
|
|
||||||
self.operand = operand
|
|
||||||
self.is_unsigned_cast = is_unsigned_cast
|
|
||||||
|
|
||||||
def expr(self) -> "ScalarExpression":
|
|
||||||
return ScalarExpression(symbolic_cast=self)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})"
|
|
||||||
|
|
||||||
|
|
||||||
class ScalarExpression(YAMLObject):
|
class ScalarExpression(YAMLObject):
|
||||||
"""An expression on scalar values.
|
"""An expression on scalar values.
|
||||||
|
|
||||||
Can be one of:
|
Can be one of:
|
||||||
- ScalarApplyFn
|
- ScalarApplyFn
|
||||||
|
- ScalarTypeFn
|
||||||
- ScalarArg
|
- ScalarArg
|
||||||
- ScalarConst
|
- ScalarConst
|
||||||
- ScalarIndex
|
- ScalarIndex
|
||||||
@ -112,19 +113,19 @@ class ScalarExpression(YAMLObject):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
scalar_apply: Optional[ScalarApplyFn] = None,
|
scalar_apply: Optional[ScalarApplyFn] = None,
|
||||||
|
type_fn: Optional[ScalarTypeFn] = None,
|
||||||
scalar_arg: Optional[ScalarArg] = None,
|
scalar_arg: Optional[ScalarArg] = None,
|
||||||
scalar_const: Optional[ScalarConst] = None,
|
scalar_const: Optional[ScalarConst] = None,
|
||||||
scalar_index: Optional[ScalarIndex] = None,
|
scalar_index: Optional[ScalarIndex] = None):
|
||||||
symbolic_cast: Optional[ScalarSymbolicCast] = None):
|
if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) +
|
||||||
if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) +
|
bool(scalar_const) + bool(scalar_index)) != 1:
|
||||||
bool(scalar_index) + bool(symbolic_cast)) != 1:
|
raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', "
|
||||||
raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', "
|
"'scalar_const', 'scalar_index', must be specified")
|
||||||
"'scalar_index', 'symbolic_cast' must be specified")
|
|
||||||
self.scalar_apply = scalar_apply
|
self.scalar_apply = scalar_apply
|
||||||
|
self.type_fn = type_fn
|
||||||
self.scalar_arg = scalar_arg
|
self.scalar_arg = scalar_arg
|
||||||
self.scalar_const = scalar_const
|
self.scalar_const = scalar_const
|
||||||
self.scalar_index = scalar_index
|
self.scalar_index = scalar_index
|
||||||
self.symbolic_cast = symbolic_cast
|
|
||||||
|
|
||||||
def to_yaml_custom_dict(self):
|
def to_yaml_custom_dict(self):
|
||||||
if self.scalar_apply:
|
if self.scalar_apply:
|
||||||
@ -133,21 +134,22 @@ class ScalarExpression(YAMLObject):
|
|||||||
fn_name=self.scalar_apply.fn_name,
|
fn_name=self.scalar_apply.fn_name,
|
||||||
operands=list(self.scalar_apply.operands),
|
operands=list(self.scalar_apply.operands),
|
||||||
))
|
))
|
||||||
|
if self.type_fn:
|
||||||
|
# Note that even though operands must be arity 1, we write it the
|
||||||
|
# same way as for apply because it allows handling code to be more
|
||||||
|
# generic vs having a special form.
|
||||||
|
return dict(
|
||||||
|
type_fn=dict(
|
||||||
|
fn_name=self.type_fn.fn_name,
|
||||||
|
type_var=self.type_fn.type_var.name,
|
||||||
|
operands=[self.type_fn.operand],
|
||||||
|
))
|
||||||
elif self.scalar_arg:
|
elif self.scalar_arg:
|
||||||
return dict(scalar_arg=self.scalar_arg.arg)
|
return dict(scalar_arg=self.scalar_arg.arg)
|
||||||
elif self.scalar_const:
|
elif self.scalar_const:
|
||||||
return dict(scalar_const=self.scalar_const.value)
|
return dict(scalar_const=self.scalar_const.value)
|
||||||
elif self.scalar_index:
|
elif self.scalar_index:
|
||||||
return dict(scalar_index=self.scalar_index.dim)
|
return dict(scalar_index=self.scalar_index.dim)
|
||||||
elif self.symbolic_cast:
|
|
||||||
# Note that even though operands must be arity 1, we write it the
|
|
||||||
# same way as for apply because it allows handling code to be more
|
|
||||||
# generic vs having a special form.
|
|
||||||
return dict(
|
|
||||||
symbolic_cast=dict(
|
|
||||||
type_var=self.symbolic_cast.to_type.name,
|
|
||||||
operands=[self.symbolic_cast.operand],
|
|
||||||
is_unsigned_cast=self.symbolic_cast.is_unsigned_cast))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected ScalarExpression type: {self}")
|
raise ValueError(f"Unexpected ScalarExpression type: {self}")
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ def matmul(
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -33,7 +33,8 @@ def matmul_unsigned(
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
|
||||||
|
U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -51,8 +52,8 @@ def quantized_matmul(
|
|||||||
matmul.
|
matmul.
|
||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * (
|
C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * (
|
||||||
cast(U, B[D.k, D.n]) - cast(U, BZp))
|
TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -72,9 +73,9 @@ def mmt4d(
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
|
domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
accum[D.m, D.n, D.m0,
|
accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast(
|
||||||
D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(
|
TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast(
|
||||||
TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
|
TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -89,7 +90,8 @@ def batch_matmul(
|
|||||||
"""
|
"""
|
||||||
domain(D.b, D.m, D.n, D.k)
|
domain(D.b, D.m, D.n, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n])
|
C[D.b, D.m,
|
||||||
|
D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -107,8 +109,9 @@ def quantized_batch_matmul(
|
|||||||
matmul.
|
matmul.
|
||||||
"""
|
"""
|
||||||
domain(D.b, D.m, D.n, D.k)
|
domain(D.b, D.m, D.n, D.k)
|
||||||
C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * (
|
C[D.b, D.m,
|
||||||
cast(U, B[D.b, D.k, D.n]) - cast(U, BZp))
|
D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * (
|
||||||
|
TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -123,7 +126,7 @@ def matvec(
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n)
|
domain(D.m, D.n)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n])
|
x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -138,7 +141,7 @@ def vecmat(
|
|||||||
"""
|
"""
|
||||||
domain(D.n, D.m)
|
domain(D.n, D.m)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
|
x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -153,7 +156,7 @@ def batch_matvec(
|
|||||||
"""
|
"""
|
||||||
domain(D.b, D.m, D.k)
|
domain(D.b, D.m, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k])
|
C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -165,7 +168,7 @@ def dot(
|
|||||||
them to the same data type as the accumulator/output.
|
them to the same data type as the accumulator/output.
|
||||||
"""
|
"""
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
|
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -180,7 +183,7 @@ def conv_1d(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.ow, D.kw)
|
domain(D.ow, D.kw)
|
||||||
O[D.ow] += cast(U, I[D.ow + D.kw]) * cast(U, K[D.kw])
|
O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -195,7 +198,8 @@ def conv_2d(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.oh, D.ow, D.kh, D.kw)
|
domain(D.oh, D.ow, D.kh, D.kw)
|
||||||
O[D.oh, D.ow] += cast(U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw])
|
O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast(
|
||||||
|
U, K[D.kh, D.kw])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -211,8 +215,8 @@ def conv_3d(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
|
domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
|
||||||
O[D.od, D.oh,
|
O[D.od, D.oh,
|
||||||
D.ow] += cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(
|
D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow +
|
||||||
U, K[D.kd, D.kh, D.kw])
|
D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -229,8 +233,9 @@ def conv_1d_nwc_wcf(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.ow, D.f, D.kw, D.c)
|
domain(D.n, D.ow, D.f, D.kw, D.c)
|
||||||
O[D.n, D.ow, D.f] += cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(
|
O[D.n, D.ow,
|
||||||
U, K[D.kw, D.c, D.f])
|
D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
|
||||||
|
D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -252,9 +257,9 @@ def conv_2d_nhwc_hwcf(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.f] += cast(
|
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
D.c]) * cast(U, K[D.kh, D.kw, D.c, D.f])
|
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -280,10 +285,10 @@ def conv_2d_nhwc_hwcf_q(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow,
|
O[D.n, D.oh, D.ow,
|
||||||
D.f] += (cast(
|
D.f] += (TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
|
||||||
cast(U, IZp)) * (
|
TypeFn.cast(U, IZp)) * (
|
||||||
cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp))
|
TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -305,9 +310,9 @@ def conv_2d_nchw_fchw(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||||
O[D.n, D.f, D.oh, D.ow] += cast(
|
O[D.n, D.f, D.oh, D.ow] += TypeFn.cast(
|
||||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
||||||
D.ow * S.SW + D.kw * S.DW]) * cast(U, K[D.f, D.c, D.kh, D.kw])
|
D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -325,9 +330,9 @@ def conv_3d_ndhwc_dhwcf(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
|
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.od, D.oh, D.ow, D.f] += cast(
|
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast(
|
||||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||||
D.ow * S.SW + D.kw * S.DW, D.c]) * cast(
|
D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast(
|
||||||
U, K[D.kd, D.kh, D.kw, D.c, D.f])
|
U, K[D.kd, D.kh, D.kw, D.c, D.f])
|
||||||
|
|
||||||
|
|
||||||
@ -347,8 +352,8 @@ def depthwise_conv_1d_nwc_wc(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.ow, D.ic, D.kw)
|
domain(D.n, D.ow, D.ic, D.kw)
|
||||||
O[D.n, D.ow, D.ic] += \
|
O[D.n, D.ow, D.ic] += \
|
||||||
cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
|
TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
|
||||||
cast(U, K[D.kw, D.ic])
|
TypeFn.cast(U, K[D.kw, D.ic])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -367,9 +372,9 @@ def depthwise_conv_2d_nhwc_hwc(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
||||||
O[D.n, D.oh, D.ow, D.ic] += cast(
|
O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
D.ic]) * cast(U, K[D.kh, D.kw, D.ic])
|
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -389,10 +394,11 @@ def depthwise_conv_2d_nhwc_hwc_q(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
||||||
O[D.n, D.oh, D.ow, D.ic] += (
|
O[D.n, D.oh, D.ow,
|
||||||
(cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
D.ic] += ((TypeFn.cast(
|
||||||
D.ic]) - cast(U, IZp)) *
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
|
||||||
(cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp)))
|
TypeFn.cast(U, IZp)) *
|
||||||
|
(TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp)))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -410,9 +416,9 @@ def depthwise_conv_2d_nhwc_hwcm(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
||||||
O[D.n, D.oh, D.ow, D.ic, D.cm] += cast(
|
O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
|
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -432,10 +438,11 @@ def depthwise_conv_2d_nhwc_hwcm_q(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
||||||
O[D.n, D.oh, D.ow, D.ic, D.cm] += (
|
O[D.n, D.oh, D.ow, D.ic,
|
||||||
(cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
D.cm] += ((TypeFn.cast(
|
||||||
D.ic]) - cast(U, IZp)) *
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
|
||||||
(cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp)))
|
TypeFn.cast(U, IZp)) *
|
||||||
|
(TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp)))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -453,7 +460,7 @@ def pooling_nhwc_sum(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] += cast(
|
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||||
|
|
||||||
|
|
||||||
@ -473,8 +480,8 @@ def pooling_nhwc_max(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
||||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
TypeFn.cast(
|
||||||
D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -493,7 +500,7 @@ def pooling_nhwc_max_unsigned(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
|
||||||
cast_unsigned(
|
TypeFn.cast_unsigned(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@ -513,8 +520,9 @@ def pooling_nchw_max(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
|
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
|
||||||
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
|
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
|
||||||
cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
TypeFn.cast(
|
||||||
D.ow * S.SW + D.kw * S.DW,]))
|
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
||||||
|
D.ow * S.SW + D.kw * S.DW,]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -533,8 +541,8 @@ def pooling_nhwc_min(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
||||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
TypeFn.cast(
|
||||||
D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -553,7 +561,7 @@ def pooling_nhwc_min_unsigned(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
|
||||||
cast_unsigned(
|
TypeFn.cast_unsigned(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@ -572,7 +580,7 @@ def pooling_ndhwc_sum(
|
|||||||
"""
|
"""
|
||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.od, D.oh, D.ow, D.c] += cast(
|
O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||||
D.ow * S.SW + D.kw * S.DW, D.c])
|
D.ow * S.SW + D.kw * S.DW, D.c])
|
||||||
|
|
||||||
@ -593,7 +601,7 @@ def pooling_ndhwc_max(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
|
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
|
||||||
cast(
|
TypeFn.cast(
|
||||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||||
D.ow * S.SW + D.kw * S.DW, D.c]))
|
D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
@ -614,7 +622,7 @@ def pooling_ndhwc_min(
|
|||||||
implements(ConvolutionOpInterface)
|
implements(ConvolutionOpInterface)
|
||||||
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
|
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
|
||||||
cast(
|
TypeFn.cast(
|
||||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||||
D.ow * S.SW + D.kw * S.DW, D.c]))
|
D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
@ -636,14 +644,15 @@ def fill_rng_2d(
|
|||||||
the range of the generated random numbers.
|
the range of the generated random numbers.
|
||||||
"""
|
"""
|
||||||
domain(D.m, D.n)
|
domain(D.m, D.n)
|
||||||
multiplier = cast(I32, const(1103515245))
|
multiplier = TypeFn.cast(I32, const(1103515245))
|
||||||
increment = cast(I32, const(12345))
|
increment = TypeFn.cast(I32, const(12345))
|
||||||
rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
|
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
|
||||||
rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
|
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
|
||||||
inv_range = cast(F64, const(2.3283064e-10))
|
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
|
||||||
offset = cast(F64, const(2147483647))
|
offset = TypeFn.cast(F64, const(2147483647))
|
||||||
scaling = (max - min) * inv_range
|
scaling = (max - min) * inv_range
|
||||||
O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
|
O[D.m, D.n] = TypeFn.cast(T,
|
||||||
|
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -656,4 +665,4 @@ def soft_plus_2d(
|
|||||||
"""
|
"""
|
||||||
domain(D.m, D.n)
|
domain(D.m, D.n)
|
||||||
O[D.m, D.n] = \
|
O[D.m, D.n] = \
|
||||||
PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n])))
|
PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n])))
|
||||||
|
@ -38,19 +38,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
fn_name: add
|
fn_name: add
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast
|
||||||
type_var: T
|
type_var: T
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '42 : i64'
|
scalar_const: '42 : i64'
|
||||||
is_unsigned_cast: false
|
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
symbolic_cast:
|
type_fn:
|
||||||
|
fn_name: cast_unsigned
|
||||||
type_var: T
|
type_var: T
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_index: 1
|
scalar_index: 1
|
||||||
is_unsigned_cast: true
|
|
||||||
|
|
||||||
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
||||||
|
|
||||||
@ -86,9 +86,9 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
# IMPL-LABEL: void Test1Op::regionBuilder(
|
# IMPL-LABEL: void Test1Op::regionBuilder(
|
||||||
# IMPL: ImplicitLocOpBuilder &b, Block &block)
|
# IMPL: ImplicitLocOpBuilder &b, Block &block)
|
||||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
||||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]], false);
|
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
|
||||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]], true);
|
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
|
||||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
|
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ def matmul(
|
|||||||
A=TensorDef(T, S.M, S.K),
|
A=TensorDef(T, S.M, S.K),
|
||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
# CHECK: ---
|
# CHECK: ---
|
||||||
|
@ -15,11 +15,11 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||||||
# CHECK: scalar_apply:
|
# CHECK: scalar_apply:
|
||||||
# CHECK: fn_name: mul
|
# CHECK: fn_name: mul
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: symbolic_cast:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: U
|
# CHECK: type_var: U
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_arg: A
|
# CHECK: scalar_arg: A
|
||||||
# CHECK: symbolic_cast:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: U
|
# CHECK: type_var: U
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_arg: B
|
# CHECK: scalar_arg: B
|
||||||
@ -28,7 +28,7 @@ def matmul(
|
|||||||
A=TensorDef(T, S.M, S.K),
|
A=TensorDef(T, S.M, S.K),
|
||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
# CHECK: ---
|
# CHECK: ---
|
||||||
@ -42,23 +42,23 @@ def matmul(
|
|||||||
# CHECK: scalar_apply:
|
# CHECK: scalar_apply:
|
||||||
# CHECK: fn_name: add
|
# CHECK: fn_name: add
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: symbolic_cast:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_const: '3.1415926535897931 : f64'
|
# CHECK: scalar_const: '3.1415926535897931 : f64'
|
||||||
# CHECK: symbolic_cast:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_const: '42 : i64'
|
# CHECK: scalar_const: '42 : i64'
|
||||||
# CHECK: symbolic_cast:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
|
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def constants(O=TensorDef(T, S.M, S.K, output=True)):
|
def constants(O=TensorDef(T, S.M, S.K, output=True)):
|
||||||
pi = cast(T, const(3.1415926535897931))
|
pi = TypeFn.cast(T, const(3.1415926535897931))
|
||||||
cst42 = cast(T, const(42))
|
cst42 = TypeFn.cast(T, const(42))
|
||||||
cst1000 = cast(T, const(1e+3))
|
cst1000 = TypeFn.cast(T, const(1e+3))
|
||||||
O[D.m, D.n] = pi + cst42 - cst1000
|
O[D.m, D.n] = pi + cst42 - cst1000
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ def conv_poly(
|
|||||||
strides=IndexAttrDef(S.SH, S.SW),
|
strides=IndexAttrDef(S.SH, S.SW),
|
||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] += cast(
|
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
D.c]) * cast(U, K[D.kh, D.kw, D.c])
|
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c])
|
||||||
|
|
||||||
|
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
|
@ -26,7 +26,7 @@ def matmul_poly(
|
|||||||
B=TensorDef(T2, S.K, S.N),
|
B=TensorDef(T2, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -35,7 +35,8 @@ def matmul_unsigned_poly(
|
|||||||
B=TensorDef(T2, S.K, S.N),
|
B=TensorDef(T2, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
|
||||||
|
U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
|
@ -14,27 +14,29 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||||||
# - exponential functions
|
# - exponential functions
|
||||||
# - custom op names.
|
# - custom op names.
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def fill_rng_poly(
|
def fill_rng_poly(
|
||||||
min=ScalarDef(F64),
|
min=ScalarDef(F64),
|
||||||
max=ScalarDef(F64),
|
max=ScalarDef(F64),
|
||||||
seed=ScalarDef(I32),
|
seed=ScalarDef(I32),
|
||||||
O=TensorDef(T, S.M, S.N, output=True)):
|
O=TensorDef(T, S.M, S.N, output=True)):
|
||||||
multiplier = cast(I32, const(1103515245))
|
multiplier = TypeFn.cast(I32, const(1103515245))
|
||||||
increment = cast(I32, const(12345))
|
increment = TypeFn.cast(I32, const(12345))
|
||||||
rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
|
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
|
||||||
rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
|
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
|
||||||
inv_range = cast(F64, const(2.3283064e-10))
|
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
|
||||||
offset = cast(F64, const(2147483647))
|
offset = TypeFn.cast(F64, const(2147483647))
|
||||||
scaling = (max - min) * inv_range
|
scaling = (max - min) * inv_range
|
||||||
O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
|
O[D.m, D.n] = TypeFn.cast(T,
|
||||||
|
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def soft_plus_poly(
|
def soft_plus_poly(
|
||||||
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
|
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
|
||||||
O[D.m, D.n] = \
|
O[D.m, D.n] = PrimFn.log(
|
||||||
PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
|
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n])))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op(op_name="custom_op_name")
|
@linalg_structured_op(op_name="custom_op_name")
|
||||||
|
@ -20,8 +20,8 @@ def pooling_max_poly(
|
|||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
||||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
TypeFn.cast(
|
||||||
D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -33,7 +33,7 @@ def pooling_max_unsigned_poly(
|
|||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
|
||||||
cast_unsigned(
|
TypeFn.cast_unsigned(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ def pooling_min_poly(
|
|||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
||||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
TypeFn.cast(
|
||||||
D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -59,7 +59,7 @@ def pooling_min_unsigned_poly(
|
|||||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
|
||||||
cast_unsigned(
|
TypeFn.cast_unsigned(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,4 +13,4 @@ def matmul(
|
|||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
@ -24,7 +24,7 @@ def matmul(
|
|||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
# Verifies that assignment to a scalar (represented as [None]) is represented
|
# Verifies that assignment to a scalar (represented as [None]) is represented
|
||||||
@ -42,7 +42,7 @@ def matmul(
|
|||||||
# CHECK-NEXT: - reduction
|
# CHECK-NEXT: - reduction
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
|
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
|
||||||
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
|
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
|
||||||
|
|
||||||
|
|
||||||
# Verifies that the index_dims of shape-only operands translate to correct
|
# Verifies that the index_dims of shape-only operands translate to correct
|
||||||
@ -65,4 +65,4 @@ def pool(
|
|||||||
K=TensorDef(T, S.K, index_dims=[D.k]),
|
K=TensorDef(T, S.K, index_dims=[D.k]),
|
||||||
O=TensorDef(U, S.O, output=True)):
|
O=TensorDef(U, S.O, output=True)):
|
||||||
domain(D.o, D.k)
|
domain(D.o, D.k)
|
||||||
O[D.o] += cast(U, I[D.o * 2 + D.k])
|
O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k])
|
||||||
|
@ -89,12 +89,12 @@ struct ScalarApply {
|
|||||||
std::vector<ScalarExpression> operands;
|
std::vector<ScalarExpression> operands;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ScalarSymbolicCast {
|
struct ScalarTypeFn {
|
||||||
|
std::string fnName;
|
||||||
std::string typeVar;
|
std::string typeVar;
|
||||||
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
||||||
// we use a heap allocated vector.
|
// we use a heap allocated vector.
|
||||||
std::vector<ScalarExpression> operands;
|
std::vector<ScalarExpression> operands;
|
||||||
bool isUnsignedCast;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ScalarExpression {
|
struct ScalarExpression {
|
||||||
@ -102,7 +102,7 @@ struct ScalarExpression {
|
|||||||
Optional<std::string> constant;
|
Optional<std::string> constant;
|
||||||
Optional<int64_t> index;
|
Optional<int64_t> index;
|
||||||
Optional<ScalarApply> apply;
|
Optional<ScalarApply> apply;
|
||||||
Optional<ScalarSymbolicCast> symbolicCast;
|
Optional<ScalarTypeFn> typeFn;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ScalarAssign {
|
struct ScalarAssign {
|
||||||
@ -141,7 +141,8 @@ namespace yaml {
|
|||||||
/// Top-level type containing op metadata and one of a concrete op type.
|
/// Top-level type containing op metadata and one of a concrete op type.
|
||||||
/// Currently, the only defined op type is `structured_op` (maps to
|
/// Currently, the only defined op type is `structured_op` (maps to
|
||||||
/// `LinalgStructuredOpConfig`).
|
/// `LinalgStructuredOpConfig`).
|
||||||
template <> struct MappingTraits<LinalgOpConfig> {
|
template <>
|
||||||
|
struct MappingTraits<LinalgOpConfig> {
|
||||||
static void mapping(IO &io, LinalgOpConfig &info) {
|
static void mapping(IO &io, LinalgOpConfig &info) {
|
||||||
io.mapOptional("metadata", info.metadata);
|
io.mapOptional("metadata", info.metadata);
|
||||||
io.mapOptional("structured_op", info.structuredOp);
|
io.mapOptional("structured_op", info.structuredOp);
|
||||||
@ -154,7 +155,8 @@ template <> struct MappingTraits<LinalgOpConfig> {
|
|||||||
/// - List of indexing maps (see `LinalgIndexingMaps`).
|
/// - List of indexing maps (see `LinalgIndexingMaps`).
|
||||||
/// - Iterator types (see `LinalgIteratorTypeDef`).
|
/// - Iterator types (see `LinalgIteratorTypeDef`).
|
||||||
/// - List of scalar level assignment (see `ScalarAssign`).
|
/// - List of scalar level assignment (see `ScalarAssign`).
|
||||||
template <> struct MappingTraits<LinalgStructuredOpConfig> {
|
template <>
|
||||||
|
struct MappingTraits<LinalgStructuredOpConfig> {
|
||||||
static void mapping(IO &io, LinalgStructuredOpConfig &info) {
|
static void mapping(IO &io, LinalgStructuredOpConfig &info) {
|
||||||
io.mapRequired("args", info.args);
|
io.mapRequired("args", info.args);
|
||||||
io.mapRequired("indexing_maps", info.indexingMaps);
|
io.mapRequired("indexing_maps", info.indexingMaps);
|
||||||
@ -177,7 +179,8 @@ template <> struct MappingTraits<LinalgStructuredOpConfig> {
|
|||||||
/// attribute symbols. During op creation these symbols are replaced by the
|
/// attribute symbols. During op creation these symbols are replaced by the
|
||||||
/// corresponding `name` attribute values. Only attribute arguments have
|
/// corresponding `name` attribute values. Only attribute arguments have
|
||||||
/// an `attribute_map`.
|
/// an `attribute_map`.
|
||||||
template <> struct MappingTraits<LinalgOperandDef> {
|
template <>
|
||||||
|
struct MappingTraits<LinalgOperandDef> {
|
||||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||||
io.mapRequired("name", info.name);
|
io.mapRequired("name", info.name);
|
||||||
io.mapRequired("usage", info.usage);
|
io.mapRequired("usage", info.usage);
|
||||||
@ -188,7 +191,8 @@ template <> struct MappingTraits<LinalgOperandDef> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Usage enum for a named argument.
|
/// Usage enum for a named argument.
|
||||||
template <> struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
template <>
|
||||||
|
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
||||||
io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
|
io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
|
||||||
io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
|
io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
|
||||||
@ -197,7 +201,8 @@ template <> struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Iterator type enum.
|
/// Iterator type enum.
|
||||||
template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
template <>
|
||||||
|
struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
||||||
static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
|
static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
|
||||||
io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
|
io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
|
||||||
io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
|
io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
|
||||||
@ -205,7 +210,8 @@ template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Metadata about the op (name, C++ name, and documentation).
|
/// Metadata about the op (name, C++ name, and documentation).
|
||||||
template <> struct MappingTraits<LinalgOpMetadata> {
|
template <>
|
||||||
|
struct MappingTraits<LinalgOpMetadata> {
|
||||||
static void mapping(IO &io, LinalgOpMetadata &info) {
|
static void mapping(IO &io, LinalgOpMetadata &info) {
|
||||||
io.mapRequired("name", info.name);
|
io.mapRequired("name", info.name);
|
||||||
io.mapRequired("cpp_class_name", info.cppClassName);
|
io.mapRequired("cpp_class_name", info.cppClassName);
|
||||||
@ -219,7 +225,8 @@ template <> struct MappingTraits<LinalgOpMetadata> {
|
|||||||
/// some symbols that bind to attributes of the op. Each indexing map must
|
/// some symbols that bind to attributes of the op. Each indexing map must
|
||||||
/// be normalized over the same list of dimensions, and its symbols must
|
/// be normalized over the same list of dimensions, and its symbols must
|
||||||
/// match the symbols for argument shapes.
|
/// match the symbols for argument shapes.
|
||||||
template <> struct MappingTraits<LinalgIndexingMapsConfig> {
|
template <>
|
||||||
|
struct MappingTraits<LinalgIndexingMapsConfig> {
|
||||||
static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
|
static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
|
||||||
io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
|
io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
|
||||||
}
|
}
|
||||||
@ -229,7 +236,8 @@ template <> struct MappingTraits<LinalgIndexingMapsConfig> {
|
|||||||
/// - The `arg` name must match a named output.
|
/// - The `arg` name must match a named output.
|
||||||
/// - The `value` is a scalar expression for computing the value to
|
/// - The `value` is a scalar expression for computing the value to
|
||||||
/// assign (see `ScalarExpression`).
|
/// assign (see `ScalarExpression`).
|
||||||
template <> struct MappingTraits<ScalarAssign> {
|
template <>
|
||||||
|
struct MappingTraits<ScalarAssign> {
|
||||||
static void mapping(IO &io, ScalarAssign &info) {
|
static void mapping(IO &io, ScalarAssign &info) {
|
||||||
io.mapRequired("arg", info.arg);
|
io.mapRequired("arg", info.arg);
|
||||||
io.mapRequired("value", info.value);
|
io.mapRequired("value", info.value);
|
||||||
@ -240,14 +248,15 @@ template <> struct MappingTraits<ScalarAssign> {
|
|||||||
/// - `scalar_arg`: Name of an argument to the op.
|
/// - `scalar_arg`: Name of an argument to the op.
|
||||||
/// - `scalar_apply`: Result of evaluating a named function (see
|
/// - `scalar_apply`: Result of evaluating a named function (see
|
||||||
/// `ScalarApply`).
|
/// `ScalarApply`).
|
||||||
/// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere.
|
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
|
||||||
template <> struct MappingTraits<ScalarExpression> {
|
template <>
|
||||||
|
struct MappingTraits<ScalarExpression> {
|
||||||
static void mapping(IO &io, ScalarExpression &info) {
|
static void mapping(IO &io, ScalarExpression &info) {
|
||||||
io.mapOptional("scalar_arg", info.arg);
|
io.mapOptional("scalar_arg", info.arg);
|
||||||
io.mapOptional("scalar_const", info.constant);
|
io.mapOptional("scalar_const", info.constant);
|
||||||
io.mapOptional("scalar_index", info.index);
|
io.mapOptional("scalar_index", info.index);
|
||||||
io.mapOptional("scalar_apply", info.apply);
|
io.mapOptional("scalar_apply", info.apply);
|
||||||
io.mapOptional("symbolic_cast", info.symbolicCast);
|
io.mapOptional("type_fn", info.typeFn);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -256,24 +265,27 @@ template <> struct MappingTraits<ScalarExpression> {
|
|||||||
/// functions include:
|
/// functions include:
|
||||||
/// - `add(lhs, rhs)`
|
/// - `add(lhs, rhs)`
|
||||||
/// - `mul(lhs, rhs)`
|
/// - `mul(lhs, rhs)`
|
||||||
template <> struct MappingTraits<ScalarApply> {
|
template <>
|
||||||
|
struct MappingTraits<ScalarApply> {
|
||||||
static void mapping(IO &io, ScalarApply &info) {
|
static void mapping(IO &io, ScalarApply &info) {
|
||||||
io.mapRequired("fn_name", info.fnName);
|
io.mapRequired("fn_name", info.fnName);
|
||||||
io.mapRequired("operands", info.operands);
|
io.mapRequired("operands", info.operands);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct MappingTraits<ScalarSymbolicCast> {
|
template <>
|
||||||
static void mapping(IO &io, ScalarSymbolicCast &info) {
|
struct MappingTraits<ScalarTypeFn> {
|
||||||
|
static void mapping(IO &io, ScalarTypeFn &info) {
|
||||||
|
io.mapRequired("fn_name", info.fnName);
|
||||||
io.mapRequired("type_var", info.typeVar);
|
io.mapRequired("type_var", info.typeVar);
|
||||||
io.mapRequired("operands", info.operands);
|
io.mapRequired("operands", info.operands);
|
||||||
io.mapRequired("is_unsigned_cast", info.isUnsignedCast);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Helper mapping which accesses an AffineMapAttr as a serialized string of
|
/// Helper mapping which accesses an AffineMapAttr as a serialized string of
|
||||||
/// the same.
|
/// the same.
|
||||||
template <> struct ScalarTraits<SerializedAffineMap> {
|
template <>
|
||||||
|
struct ScalarTraits<SerializedAffineMap> {
|
||||||
static void output(const SerializedAffineMap &value, void *rawYamlContext,
|
static void output(const SerializedAffineMap &value, void *rawYamlContext,
|
||||||
raw_ostream &out) {
|
raw_ostream &out) {
|
||||||
assert(value.affineMapAttr);
|
assert(value.affineMapAttr);
|
||||||
@ -949,33 +961,33 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
|
|||||||
interleaveToString(operandCppValues, ", ")));
|
interleaveToString(operandCppValues, ", ")));
|
||||||
return cppIdent;
|
return cppIdent;
|
||||||
}
|
}
|
||||||
if (expression.symbolicCast) {
|
if (expression.typeFn) {
|
||||||
// Symbolic cast.
|
// Symbolic cast.
|
||||||
// Operands must be arity 1.
|
// Operands must be arity 1.
|
||||||
if (expression.symbolicCast->operands.size() != 1) {
|
if (expression.typeFn->operands.size() != 1) {
|
||||||
emitError(genContext.getLoc())
|
emitError(genContext.getLoc())
|
||||||
<< "symbolic_cast operand arity must be 1";
|
<< "type conversion operand arity must be 1";
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
Optional<std::string> operandCppValue =
|
Optional<std::string> operandCppValue =
|
||||||
generateExpression(expression.symbolicCast->operands[0]);
|
generateExpression(expression.typeFn->operands[0]);
|
||||||
if (!operandCppValue)
|
if (!operandCppValue)
|
||||||
return None;
|
return None;
|
||||||
|
|
||||||
Optional<std::string> typeCppValue =
|
Optional<std::string> typeCppValue =
|
||||||
findTypeValue(expression.symbolicCast->typeVar, args);
|
findTypeValue(expression.typeFn->typeVar, args);
|
||||||
if (!typeCppValue) {
|
if (!typeCppValue) {
|
||||||
emitError(genContext.getLoc())
|
emitError(genContext.getLoc())
|
||||||
<< "type variable " << expression.symbolicCast->typeVar
|
<< "type variable " << expression.typeFn->typeVar
|
||||||
<< ", used in a symbolic cast must map to a predefined or "
|
<< ", used in a type conversion, must map to a predefined or "
|
||||||
<< "an argument type but it does not";
|
<< "an argument type but it does not";
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||||
stmts.push_back(
|
stmts.push_back(
|
||||||
llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent,
|
llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});",
|
||||||
typeCppValue.getValue(), *operandCppValue,
|
cppIdent, expression.typeFn->fnName,
|
||||||
expression.symbolicCast->isUnsignedCast));
|
typeCppValue.getValue(), *operandCppValue));
|
||||||
return cppIdent;
|
return cppIdent;
|
||||||
}
|
}
|
||||||
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
||||||
|
Loading…
x
Reference in New Issue
Block a user