[mlir][linalg] Add unsigned min/max/cast function to OpDSL.

Update OpDSL to support unsigned integers by adding unsigned min/max/cast signatures. Add tests in OpDSL and on the C++ side to verify the proper signed and unsigned operations are emitted.

The patch addresses an issue brought up in https://reviews.llvm.org/D111170.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D111230
This commit is contained in:
Tobias Gysi 2021-10-07 06:26:38 +00:00
parent 06404d5488
commit 3fe7fe4424
10 changed files with 601 additions and 158 deletions

View File

@ -56,12 +56,78 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul_unsigned
cpp_class_name: MatmulUnsignedOp
doc: |-
Performs a unsigned matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: C
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
iterator_types:
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: true
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: true
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
@ -132,12 +198,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: sub
@ -148,12 +216,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
@ -221,12 +291,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: lhs
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: rhs
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
@ -284,12 +356,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
@ -361,12 +435,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: sub
@ -377,12 +453,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
@ -438,12 +516,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
@ -499,12 +579,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: y
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matvec
@ -561,12 +643,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
@ -621,12 +705,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: A
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d
@ -682,12 +768,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d
@ -745,12 +833,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d
@ -811,12 +901,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d_nwc_wcf
@ -887,12 +979,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf
@ -975,12 +1069,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf_q
@ -1080,12 +1176,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: sub
@ -1096,12 +1194,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw_fchw
@ -1184,12 +1284,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d_ndhwc_dhwcf
@ -1272,12 +1374,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv2D_nhw
@ -1353,12 +1457,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv2D_nhw_q
@ -1449,12 +1555,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: sub
@ -1465,12 +1573,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv2D_nhwc
@ -1549,12 +1659,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv2D_nhwc_q
@ -1649,12 +1761,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: sub
@ -1665,12 +1779,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: K
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_sum
@ -1741,6 +1857,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max
@ -1811,6 +1928,78 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max_unsigned
cpp_class_name: PoolingNhwcMaxUnsignedOp
doc: |-
Performs unsigned max pooling.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
implements:
- LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
- !LinalgOperandDefConfig
name: dilations
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d3, d4)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d0, d1, d2, d5)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
- reduction
- parallel
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
fn_name: max_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: true
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nchw_max
@ -1881,6 +2070,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min
@ -1951,6 +2141,78 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min_unsigned
cpp_class_name: PoolingNhwcMinUnsignedOp
doc: |-
Performs unsigned min pooling.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
implements:
- LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
- !LinalgOperandDefConfig
name: dilations
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d3, d4)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
-> (d0, d1, d2, d5)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
- reduction
- parallel
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
fn_name: min_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: true
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_sum
@ -2027,6 +2289,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_max
@ -2103,6 +2366,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_min
@ -2179,6 +2443,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
@ -2246,6 +2511,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_const: '2147483647 : i64'
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: F64
@ -2268,6 +2534,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_index: 1
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: add
@ -2286,6 +2553,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_index: 0
is_unsigned_cast: false
- !ScalarExpression
scalar_arg: seed
- !ScalarExpression
@ -2294,24 +2562,29 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
is_unsigned_cast: false
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: mul
@ -2330,8 +2603,10 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_const: '2.3283063999999999E-10 : f64'
is_unsigned_cast: false
- !ScalarExpression
scalar_arg: min
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: soft_plus_2d
@ -2377,6 +2652,7 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
is_unsigned_cast: false
- !ScalarExpression
scalar_apply:
fn_name: exp
@ -2387,3 +2663,4 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I
is_unsigned_cast: false

View File

