[MLIR][Complex] Check for FastMathFlag in DivOp folder (#176249)
- Fold DivOp with LHS that has NaN as real or imag to Complex of NaNs - Fold `div(a, Complex<1, 0>) -> a` if fast math flag with nnan is set
This commit is contained in:
parent
67fcdc9016
commit
fd4dec9b1e
@ -371,35 +371,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||
auto rhs = adaptor.getRhs();
|
||||
auto lhs = adaptor.getLhs();
|
||||
Attribute rhs = adaptor.getRhs();
|
||||
Attribute lhs = adaptor.getLhs();
|
||||
|
||||
// We can't fold without knowing that LHS isn't NaN
|
||||
if (!rhs || !lhs)
|
||||
return {};
|
||||
// complex.div(complex.constant<NaN, NaN>, a) -> complex.constant<NaN, NaN>
|
||||
// complex.div(complex.constant<NaN, a>, b) -> complex.constant<NaN, NaN>
|
||||
// complex.div(complex.constant<a, NaN>, b) -> complex.constant<NaN, NaN>
|
||||
bool isLhsComplexHasNan = false;
|
||||
ArrayAttr lhsArrayAttr = dyn_cast_if_present<ArrayAttr>(lhs);
|
||||
if (lhsArrayAttr && lhsArrayAttr.size() == 2) {
|
||||
APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
|
||||
APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
|
||||
isLhsComplexHasNan = lhsReal.isNaN() || lhsImag.isNaN();
|
||||
if (isLhsComplexHasNan) {
|
||||
Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
|
||||
return ArrayAttr::get(getContext(), {nanValue, nanValue});
|
||||
}
|
||||
}
|
||||
|
||||
ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
|
||||
ArrayAttr rhsArrayAttr = dyn_cast_if_present<ArrayAttr>(rhs);
|
||||
if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
|
||||
return {};
|
||||
|
||||
ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
|
||||
if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
|
||||
return {};
|
||||
|
||||
// Fold only if RHS is complex.constant<1.0, 0.0>
|
||||
APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
|
||||
if (!rhsImag.isZero())
|
||||
APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
|
||||
if (!rhsImag.isZero() || rhsReal != APFloat(rhsReal.getSemantics(), 1))
|
||||
return {};
|
||||
|
||||
APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
|
||||
APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
|
||||
if (lhsReal.isNaN() || lhsImag.isNaN()) {
|
||||
Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
|
||||
return ArrayAttr::get(getContext(), {nanValue, nanValue});
|
||||
}
|
||||
|
||||
// complex.div(a, complex.constant<1.0, 0.0>) -> a
|
||||
APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
|
||||
if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
|
||||
// Fold to LHS if it doesn't contains NaNs or fast math flag nan is set
|
||||
// complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
|
||||
if ((lhsArrayAttr && !isLhsComplexHasNan) ||
|
||||
arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))
|
||||
return getLhs();
|
||||
|
||||
return {};
|
||||
|
||||
@ -327,8 +327,8 @@ func.func @div_one_f128() -> complex<f128> {
|
||||
return %div : complex<f128>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: div_op_with_rhs_has_nan
|
||||
func.func @div_op_with_rhs_has_nan() -> complex<f32> {
|
||||
// CHECK-LABEL: div_op_with_rhs_has_nan_real
|
||||
func.func @div_op_with_rhs_has_nan_real() -> complex<f32> {
|
||||
%a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
|
||||
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
|
||||
%div = complex.div %a, %b : complex<f32>
|
||||
@ -336,3 +336,45 @@ func.func @div_op_with_rhs_has_nan() -> complex<f32> {
|
||||
// CHECK: return %[[DIV]] : complex<f32>
|
||||
return %div : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: div_op_with_rhs_has_nan_imag
|
||||
func.func @div_op_with_rhs_has_nan_imag() -> complex<f32> {
|
||||
%a = complex.constant [1.0 : f32, 0x7fffffff : f32]: complex<f32>
|
||||
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
|
||||
%div = complex.div %a, %b : complex<f32>
|
||||
// CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
|
||||
// CHECK: return %[[DIV]] : complex<f32>
|
||||
return %div : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: div_op_with_rhs_has_nan_real_imag
|
||||
func.func @div_op_with_rhs_has_nan_real_imag() -> complex<f32> {
|
||||
%a = complex.constant [0x7fffffff : f32, 0x7fffffff : f32]: complex<f32>
|
||||
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
|
||||
%div = complex.div %a, %b : complex<f32>
|
||||
// CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
|
||||
// CHECK: return %[[DIV]] : complex<f32>
|
||||
return %div : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: div_op_non_constant_lhs_with_fast_math
|
||||
func.func @div_op_non_constant_lhs_with_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
|
||||
%a = complex.create %arg0, %arg1 : complex<f32>
|
||||
%b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%div = complex.div %a, %b fastmath<nnan> : complex<f32>
|
||||
// CHECK: %[[COMPLEX:.*]] = complex.create %arg0, %arg1 : complex<f32>
|
||||
// CHECK: return %[[COMPLEX]] : complex<f32>
|
||||
return %div: complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: div_op_non_constant_lhs_without_fast_math
|
||||
func.func @div_op_non_constant_lhs_without_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
|
||||
%a = complex.create %arg0, %arg1 : complex<f32>
|
||||
%b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%div = complex.div %a, %b : complex<f32>
|
||||
// CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
|
||||
// CHECK: %[[A:.*]] = complex.create %arg0, %arg1 : complex<f32>
|
||||
// CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
|
||||
// CHECK: return %[[DIV]] : complex<f32>
|
||||
return %div: complex<f32>
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user