[MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to bfmmla (#145064)

This commit is contained in:
Momchil Velikov 2025-06-27 15:37:13 +01:00 committed by GitHub
parent da2969b105
commit 3876e887d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 1 deletions

View File

@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "BFloat16 matrix multiply-accumulate";
let description = [{
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
segment of the first source vector by the 4x2 BFloat16 matrix in the
corresponding segment of the second source vector, then accumulates
this intermediate result with the 2x2 Float32 matrix in the corresponding
segment of the accumulator vector, yielding the final 2x2 Float32
segment of the result.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,

View File

@ -220,7 +220,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<ConvertFromSvboolIntrOp,
target.addLegalOp<BfmmlaOp,
ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,

View File

@ -72,3 +72,63 @@ func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
return
}
// -----
func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<[8]xf16>,
%rhs: vector<[8]xf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[8]xf16>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}
// -----
func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<[4]xbf16>,
%rhs: vector<[4]xbf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[4]xbf16>}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[4]xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}
// -----
func.func @bfmmla_fixed_dimension_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<8xbf16>,
%rhs: vector<8xbf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<8xbf16>}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}
// -----
func.func @bfmmla_invalid_element_type_acc(%acc: vector<[4]xi32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<[4]xi32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[4]xi32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
// -----
func.func @bfmmla_invalid_dimension_acc(%acc: vector<[8]xf32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<[8]xf32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[8]xf32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[8]xf32>
return %0 : vector<[8]xf32>
}
// -----
func.func @bfmmla_fixed_dimension_acc(%acc: vector<4xf32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<4xf32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<4xf32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<4xf32>
return %0 : vector<4xf32>
}

View File

@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
%b: vector<[8]xbf16>,
%c: vector<[4]xf32>) -> vector<[4]xf32> {
// CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
%0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}
// -----
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,

View File

@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
%arg1: vector<[8]xbf16>,
%arg2: vector<[4]xf32>)
-> vector<[4]xf32> {
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
%0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
(vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
-> vector<[4]xf32>
llvm.return %0 : vector<[4]xf32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,