[MLIR][Complex] Check for FastMathFlag in DivOp folder (#176249)

- Fold DivOp with LHS that has NaN as real or imag to Complex of NaNs
- Fold `div(a, Complex<1, 0>) -> a` if fast math flag with nnan is set
This commit is contained in:
Amr Hesham 2026-02-20 22:04:07 +01:00 committed by GitHub
parent 67fcdc9016
commit fd4dec9b1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 23 deletions

View File

@ -371,35 +371,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto rhs = adaptor.getRhs();
auto lhs = adaptor.getLhs();
Attribute rhs = adaptor.getRhs();
Attribute lhs = adaptor.getLhs();
// We can't fold without knowing that LHS isn't NaN
if (!rhs || !lhs)
return {};
// complex.div(complex.constant<NaN, NaN>, a) -> complex.constant<NaN, NaN>
// complex.div(complex.constant<NaN, a>, b) -> complex.constant<NaN, NaN>
// complex.div(complex.constant<a, NaN>, b) -> complex.constant<NaN, NaN>
bool isLhsComplexHasNan = false;
ArrayAttr lhsArrayAttr = dyn_cast_if_present<ArrayAttr>(lhs);
if (lhsArrayAttr && lhsArrayAttr.size() == 2) {
APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
isLhsComplexHasNan = lhsReal.isNaN() || lhsImag.isNaN();
if (isLhsComplexHasNan) {
Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
return ArrayAttr::get(getContext(), {nanValue, nanValue});
}
}
ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
ArrayAttr rhsArrayAttr = dyn_cast_if_present<ArrayAttr>(rhs);
if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
return {};
ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
return {};
// Fold only if RHS is complex.constant<1.0, 0.0>
APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
if (!rhsImag.isZero())
APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
if (!rhsImag.isZero() || rhsReal != APFloat(rhsReal.getSemantics(), 1))
return {};
APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
if (lhsReal.isNaN() || lhsImag.isNaN()) {
Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
return ArrayAttr::get(getContext(), {nanValue, nanValue});
}
// complex.div(a, complex.constant<1.0, 0.0>) -> a
APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
// Fold to LHS if it doesn't contains NaNs or fast math flag nan is set
// complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
if ((lhsArrayAttr && !isLhsComplexHasNan) ||
arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))
return getLhs();
return {};

View File

@ -327,8 +327,8 @@ func.func @div_one_f128() -> complex<f128> {
return %div : complex<f128>
}
// CHECK-LABEL: div_op_with_rhs_has_nan
func.func @div_op_with_rhs_has_nan() -> complex<f32> {
// CHECK-LABEL: div_op_with_rhs_has_nan_real
func.func @div_op_with_rhs_has_nan_real() -> complex<f32> {
%a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
%div = complex.div %a, %b : complex<f32>
@ -336,3 +336,45 @@ func.func @div_op_with_rhs_has_nan() -> complex<f32> {
// CHECK: return %[[DIV]] : complex<f32>
return %div : complex<f32>
}
// CHECK-LABEL: div_op_with_rhs_has_nan_imag
func.func @div_op_with_rhs_has_nan_imag() -> complex<f32> {
%a = complex.constant [1.0 : f32, 0x7fffffff : f32]: complex<f32>
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
%div = complex.div %a, %b : complex<f32>
// CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
// CHECK: return %[[DIV]] : complex<f32>
return %div : complex<f32>
}
// CHECK-LABEL: div_op_with_rhs_has_nan_real_imag
func.func @div_op_with_rhs_has_nan_real_imag() -> complex<f32> {
%a = complex.constant [0x7fffffff : f32, 0x7fffffff : f32]: complex<f32>
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
%div = complex.div %a, %b : complex<f32>
// CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
// CHECK: return %[[DIV]] : complex<f32>
return %div : complex<f32>
}
// CHECK-LABEL: div_op_non_constant_lhs_with_fast_math
func.func @div_op_non_constant_lhs_with_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
%a = complex.create %arg0, %arg1 : complex<f32>
%b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%div = complex.div %a, %b fastmath<nnan> : complex<f32>
// CHECK: %[[COMPLEX:.*]] = complex.create %arg0, %arg1 : complex<f32>
// CHECK: return %[[COMPLEX]] : complex<f32>
return %div: complex<f32>
}
// CHECK-LABEL: div_op_non_constant_lhs_without_fast_math
func.func @div_op_non_constant_lhs_without_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
%a = complex.create %arg0, %arg1 : complex<f32>
%b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%div = complex.div %a, %b : complex<f32>
// CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
// CHECK: %[[A:.*]] = complex.create %arg0, %arg1 : complex<f32>
// CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
// CHECK: return %[[DIV]] : complex<f32>
return %div: complex<f32>
}