@ -196,7 +196,7 @@ public:
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand) {
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
@ -204,23 +204,32 @@ public:
return operand;
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>())
if (operand.getType().isa<FloatType>()) {
if (isUnsignedCast)
return builder.create<FPToUIOp>(loc, toType, operand);
return builder.create<FPToSIOp>(loc, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return builder.create<IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either sign extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth())
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
return builder.create<ZeroExtendIOp>(loc, toType, operand);
return builder.create<SignExtendIOp>(loc, toType, operand);
}
if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<TruncateIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
if (operand.getType().isa<IntegerType>())
if (operand.getType().isa<IntegerType>()) {
if (isUnsignedCast)
return builder.create<UIToFPOp>(loc, toFloatType, operand);
return builder.create<SIToFPOp>(loc, toFloatType, operand);
}
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<FPExtOp>(loc, toFloatType, operand);
@ -284,6 +293,15 @@ public:
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<MaxUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
@ -293,6 +311,15 @@ public:
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<MinUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
void yieldOutputs(ValueRange values) {
assert(!values.empty() && "linalg ops must yield outputs");
if (values.empty())

View File

@ -340,6 +340,8 @@ class PrimFn:
max = PrimFnType("max")
min = PrimFnType("min")
sub = PrimFnType("sub")
max_unsigned = PrimFnType("max_unsigned")
min_unsigned = PrimFnType("min_unsigned")
class ReduceFnType:
@ -365,6 +367,8 @@ class ReduceFn:
mul = PrimFn.mul.reduce
max = PrimFn.max.reduce
min = PrimFn.min.reduce
max_unsigned = PrimFn.max_unsigned.reduce
min_unsigned = PrimFn.min_unsigned.reduce
class PrimApply(TensorExpression):
@ -438,8 +442,8 @@ class cast(TensorExpression):
self.operand = operand
def to_scalar_expression(self) -> ScalarExpression:
return ScalarSymbolicCast(self.to_type,
self.operand.to_scalar_expression()).expr()
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
False).expr()
def visit_tensor_exprs(self, callback):
super().visit_tensor_exprs(callback)
@ -449,6 +453,17 @@ class cast(TensorExpression):
return f"cast({self.to_type}, {repr(self.operand)})"
class cast_unsigned(cast):
"""Casts the element type to an unsigned type (typically symbolic TypeVar)."""
def to_scalar_expression(self) -> ScalarExpression:
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
True).expr()
def __repr__(self):
return f"cast_unsigned({self.to_type}, {repr(self.operand)})"
class ReduceApply(TensorExpression):
"""Application of a reduction.

View File

@ -230,10 +230,12 @@ class _BodyBuilder:
return fn(*operand_values)
elif expr.symbolic_cast:
operand_value = self.expression(expr.symbolic_cast.operand)
return self.cast(expr.symbolic_cast.to_type.name, operand_value)
return self.cast(expr.symbolic_cast.to_type.name, operand_value,
expr.symbolic_cast.is_unsigned_cast)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def cast(self, type_var_name: str, operand: Value) -> Value:
def cast(self, type_var_name: str, operand: Value,
is_unsigned_cast: bool) -> Value:
try:
to_type = self.type_mapping[type_var_name]
except KeyError:
@ -242,29 +244,37 @@ class _BodyBuilder:
if operand.type == to_type:
return operand
if _is_integer_type(to_type):
return self._cast_to_integer(to_type, operand)
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
elif _is_floating_point_type(to_type):
return self._cast_to_floating_point(to_type, operand)
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
def _cast_to_integer(self, to_type: Type, operand: Value,
is_unsigned_cast: bool) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
if _is_floating_point_type(operand_type):
if is_unsigned_cast:
return std.FPToUIOp(to_type, operand).result
return std.FPToSIOp(to_type, operand).result
if _is_index_type(operand_type):
return std.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
if to_width > from_width:
if is_unsigned_cast:
return std.ZeroExtendIOp(to_type, operand).result
return std.SignExtendIOp(to_type, operand).result
elif to_width < from_width:
return std.TruncateIOp(to_type, operand).result
raise ValueError(f"Unable to cast body expression from {operand_type} to "
f"{to_type}")
def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value:
def _cast_to_floating_point(self, to_type: Type, operand: Value,
is_unsigned_cast: bool) -> Value:
operand_type = operand.type
if _is_integer_type(operand_type):
if is_unsigned_cast:
return std.UIToFPOp(to_type, operand).result
return std.SIToFPOp(to_type, operand).result
# Assume FloatType.
to_width = _get_floating_point_width(to_type)
@ -324,6 +334,13 @@ class _BodyBuilder:
return std.MaxSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.MaxFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return std.MaxUIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.MinFOp(lhs.type, lhs, rhs).result
@ -331,6 +348,12 @@ class _BodyBuilder:
return std.MinSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.MinFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return std.MinUIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
in_arg_defs: Sequence[OperandDefConfig],

View File

@ -85,15 +85,17 @@ class ScalarIndex:
class ScalarSymbolicCast:
"""A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
def __init__(self, to_type: TypeVar, operand: "ScalarExpression",
is_unsigned_cast: bool):
self.to_type = to_type
self.operand = operand
self.is_unsigned_cast = is_unsigned_cast
def expr(self) -> "ScalarExpression":
return ScalarExpression(symbolic_cast=self)
def __repr__(self):
return f"ScalarSymbolicCast({self.to_type}, {self.operand})"
return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})"
class ScalarExpression(YAMLObject):
@ -144,7 +146,8 @@ class ScalarExpression(YAMLObject):
return dict(
symbolic_cast=dict(
type_var=self.symbolic_cast.to_type.name,
operands=[self.symbolic_cast.operand]))
operands=[self.symbolic_cast.operand],
is_unsigned_cast=self.symbolic_cast.is_unsigned_cast))
else:
raise ValueError(f"Unexpected ScalarExpression type: {self}")

