[X86] Add tests showing failure to concat fma chain which share concatenated operands (#173403)

We often have fma chains that reuse operands down the chain (e.g mathlib
taylor series expansion) - FMA(FMA(X,Y,Z),X,W) etc.

For these cases combineConcatVectorOps fails to account that the same
operands will be concatenated down the recursion chain.
This commit is contained in:
Simon Pilgrim 2025-12-23 18:46:30 +00:00 committed by GitHub
parent 67f2a22a23
commit 4e44e87617
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -200,6 +200,75 @@ define <16 x float> @concat_fmsub_v16f32_v8f32_constant_split(<8 x float> %a0, <
ret <16 x float> %res
}
; 1 repeated concat + 1 constant - we only have to concat 2 operands down the fma chain
define <8 x float> @concat_fma_v8f32_v4f32_constant_repeatedop(<4 x float> %a0, <4 x float> %a1, <4 x float> %a2, <4 x float> %a3) {
; FMA4-LABEL: concat_fma_v8f32_v4f32_constant_repeatedop:
; FMA4: # %bb.0:
; FMA4-NEXT: vmovaps {{.*#+}} xmm4 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; FMA4-NEXT: vmovaps {{.*#+}} xmm5 = [2.0E+0,2.0E+0,2.0E+0,2.0E+0]
; FMA4-NEXT: vfmaddps {{.*#+}} xmm2 = (xmm0 * xmm2) + xmm4
; FMA4-NEXT: vfmaddps {{.*#+}} xmm3 = (xmm1 * xmm3) + xmm4
; FMA4-NEXT: vfmaddps {{.*#+}} xmm0 = (xmm0 * xmm2) + xmm5
; FMA4-NEXT: vfmaddps {{.*#+}} xmm1 = (xmm1 * xmm3) + xmm5
; FMA4-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
; FMA4-NEXT: retq
;
; FMA3-LABEL: concat_fma_v8f32_v4f32_constant_repeatedop:
; FMA3: # %bb.0:
; FMA3-NEXT: # kill: def $xmm2 killed $xmm2 def $ymm2
; FMA3-NEXT: vbroadcastss {{.*#+}} xmm4 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; FMA3-NEXT: vfmadd213ps {{.*#+}} xmm2 = (xmm0 * xmm2) + xmm4
; FMA3-NEXT: vfmadd213ps {{.*#+}} xmm3 = (xmm1 * xmm3) + xmm4
; FMA3-NEXT: vbroadcastss {{.*#+}} xmm4 = [2.0E+0,2.0E+0,2.0E+0,2.0E+0]
; FMA3-NEXT: vfmadd213ps {{.*#+}} xmm2 = (xmm0 * xmm2) + xmm4
; FMA3-NEXT: vfmadd213ps {{.*#+}} xmm3 = (xmm1 * xmm3) + xmm4
; FMA3-NEXT: vinsertf128 $1, %xmm3, %ymm2, %ymm0
; FMA3-NEXT: retq
%l0 = call <4 x float> @llvm.fma.v4f32(<4 x float> %a0, <4 x float> %a2, <4 x float> splat (float 1.000000e+00))
%h0 = call <4 x float> @llvm.fma.v4f32(<4 x float> %a1, <4 x float> %a3, <4 x float> splat (float 1.000000e+00))
%l1 = call <4 x float> @llvm.fma.v4f32(<4 x float> %a0, <4 x float> %l0, <4 x float> splat (float 2.000000e+00))
%h1 = call <4 x float> @llvm.fma.v4f32(<4 x float> %a1, <4 x float> %h0, <4 x float> splat (float 2.000000e+00))
%r = shufflevector <4 x float> %l1, <4 x float> %h1, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
ret <8 x float> %r
}
define <8 x double> @concat_fma_fmsub_v8f64_v4f64_constant_repeatedop_commute(<4 x double> %a0, <4 x double> %a1, <4 x double> %a2, <4 x double> %a3) {
; FMA4-LABEL: concat_fma_fmsub_v8f64_v4f64_constant_repeatedop_commute:
; FMA4: # %bb.0:
; FMA4-NEXT: vmovapd {{.*#+}} ymm4 = [-2.0E+0,-2.0E+0,-2.0E+0,-2.0E+0]
; FMA4-NEXT: vfmaddpd {{.*#+}} ymm2 = (ymm2 * ymm0) + ymm4
; FMA4-NEXT: vfmaddpd {{.*#+}} ymm3 = (ymm3 * ymm1) + ymm4
; FMA4-NEXT: vfmsubpd {{.*#+}} ymm0 = (ymm0 * ymm2) - ymm4
; FMA4-NEXT: vfmsubpd {{.*#+}} ymm1 = (ymm1 * ymm3) - ymm4
; FMA4-NEXT: retq
;
; AVX2-LABEL: concat_fma_fmsub_v8f64_v4f64_constant_repeatedop_commute:
; AVX2: # %bb.0:
; AVX2-NEXT: vbroadcastsd {{.*#+}} ymm4 = [-2.0E+0,-2.0E+0,-2.0E+0,-2.0E+0]
; AVX2-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm4
; AVX2-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm1 * ymm3) + ymm4
; AVX2-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm2 * ymm0) - ymm4
; AVX2-NEXT: vfmsub213pd {{.*#+}} ymm1 = (ymm3 * ymm1) - ymm4
; AVX2-NEXT: retq
;
; AVX512-LABEL: concat_fma_fmsub_v8f64_v4f64_constant_repeatedop_commute:
; AVX512: # %bb.0:
; AVX512-NEXT: vbroadcastsd {{.*#+}} ymm4 = [-2.0E+0,-2.0E+0,-2.0E+0,-2.0E+0]
; AVX512-NEXT: # kill: def $ymm2 killed $ymm2 def $zmm2
; AVX512-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm4
; AVX512-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm1 * ymm3) + ymm4
; AVX512-NEXT: vfmsub213pd {{.*#+}} ymm2 = (ymm0 * ymm2) - ymm4
; AVX512-NEXT: vfmsub213pd {{.*#+}} ymm3 = (ymm1 * ymm3) - ymm4
; AVX512-NEXT: vinsertf64x4 $1, %ymm3, %zmm2, %zmm0
; AVX512-NEXT: retq
%l0 = call <4 x double> @llvm.fma.v4f32(<4 x double> %a2, <4 x double> %a0, <4 x double> splat (double -2.000000e+00))
%h0 = call <4 x double> @llvm.fma.v4f32(<4 x double> %a3, <4 x double> %a1, <4 x double> splat (double -2.000000e+00))
%l1 = call <4 x double> @llvm.fma.v4f32(<4 x double> %a0, <4 x double> %l0, <4 x double> splat (double +2.000000e+00))
%h1 = call <4 x double> @llvm.fma.v4f32(<4 x double> %a1, <4 x double> %h0, <4 x double> splat (double +2.000000e+00))
%r = shufflevector <4 x double> %l1, <4 x double> %h1, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
ret <8 x double> %r
}
; negative - too many operands to concat
define <8 x float> @concat_fmadd_v8f32_v4f32(<4 x float> %a0, <4 x float> %a1, <4 x float> %b0, <4 x float> %b1, <4 x float> %c0, <4 x float> %c1) {
; FMA4-LABEL: concat_fmadd_v8f32_v4f32: