[mlir][spirv] Reject coop matrix operands on unsupported arithmetic ops (#147230)

Cooperative matrix operands are only supported for `add/sub/mul/div`
binary arithmetic ops, but currently all binary arithmetic ops accept
cooperative matrix operands, including `mod/rem`. This change fixes this
behaviour.
This commit is contained in:
Darren Wihandi 2025-07-08 10:44:37 -04:00 committed by GitHub
parent 517cda12e5
commit 4a68562e9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 21 deletions

View File

@ -24,8 +24,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOf<type>:$operand1,
SPIRV_ScalarOrVectorOf<type>:$operand2
);
let results = (outs
SPIRV_ScalarOrVectorOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types these arithmetic instructions can support
// cooperative matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@ -82,7 +99,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
// -----
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
let summary = "Floating-point addition of Operand 1 and Operand 2.";
let description = [{
@ -104,7 +121,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
// -----
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
let description = [{
@ -154,7 +171,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
// -----
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
let description = [{
@ -229,7 +246,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
// -----
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
let description = [{
@ -251,9 +268,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
// -----
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer addition of Operand 1 and Operand 2.";
let description = [{
@ -322,9 +339,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
// -----
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer multiplication of Operand 1 and Operand 2.";
let description = [{
@ -354,9 +371,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
// -----
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Integer subtraction of Operand 2 from Operand 1.";
let description = [{
@ -460,9 +477,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
// -----
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
let description = [{
@ -622,9 +639,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
// -----
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
let description = [{

View File

@ -577,3 +577,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
spirv.Return
}
// -----
// These binary arithmetic instructions do not support coop matrix operands.
spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
// -----