[Flang] Add new ConvertComplexPow pass for Flang (#158642)

This PR introduces a new `ConvertComplexPow` pass for Flang that handles
complex power operations. The change forces lowering to complex.pow
operations when `--math-runtime=precise` is not used, then uses the
`ConvertComplexPow` pass to convert these operations back to library
calls.

- Adds a new `ConvertComplexPow` pass that converts complex.pow ops to
appropriate runtime library calls
- Updates complex power lowering to use `complex.pow` operations by
default instead of direct library calls

#158722 Adds a new `complex.powi` op enabling algebraic optimisations.
This commit is contained in:
Akash Banerjee 2025-09-19 01:51:10 +01:00 committed by GitHub
parent 01fca01d3b
commit 54677d66c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 307 additions and 41 deletions

View File

@ -555,6 +555,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}
def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
let summary = "Convert complex.pow operations to library calls";
let description = [{
Replace `complex.pow` operations with calls to the appropriate
Fortran runtime or libm functions.
}];
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
"mlir::complex::ComplexDialect",
"mlir::arith::ArithDialect"];
}
def OptimizeArrayRepacking
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
let summary = "Optimizes redundant array repacking operations";

View File

@ -135,6 +135,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments.
bool EnableOpenMP = false; ///< Enable OpenMP lowering.
bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode.
bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion.
std::string InstrumentFunctionEntry =
""; ///< Name of the instrument-function that is called on each
///< function-entry

View File

@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() {
pm.enableVerifier(/*verifyPasses=*/true);
MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
fir::registerDefaultInlinerPass(config);
if (auto vsr = getVScaleRange(ci)) {

View File

@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
if (!isAMDGPU)
if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
mlir::Value complexExp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
mlir::Value result =
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
mlir::Value exp = args[1];
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
exp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
}
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
genComplexMathOp<mlir::complex::PowOp>},
genComplexPow},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
genComplexMathOp<mlir::complex::PowOp>},
genComplexPow},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
genLibF128Call},
genComplexPow},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
genComplexPow},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
genComplexPow},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),

View File

@ -226,6 +226,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(mlir::createCanonicalizerPass(config));
pm.addPass(fir::createSimplifyRegionLite());
if (!pc.SkipConvertComplexPow)
pm.addPass(fir::createConvertComplexPow());
pm.addPass(mlir::createCSEPass());
if (pc.OptLevel.isOptimizingForSpeed())

View File

@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp
OptimizeArrayRepacking.cpp
ConvertComplexPow.cpp
DEPENDS
CUFAttrs

View File

