[Flang] Add new ConvertComplexPow pass for Flang (#158642)
This PR introduces a new `ConvertComplexPow` pass for Flang that handles complex power operations. The change forces lowering to complex.pow operations when `--math-runtime=precise` is not used, then uses the `ConvertComplexPow` pass to convert these operations back to library calls. - Adds a new `ConvertComplexPow` pass that converts complex.pow ops to appropriate runtime library calls - Updates complex power lowering to use `complex.pow` operations by default instead of direct library calls #158722 Adds a new `complex.powi` op enabling algebraic optimisations.
This commit is contained in:
parent
01fca01d3b
commit
54677d66c4
@ -555,6 +555,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
|
||||
"Prefer expanding without using Fortran runtime calls.">];
|
||||
}
|
||||
|
||||
def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
|
||||
let summary = "Convert complex.pow operations to library calls";
|
||||
let description = [{
|
||||
Replace `complex.pow` operations with calls to the appropriate
|
||||
Fortran runtime or libm functions.
|
||||
}];
|
||||
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
|
||||
"mlir::complex::ComplexDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def OptimizeArrayRepacking
|
||||
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
|
||||
let summary = "Optimizes redundant array repacking operations";
|
||||
|
||||
@ -135,6 +135,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
|
||||
bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments.
|
||||
bool EnableOpenMP = false; ///< Enable OpenMP lowering.
|
||||
bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode.
|
||||
bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion.
|
||||
std::string InstrumentFunctionEntry =
|
||||
""; ///< Name of the instrument-function that is called on each
|
||||
///< function-entry
|
||||
|
||||
@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() {
|
||||
pm.enableVerifier(/*verifyPasses=*/true);
|
||||
|
||||
MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
|
||||
llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
|
||||
config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
|
||||
fir::registerDefaultInlinerPass(config);
|
||||
|
||||
if (auto vsr = getVScaleRange(ci)) {
|
||||
|
||||
@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
const MathOperation &mathOp,
|
||||
mlir::FunctionType mathLibFuncType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
|
||||
if (!isAMDGPU)
|
||||
if (mathRuntimeVersion == preciseVersion)
|
||||
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
|
||||
|
||||
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
|
||||
auto realTy = complexTy.getElementType();
|
||||
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
|
||||
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
|
||||
mlir::Value complexExp =
|
||||
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
|
||||
mlir::Value result =
|
||||
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
|
||||
mlir::Value exp = args[1];
|
||||
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
|
||||
auto realTy = complexTy.getElementType();
|
||||
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
|
||||
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
|
||||
exp =
|
||||
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
|
||||
}
|
||||
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
|
||||
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
|
||||
return result;
|
||||
}
|
||||
@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
|
||||
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
|
||||
{"pow", "cpowf",
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
|
||||
genComplexMathOp<mlir::complex::PowOp>},
|
||||
genComplexPow},
|
||||
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
|
||||
genComplexMathOp<mlir::complex::PowOp>},
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
|
||||
genLibF128Call},
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(FPow4i),
|
||||
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
|
||||
genMathOp<mlir::math::FPowIOp>},
|
||||
@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
|
||||
genLibF128Call},
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(cpowk),
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
|
||||
genLibF128Call},
|
||||
genComplexPow},
|
||||
{"pow-unsigned", RTNAME_STRING(UPow1),
|
||||
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
|
||||
{"pow-unsigned", RTNAME_STRING(UPow2),
|
||||
|
||||
@ -226,6 +226,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
|
||||
|
||||
pm.addPass(mlir::createCanonicalizerPass(config));
|
||||
pm.addPass(fir::createSimplifyRegionLite());
|
||||
if (!pc.SkipConvertComplexPow)
|
||||
pm.addPass(fir::createConvertComplexPow());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
|
||||
if (pc.OptLevel.isOptimizingForSpeed())
|
||||
|
||||
@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
|
||||
GenRuntimeCallsForTest.cpp
|
||||
SimplifyFIROperations.cpp
|
||||
OptimizeArrayRepacking.cpp
|
||||
ConvertComplexPow.cpp
|
||||
|
||||
DEPENDS
|
||||
CUFAttrs
|
||||
|
||||
123
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
Normal file
123
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Common/static-multimap-view.h"
|
||||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
||||
#include "flang/Optimizer/Transforms/Passes.h"
|
||||
#include "flang/Runtime/entry-names.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace fir {
|
||||
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
|
||||
#include "flang/Optimizer/Transforms/Passes.h.inc"
|
||||
} // namespace fir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class ConvertComplexPowPass
|
||||
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
|
||||
arith::ArithDialect, func::FuncDialect>();
|
||||
}
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Helper to declare or get a math library function.
|
||||
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
|
||||
StringRef name, FunctionType type) {
|
||||
if (auto func = builder.getNamedFunction(name))
|
||||
return func;
|
||||
auto func = builder.createFunction(loc, name, type);
|
||||
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
|
||||
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
|
||||
builder.getUnitAttr());
|
||||
return func;
|
||||
}
|
||||
|
||||
static bool isZero(Value v) {
|
||||
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
|
||||
return attr.getValue().isZero();
|
||||
return false;
|
||||
}
|
||||
|
||||
void ConvertComplexPowPass::runOnOperation() {
|
||||
ModuleOp mod = getOperation();
|
||||
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
|
||||
|
||||
mod.walk([&](complex::PowOp op) {
|
||||
builder.setInsertionPoint(op);
|
||||
Location loc = op.getLoc();
|
||||
auto complexTy = cast<ComplexType>(op.getType());
|
||||
auto elemTy = complexTy.getElementType();
|
||||
|
||||
Value base = op.getLhs();
|
||||
Value rhs = op.getRhs();
|
||||
|
||||
Value intExp;
|
||||
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
|
||||
if (isZero(create.getImaginary())) {
|
||||
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
|
||||
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
|
||||
intExp = conv.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func::FuncOp callee;
|
||||
SmallVector<Value> args;
|
||||
if (intExp) {
|
||||
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
||||
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
|
||||
auto funcTy = builder.getFunctionType(
|
||||
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
|
||||
if (realBits == 32 && intBits == 32)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
|
||||
else if (realBits == 32 && intBits == 64)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
|
||||
else if (realBits == 64 && intBits == 32)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
|
||||
else if (realBits == 64 && intBits == 64)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
|
||||
else if (realBits == 128 && intBits == 32)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
|
||||
else if (realBits == 128 && intBits == 64)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
|
||||
else
|
||||
return;
|
||||
args = {base, intExp};
|
||||
} else {
|
||||
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
||||
auto funcTy =
|
||||
builder.getFunctionType({complexTy, complexTy}, {complexTy});
|
||||
if (realBits == 32)
|
||||
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
|
||||
else if (realBits == 64)
|
||||
callee = getOrDeclare(builder, loc, "cpow", funcTy);
|
||||
else if (realBits == 128)
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
|
||||
else
|
||||
return;
|
||||
args = {base, rhs};
|
||||
}
|
||||
|
||||
auto call = fir::CallOp::create(builder, loc, callee, args);
|
||||
if (auto fmf = op.getFastmathAttr())
|
||||
call.setFastmathAttr(fmf);
|
||||
op.replaceAllUsesWith(call.getResult(0));
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
@ -69,6 +69,7 @@ end program
|
||||
! CHECK-NEXT: SCFToControlFlow
|
||||
! CHECK-NEXT: Canonicalizer
|
||||
! CHECK-NEXT: SimplifyRegionLite
|
||||
! CHECK-NEXT: ConvertComplexPow
|
||||
! CHECK-NEXT: CSE
|
||||
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
|
||||
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
|
||||
|
||||
@ -96,6 +96,7 @@ end program
|
||||
! ALL-NEXT: SCFToControlFlow
|
||||
! ALL-NEXT: Canonicalizer
|
||||
! ALL-NEXT: SimplifyRegionLite
|
||||
! ALL-NEXT: ConvertComplexPow
|
||||
! ALL-NEXT: CSE
|
||||
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
|
||||
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
|
||||
|
||||
@ -127,6 +127,7 @@ end program
|
||||
! ALL-NEXT: SCFToControlFlow
|
||||
! ALL-NEXT: Canonicalizer
|
||||
! ALL-NEXT: SimplifyRegionLite
|
||||
! ALL-NEXT: ConvertComplexPow
|
||||
! ALL-NEXT: CSE
|
||||
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
|
||||
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
|
||||
|
||||
@ -125,6 +125,7 @@ func.func @_QQmain() {
|
||||
// PASSES-NEXT: SCFToControlFlow
|
||||
// PASSES-NEXT: Canonicalizer
|
||||
// PASSES-NEXT: SimplifyRegionLite
|
||||
// PASSES-NEXT: ConvertComplexPow
|
||||
// PASSES-NEXT: CSE
|
||||
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
|
||||
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
|
||||
|
||||
@ -168,7 +168,7 @@ end subroutine
|
||||
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
|
||||
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
|
||||
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
|
||||
! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
|
||||
! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
|
||||
|
||||
|
||||
subroutine real_to_int_power(x, y, z)
|
||||
@ -193,7 +193,7 @@ end subroutine
|
||||
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
|
||||
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
|
||||
! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
|
||||
! CHECK: %[[VAL_8:.*]] = complex.pow
|
||||
|
||||
subroutine extremum(c, n, l)
|
||||
integer(8), intent(in) :: l
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
! REQUIRES: flang-supports-f128-math
|
||||
! RUN: bbc -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
|
||||
! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
complex(16) :: a, b
|
||||
b = a ** b
|
||||
end
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
! REQUIRES: flang-supports-f128-math
|
||||
! RUN: bbc -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
|
||||
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
complex(16) :: a
|
||||
integer(4) :: b
|
||||
b = a ** b
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
! REQUIRES: flang-supports-f128-math
|
||||
! RUN: bbc -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
|
||||
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
complex(16) :: a
|
||||
integer(8) :: b
|
||||
b = a ** b
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
|
||||
! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
|
||||
! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
|
||||
! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
|
||||
! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
|
||||
! RUN: bbc -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
|
||||
! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
|
||||
! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
|
||||
! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
|
||||
|
||||
! Test power operation lowering
|
||||
|
||||
@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
|
||||
complex :: x, z
|
||||
integer :: y
|
||||
z = x ** y
|
||||
! CHECK: call @_FortranAcpowi
|
||||
! CHECK: complex.pow
|
||||
! PRECISE: fir.call @_FortranAcpowi
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: pow_c4_i8
|
||||
@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
|
||||
complex :: x, z
|
||||
integer(8) :: y
|
||||
z = x ** y
|
||||
! CHECK: call @_FortranAcpowk
|
||||
! CHECK: complex.pow
|
||||
! PRECISE: fir.call @_FortranAcpowk
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: pow_c8_i4
|
||||
@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
|
||||
complex(8) :: x, z
|
||||
integer :: y
|
||||
z = x ** y
|
||||
! CHECK: call @_FortranAzpowi
|
||||
! CHECK: complex.pow
|
||||
! PRECISE: fir.call @_FortranAzpowi
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: pow_c8_i8
|
||||
@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
|
||||
complex(8) :: x, z
|
||||
integer(8) :: y
|
||||
z = x ** y
|
||||
! CHECK: call @_FortranAzpowk
|
||||
! CHECK: complex.pow
|
||||
! PRECISE: fir.call @_FortranAzpowk
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: pow_c4_c4
|
||||
subroutine pow_c4_c4(x, y, z)
|
||||
complex :: x, y, z
|
||||
z = x ** y
|
||||
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
|
||||
! PRECISE: call @cpowf
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
|
||||
! PRECISE: fir.call @cpowf
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: pow_c8_c8
|
||||
subroutine pow_c8_c8(x, y, z)
|
||||
complex(8) :: x, y, z
|
||||
z = x ** y
|
||||
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
|
||||
! PRECISE: call @cpow
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
|
||||
! PRECISE: fir.call @cpow
|
||||
end subroutine
|
||||
|
||||
|
||||
111
flang/test/Transforms/convert-complex-pow.fir
Normal file
111
flang/test/Transforms/convert-complex-pow.fir
Normal file
@ -0,0 +1,111 @@
|
||||
// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%0 = fir.convert %arg1 : (i32) -> f32
|
||||
%1 = complex.create %0, %c0 : complex<f32>
|
||||
%2 = complex.pow %arg0, %1 : complex<f32>
|
||||
return %2 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%0 = fir.convert %arg1 : (i64) -> f32
|
||||
%1 = complex.create %0, %c0 : complex<f32>
|
||||
%2 = complex.pow %arg0, %1 : complex<f32>
|
||||
return %2 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
|
||||
%c0 = arith.constant 0.0 : f64
|
||||
%0 = fir.convert %arg1 : (i32) -> f64
|
||||
%1 = complex.create %0, %c0 : complex<f64>
|
||||
%2 = complex.pow %arg0, %1 : complex<f64>
|
||||
return %2 : complex<f64>
|
||||
}
|
||||
|
||||
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
|
||||
%c0 = arith.constant 0.0 : f64
|
||||
%0 = fir.convert %arg1 : (i64) -> f64
|
||||
%1 = complex.create %0, %c0 : complex<f64>
|
||||
%2 = complex.pow %arg0, %1 : complex<f64>
|
||||
return %2 : complex<f64>
|
||||
}
|
||||
|
||||
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
|
||||
%c0 = arith.constant 0.0 : f128
|
||||
%0 = fir.convert %arg1 : (i32) -> f128
|
||||
%1 = complex.create %0, %c0 : complex<f128>
|
||||
%2 = complex.pow %arg0, %1 : complex<f128>
|
||||
return %2 : complex<f128>
|
||||
}
|
||||
|
||||
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
|
||||
%c0 = arith.constant 0.0 : f128
|
||||
%0 = fir.convert %arg1 : (i64) -> f128
|
||||
%1 = complex.create %0, %c0 : complex<f128>
|
||||
%2 = complex.pow %arg0, %1 : complex<f128>
|
||||
return %2 : complex<f128>
|
||||
}
|
||||
|
||||
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%0 = complex.create %arg1, %c1 : complex<f32>
|
||||
%1 = complex.pow %arg0, %0 fastmath<fast> : complex<f32>
|
||||
return %1 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c8_complex(%arg0: complex<f64>, %arg1: f64) -> complex<f64> {
|
||||
%c2 = arith.constant 2.0 : f64
|
||||
%0 = complex.create %arg1, %c2 : complex<f64>
|
||||
%1 = complex.pow %arg0, %0 : complex<f64>
|
||||
return %1 : complex<f64>
|
||||
}
|
||||
|
||||
func.func @pow_c16_complex(%arg0: complex<f128>, %arg1: f128) -> complex<f128> {
|
||||
%c3 = arith.constant 3.0 : f128
|
||||
%0 = complex.create %arg1, %c3 : complex<f128>
|
||||
%1 = complex.pow %arg0, %0 : complex<f128>
|
||||
return %1 : complex<f128>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_i4(
|
||||
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_i8(
|
||||
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c8_i4(
|
||||
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c8_i8(
|
||||
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c16_i4(
|
||||
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c16_i8(
|
||||
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_fast(
|
||||
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
|
||||
// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath<fast> : (complex<f32>, complex<f32>) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c8_complex(
|
||||
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f64>
|
||||
// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex<f64>, complex<f64>) -> complex<f64>
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c16_complex(
|
||||
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
|
||||
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
@ -538,6 +538,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
|
||||
|
||||
// Add O2 optimizer pass pipeline.
|
||||
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
|
||||
config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN();
|
||||
if (enableOpenMP)
|
||||
config.EnableOpenMP = true;
|
||||
config.NSWOnLoopVarInc = !integerWrapAround;
|
||||
|
||||
@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
|
||||
LogicalResult matchAndRewrite(complex::PowOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
|
||||
Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
|
||||
Value exp = complex::ExpOp::create(rewriter, loc, mul);
|
||||
auto fastmath = op.getFastmathAttr();
|
||||
Value logBase =
|
||||
complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath);
|
||||
Value mul =
|
||||
complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath);
|
||||
Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath);
|
||||
rewriter.replaceOp(op, exp);
|
||||
return success();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user