View File

@ -20,6 +20,20 @@ def matmul(
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@linalg_structured_op
def matmul_unsigned(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
"""Performs an unsigned matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
@linalg_structured_op
def quantized_matmul(
A=TensorDef(T1, S.M, S.K),
@ -411,6 +425,24 @@ def pooling_nhwc_max(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]))
@linalg_structured_op
def pooling_nhwc_max_unsigned(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs unsigned max pooling.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_nchw_max(
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@ -447,6 +479,23 @@ def pooling_nhwc_min(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]))
@linalg_structured_op
def pooling_nhwc_min_unsigned(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs unsigned min pooling.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_ndhwc_sum(

View File

@ -1,35 +1,108 @@
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
// Verifies that different argument types is legal.
func @generalize_matmul_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_tensor_f32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
// CHECK-LABEL: @generalize_matmul_tensor_f16f64f32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
// Verify floating point extension and truncation.
// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
// -----
func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
// Verifies that different argument types is legal.
func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i16, %[[B_ARG:.+]]: i64, %[[C_ARG:.+]]: i32)
// Verify signed integer extension and truncation.
// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i16 to i32
// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i64 to i32
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
// -----
func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i16i64f32
// Verify signed integer to floating point cast.
// CHECK: = sitofp
// CHECK: = sitofp
// -----
func @generalize_matmul_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_tensor_f16f64i32
// Verify floating point to signed integer cast.
// CHECK: = fptosi
// CHECK: = fptosi
// -----
func @generalize_matmul_unsigned_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64i32
// Verify unsigned integer extension and truncation.
// CHECK: = zexti
// CHECK: = trunci
// -----
func @generalize_matmul_unsigned_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64f32
// Verify unsigned integer to floating point cast.
// CHECK: = uitofp
// CHECK: = uitofp
// -----
func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_unsigned_tensor_f16f64i32
// Verify floating point to unsigend integer cast.
// CHECK: = fptoui
// CHECK: = fptoui
// -----
func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
%0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
@ -51,10 +124,20 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
}
// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
// Verify signed integer maximum.
// CHECK: = maxsi
// -----
func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
return %0: tensor<1x2x4x1xi32>
}
// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32
// Verify unsigned integer minimum.
// CHECK: = maxui
// -----
@ -79,10 +162,20 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
}
// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
// CHECK-NEXT: linalg.yield %[[MIN]] : i32
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
// Verify signed integer minimum.
// CHECK: = minsi
// -----
func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
return %0: tensor<1x2x4x1xi32>
}
// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32
// Verify unsigned integer minimum.
// CHECK: = minui
// -----
@ -169,122 +262,3 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
// CHECK-NEXT: linalg.yield %[[LOG]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
// -----
// Verifies floating point to integer cast.
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
// CHECK-LABEL: @generalize_matmul_tensor_f32_f32_i16
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16)
// CHECK-NEXT: %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16
// CHECK-NEXT: %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
// CHECK-NEXT: linalg.yield %[[ADD]] : i16
// CHECK-NEXT: -> tensor<16x32xi16>
// -----
// Verifies sign extension cast.
func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_i32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
// -----
// Verifies that different argument types is legal.
func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i8_i16_i32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
// -----
// Somewhat non-sensical but checks integer truncation cast.
func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
// CHECK-LABEL: @generalize_matmul_tensor_i32_i32_i16
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
// CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
// CHECK-NEXT: linalg.yield %[[ADD]] : i16
// CHECK-NEXT: -> tensor<16x32xi16>
// -----
// Verifies integer to floating point cast.
func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_f32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
// CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
// CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
// -----
// Verifies floating point extension cast.
func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_tensor_f16_f16_f32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
// CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
// -----
// Verifies floating point truncation.
func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_tensor_f64_f64_f32
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
// CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>

View File

@ -43,12 +43,14 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_const: '42 : i64'
is_unsigned_cast: false
- !ScalarExpression
symbolic_cast:
type_var: T
operands:
- !ScalarExpression
scalar_index: 1
is_unsigned_cast: true
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
@ -84,9 +86,9 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-LABEL: void Test1Op::regionBuilder(
# IMPL: ImplicitLocOpBuilder &b, Block &block)
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]], false);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]], true);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);

View File

@ -29,6 +29,15 @@ def matmul_poly(
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@linalg_structured_op
def matmul_unsigned_poly(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
domain(D.m, D.n, D.k)
C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
@linalg_structured_op
def conv_poly(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
@ -54,6 +63,17 @@ def pooling_max_poly(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]))
@linalg_structured_op
def pooling_max_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_min_poly(
@ -67,6 +87,17 @@ def pooling_min_poly(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]))
@linalg_structured_op
def pooling_min_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def fill_rng_poly(
@ -147,6 +178,15 @@ with Context() as ctx, Location.unknown():
def test_i8i8i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_i8i8i32_matmul_unsigned
# CHECK: = zexti
# CHECK: = zexti
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), i32))
def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_i8i16i32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
# CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
@ -189,6 +229,15 @@ with Context() as ctx, Location.unknown():
def test_i8i8f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_i8i8f32_matmul_unsigned
# CHECK: = uitofp
# CHECK: = uitofp
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), f32))
def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_f16f16f32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
# CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
@ -252,6 +301,16 @@ with Context() as ctx, Location.unknown():
return pooling_max_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_max_unsigned_pooling
# CHECK: = fptoui
# CHECK: = maxui
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
def test_f32i32_max_unsigned_pooling(input, shape, init_result):
return pooling_max_unsigned_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_max_pooling
# CHECK: linalg.generic
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
@ -268,6 +327,7 @@ with Context() as ctx, Location.unknown():
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_pooling
# CHECK: = fptosi
# CHECK: = minsi
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
@ -276,6 +336,16 @@ with Context() as ctx, Location.unknown():
return pooling_min_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_unsigned_pooling
# CHECK: = fptoui
# CHECK: = minui
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
def test_f32i32_min_unsigned_pooling(input, shape, init_result):
return pooling_min_unsigned_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_min_pooling
# CHECK: = minf
@builtin.FuncOp.from_py_func(

View File

@ -95,6 +95,7 @@ struct ScalarSymbolicCast {
// NOTE: This must be of arity 1, but to break the self-referential cycle,
// we use a heap allocated vector.
std::vector<ScalarExpression> operands;
bool isUnsignedCast;
};
struct ScalarExpression {
@ -278,6 +279,7 @@ struct MappingTraits<ScalarSymbolicCast> {
static void mapping(IO &io, ScalarSymbolicCast &info) {
io.mapRequired("type_var", info.typeVar);
io.mapRequired("operands", info.operands);
io.mapRequired("is_unsigned_cast", info.isUnsignedCast);
}
};
@ -986,9 +988,10 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
return None;
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
cppIdent, typeCppValue.getValue(),
*operandCppValue));
stmts.push_back(
llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent,
typeCppValue.getValue(), *operandCppValue,
expression.symbolicCast->isUnsignedCast));
return cppIdent;
}
emitError(genContext.getLoc()) << "unknown ScalarExpression type";