[MLIR] Add cpow support in ComplexToROCDLLibraryCalls (#153183)
This PR adds support for complex power operations (`cpow`) in the `ComplexToROCDLLibraryCalls` conversion pass, specifically targeting AMDGPU architectures. The implementation optimises complex exponentiation by using mathematical identities and special-case handling for small integer powers. - Force lowering to `complex.pow` operations for the `amdgcn-amd-amdhsa` target instead of using library calls - Convert `complex.pow(z, w)` to `complex.exp(w * complex.log(z))` using mathematical identity
This commit is contained in:
parent
65de318d18
commit
d69ccded4f
@ -1287,6 +1287,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
const MathOperation &mathOp,
|
||||
mlir::FunctionType mathLibFuncType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
|
||||
if (!isAMDGPU)
|
||||
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
|
||||
|
||||
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
|
||||
auto realTy = complexTy.getElementType();
|
||||
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
|
||||
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
|
||||
mlir::Value complexExp =
|
||||
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
|
||||
mlir::Value result =
|
||||
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
|
||||
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Mapping between mathematical intrinsic operations and MLIR operations
|
||||
/// of some appropriate dialect (math, complex, etc.) or libm calls.
|
||||
/// TODO: support remaining Fortran math intrinsics.
|
||||
@ -1636,15 +1656,19 @@ static constexpr MathOperation mathOperations[] = {
|
||||
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
|
||||
genMathOp<mlir::math::FPowIOp>},
|
||||
{"pow", RTNAME_STRING(cpowi),
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(zpowi),
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
|
||||
genLibF128Call},
|
||||
{"pow", RTNAME_STRING(cpowk),
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(zpowk),
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
|
||||
genLibF128Call},
|
||||
{"remainder", "remainderf",
|
||||
@ -4044,14 +4068,13 @@ void IntrinsicLibrary::genExecuteCommandLine(
|
||||
mlir::Value waitAddr = fir::getBase(wait);
|
||||
mlir::Value waitIsPresentAtRuntime =
|
||||
builder.genIsNotNullAddr(loc, waitAddr);
|
||||
waitBool = builder
|
||||
waitBool =
|
||||
builder
|
||||
.genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
|
||||
/*withElseRegion=*/true)
|
||||
.genThen([&]() {
|
||||
auto waitLoad =
|
||||
fir::LoadOp::create(builder, loc, waitAddr);
|
||||
mlir::Value cast =
|
||||
builder.createConvert(loc, i1Ty, waitLoad);
|
||||
auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
|
||||
mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
|
||||
fir::ResultOp::create(builder, loc, cast);
|
||||
})
|
||||
.genElse([&]() {
|
||||
|
@ -1,21 +1,27 @@
|
||||
! REQUIRES: amdgpu-registered-target
|
||||
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
|
||||
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! CHECK-LABEL: func @_QPcabsf_test(
|
||||
! CHECK: complex.abs
|
||||
! CHECK-NOT: fir.call @cabsf
|
||||
subroutine cabsf_test(a, b)
|
||||
complex :: a
|
||||
real :: b
|
||||
b = abs(a)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func @_QPcabsf_test(
|
||||
! CHECK: complex.abs
|
||||
! CHECK-NOT: fir.call @cabsf
|
||||
|
||||
! CHECK-LABEL: func @_QPcexpf_test(
|
||||
! CHECK: complex.exp
|
||||
! CHECK-NOT: fir.call @cexpf
|
||||
subroutine cexpf_test(a, b)
|
||||
complex :: a, b
|
||||
b = exp(a)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func @_QPcexpf_test(
|
||||
! CHECK: complex.exp
|
||||
! CHECK-NOT: fir.call @cexpf
|
||||
! CHECK-LABEL: func @_QPpow_test(
|
||||
! CHECK: complex.pow
|
||||
! CHECK-NOT: fir.call @_FortranAcpowi
|
||||
subroutine pow_test(a, b, c)
|
||||
complex :: a, b, c
|
||||
a = b**c
|
||||
end subroutine pow_test
|
||||
|
@ -56,10 +56,26 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
|
||||
private:
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
|
||||
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
|
||||
using OpRewritePattern<complex::PowOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(complex::PowOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
|
||||
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
|
||||
Value exp = rewriter.create<complex::ExpOp>(loc, mul);
|
||||
rewriter.replaceOp(op, exp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
|
||||
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
|
||||
patterns.getContext(), "__ocml_cabs_f32");
|
||||
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
|
||||
@ -110,9 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalOp<complex::MulOp>();
|
||||
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
|
||||
complex::LogOp, complex::SinOp, complex::SqrtOp,
|
||||
complex::TanOp, complex::TanhOp>();
|
||||
complex::LogOp, complex::PowOp, complex::SinOp,
|
||||
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
|
||||
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
|
||||
|
||||
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
|
||||
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
|
||||
@ -57,6 +57,17 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
|
||||
return %lf, %ld : complex<f32>, complex<f64>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @pow_caller
|
||||
//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
|
||||
func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
|
||||
// CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
|
||||
// CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
|
||||
// CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
|
||||
// CHECK: return %[[EXP]]
|
||||
%r = complex.pow %z, %w : complex<f32>
|
||||
return %r : complex<f32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @sin_caller
|
||||
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
|
||||
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
|
||||
|
Loading…
x
Reference in New Issue
Block a user