@ -0,0 +1,123 @@
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Common/static-multimap-view.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace fir {
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
using namespace mlir;
namespace {
class ConvertComplexPowPass
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
arith::ArithDialect, func::FuncDialect>();
}
void runOnOperation() override;
};
} // namespace
// Helper to declare or get a math library function.
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
StringRef name, FunctionType type) {
if (auto func = builder.getNamedFunction(name))
return func;
auto func = builder.createFunction(loc, name, type);
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
builder.getUnitAttr());
return func;
}
static bool isZero(Value v) {
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
return attr.getValue().isZero();
return false;
}
void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
mod.walk([&](complex::PowOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();
Value base = op.getLhs();
Value rhs = op.getRhs();
Value intExp;
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
if (isZero(create.getImaginary())) {
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
intExp = conv.getValue();
}
}
}
func::FuncOp callee;
SmallVector<Value> args;
if (intExp) {
unsigned realBits = cast<FloatType>(elemTy).getWidth();
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
auto funcTy = builder.getFunctionType(
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
if (realBits == 32 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
else if (realBits == 32 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
else if (realBits == 64 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
else if (realBits == 64 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
else if (realBits == 128 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
else if (realBits == 128 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
else
return;
args = {base, intExp};
} else {
unsigned realBits = cast<FloatType>(elemTy).getWidth();
auto funcTy =
builder.getFunctionType({complexTy, complexTy}, {complexTy});
if (realBits == 32)
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
else if (realBits == 64)
callee = getOrDeclare(builder, loc, "cpow", funcTy);
else if (realBits == 128)
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
else
return;
args = {base, rhs};
}
auto call = fir::CallOp::create(builder, loc, callee, args);
if (auto fmf = op.getFastmathAttr())
call.setFastmathAttr(fmf);
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
}

View File

@ -69,6 +69,7 @@ end program
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

View File

@ -96,6 +96,7 @@ end program
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

View File

@ -127,6 +127,7 @@ end program
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

View File

@ -125,6 +125,7 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

View File

@ -168,7 +168,7 @@ end subroutine
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
subroutine real_to_int_power(x, y, z)
@ -193,7 +193,7 @@ end subroutine
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
! CHECK: %[[VAL_8:.*]] = complex.pow
subroutine extremum(c, n, l)
integer(8), intent(in) :: l

View File

@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a, b
b = a ** b
end

View File

@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b

View File

@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b

View File

@ -1,10 +1,10 @@
! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
! Test power operation lowering
@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
! CHECK: call @_FortranAcpowi
! CHECK: complex.pow
! PRECISE: fir.call @_FortranAcpowi
end subroutine
! CHECK-LABEL: pow_c4_i8
@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
! CHECK: call @_FortranAcpowk
! CHECK: complex.pow
! PRECISE: fir.call @_FortranAcpowk
end subroutine
! CHECK-LABEL: pow_c8_i4
@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
! CHECK: call @_FortranAzpowi
! CHECK: complex.pow
! PRECISE: fir.call @_FortranAzpowi
end subroutine
! CHECK-LABEL: pow_c8_i8
@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
! CHECK: call @_FortranAzpowk
! CHECK: complex.pow
! PRECISE: fir.call @_FortranAzpowk
end subroutine
! CHECK-LABEL: pow_c4_c4
subroutine pow_c4_c4(x, y, z)
complex :: x, y, z
z = x ** y
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
! PRECISE: call @cpowf
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
! PRECISE: fir.call @cpowf
end subroutine
! CHECK-LABEL: pow_c8_c8
subroutine pow_c8_c8(x, y, z)
complex(8) :: x, y, z
z = x ** y
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: call @cpow
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine

View File

@ -0,0 +1,111 @@
// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i32) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
}
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i64) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
}
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i32) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
}
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i64) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
}
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i32) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
}
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i64) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
}
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
%c1 = arith.constant 1.0 : f32
%0 = complex.create %arg1, %c1 : complex<f32>
%1 = complex.pow %arg0, %0 fastmath<fast> : complex<f32>
return %1 : complex<f32>
}
func.func @pow_c8_complex(%arg0: complex<f64>, %arg1: f64) -> complex<f64> {
%c2 = arith.constant 2.0 : f64
%0 = complex.create %arg1, %c2 : complex<f64>
%1 = complex.pow %arg0, %0 : complex<f64>
return %1 : complex<f64>
}
func.func @pow_c16_complex(%arg0: complex<f128>, %arg1: f128) -> complex<f128> {
%c3 = arith.constant 3.0 : f128
%0 = complex.create %arg1, %c3 : complex<f128>
%1 = complex.pow %arg0, %0 : complex<f128>
return %1 : complex<f128>
}
}
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c4_fast(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath<fast> : (complex<f32>, complex<f32>) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c8_complex(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f64>
// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex<f64>, complex<f64>) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-LABEL: func.func @pow_c16_complex(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
// CHECK-NOT: complex.pow

View File

@ -538,6 +538,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
// Add O2 optimizer pass pipeline.
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN();
if (enableOpenMP)
config.EnableOpenMP = true;
config.NSWOnLoopVarInc = !integerWrapAround;

View File

@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
LogicalResult matchAndRewrite(complex::PowOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
Value exp = complex::ExpOp::create(rewriter, loc, mul);
auto fastmath = op.getFastmathAttr();
Value logBase =
complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath);
Value mul =
complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath);
Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath);
rewriter.replaceOp(op, exp);
return success();
}