From 3720d8b52d664c7e3620404d1a2d12cee13677f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Fri, 15 Aug 2025 11:07:15 -0700 Subject: [PATCH] [flang][cuda] Update some bind name to fast version and add __sincosf (#153744) Use the fast version in the bind name and reorder these fast math functions. Add missing __sincosf interface. --- flang/module/cudadevice.f90 | 98 ++++++++++++---------- flang/test/Lower/CUDA/cuda-device-proc.cuf | 4 +- flang/test/Lower/CUDA/cuda-libdevice.cuf | 22 +++-- 3 files changed, 70 insertions(+), 54 deletions(-) diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index d21ee9865899..1598c64db2cb 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -394,11 +394,25 @@ implicit none end interface interface - attributes(device) real(4) function __cosf(x) bind(c, name='__nv_cosf') + attributes(device) real(4) function __cosf(x) bind(c, name='__nv_fast_cosf') real(4), value :: x end function end interface + interface __exp10f + attributes(device) real function __exp10f(r) bind(c, name='__nv_fast_exp10f') + !dir$ ignore_tkr (d) r + real, value :: r + end function + end interface + + interface __expf + attributes(device) real function __expf(r) bind(c, name='__nv_fast_expf') + !dir$ ignore_tkr (d) r + real, value :: r + end function + end interface + interface __fdividef attributes(device) real function __fdividef(r,d) bind(c, name='__nv_fast_fdividef') !dir$ ignore_tkr (d) r, (d) d @@ -406,15 +420,51 @@ implicit none end function end interface + interface __log10f + attributes(device) real function __log10f(r) bind(c, name='__nv_fast_log10f') + !dir$ ignore_tkr (d) r + real, value :: r + end function + end interface + + interface __log2f + attributes(device) real function __log2f(r) bind(c, name='__nv_fast_log2f') + !dir$ ignore_tkr (d) r + real, value :: r + end function + end interface + + interface __logf + attributes(device) real function __logf(r) bind(c, name='__nv_fast_logf') + !dir$ ignore_tkr (d) r + real, value :: r + end function + end interface + + interface + attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_fast_powf') + !dir$ ignore_tkr (d) x, y + real(4), value :: x, y + end function + end interface + + interface __sincosf + attributes(device) subroutine __sincosf(r, s, c) bind(c, name='__nv_fast_sincosf') + !dir$ ignore_tkr (d) r, (d) s, (d) c + real, value :: r + real :: s, c + end subroutine + end interface + interface __sinf - attributes(device) real function __sinf(r) bind(c, name='__nv_sinf') + attributes(device) real function __sinf(r) bind(c, name='__nv_fast_sinf') !dir$ ignore_tkr (d) r real, value :: r end function end interface interface __tanf - attributes(device) real function __tanf(r) bind(c, name='__nv_tanf') + attributes(device) real function __tanf(r) bind(c, name='__nv_fast_tanf') !dir$ ignore_tkr (d) r real, value :: r end function @@ -1078,13 +1128,6 @@ implicit none end function end interface - interface - attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_powf') - !dir$ ignore_tkr (d) x, y - real(4), value :: x, y - end function - end interface - interface __brev attributes(device) integer function __brev(i) bind(c, name='__nv_brev') !dir$ ignore_tkr (d) i @@ -1944,41 +1987,6 @@ implicit none end function end interface - interface __log2f - attributes(device) real function __log2f(r) bind(c, name='__nv_log2f') - !dir$ ignore_tkr (d) r - real, value :: r - end function - end interface - - interface __log10f - attributes(device) real function __log10f(r) bind(c, name='__nv_log10f') - !dir$ ignore_tkr (d) r - real, value :: r - end function - end interface - - interface __logf - attributes(device) real function __logf(r) bind(c, name='__nv_logf') - !dir$ ignore_tkr (d) r - real, value :: r - end function - end interface - - interface __expf - attributes(device) real function __expf(r) bind(c, name='__nv_expf') - !dir$ ignore_tkr (d) r - real, value :: r - end function - end interface - - interface __exp10f - attributes(device) real function __exp10f(r) bind(c, name='__nv_exp10f') - !dir$ ignore_tkr (d) r - real, value :: r - end function - end interface - contains attributes(device) subroutine syncthreads() diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index a6e8c69b2e52..5e1f6b66d1d5 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -140,7 +140,7 @@ end ! CHECK: %{{.*}} = fir.call @__nv_brevll(%{{.*}}) proc_attrs fastmath : (i64) -> i64 ! CHECK: %{{.*}} = fir.call @__nv_clz(%{{.*}}) proc_attrs fastmath : (i32) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_clzll(%{{.*}}) proc_attrs fastmath : (i64) -> i32 -! CHECK: %{{.*}} = fir.call @__nv_cosf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_cosf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_rn(%{{.*}}, %{{.*}}) proc_attrs fastmath : (f64, f64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_rz(%{{.*}}, %{{.*}}) proc_attrs fastmath : (f64, f64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_ru(%{{.*}}, %{{.*}}) proc_attrs fastmath : (f64, f64) -> f64 @@ -159,7 +159,7 @@ end ! CHECK: %{{.*}} = fir.call @__nv_double2uint_rz(%{{.*}}) proc_attrs fastmath : (f64) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_mul24(%{{.*}}, %{{.*}}) proc_attrs fastmath : (i32, i32) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_umul24(%{{.*}}, %{{.*}}) proc_attrs fastmath : (i32, i32) -> i32 -! CHECK: %{{.*}} = fir.call @__nv_powf(%{{.*}}, %{{.*}}) proc_attrs fastmath : (f32, f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_powf(%{{.*}}, %{{.*}}) proc_attrs fastmath : (f32, f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_rd(%{{.*}}) proc_attrs fastmath : (i64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_rn(%{{.*}}) proc_attrs fastmath : (i64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_ru(%{{.*}}) proc_attrs fastmath : (i64) -> f64 diff --git a/flang/test/Lower/CUDA/cuda-libdevice.cuf b/flang/test/Lower/CUDA/cuda-libdevice.cuf index 0e024f06d06d..d243c49f0516 100644 --- a/flang/test/Lower/CUDA/cuda-libdevice.cuf +++ b/flang/test/Lower/CUDA/cuda-libdevice.cuf @@ -83,9 +83,17 @@ attributes(global) subroutine test_log() end subroutine ! CHECK-LABEL: _QPtest_log -! CHECK: %{{.*}} = fir.call @__nv_logf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 -! CHECK: %{{.*}} = fir.call @__nv_log2f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 -! CHECK: %{{.*}} = fir.call @__nv_log10f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_logf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_log2f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_log10f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 + +attributes(global) subroutine test_sincosf() + real :: r, s, c + call __sincosf(r, s, c) +end subroutine + +! CHECK-LABEL: _QPtest_sincosf +! CHECK: fir.call @__nv_fast_sincosf(%{{.*}}, %{{.*}}#0, %{{.*}}#0) proc_attrs fastmath : (f32, !fir.ref, !fir.ref) -> () attributes(global) subroutine test_sinf() real :: res @@ -94,7 +102,7 @@ attributes(global) subroutine test_sinf() end subroutine ! CHECK-LABEL: _QPtest_sinf -! CHECK: %{{.*}} = fir.call @__nv_sinf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_sinf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 attributes(global) subroutine test_tanf() real :: res @@ -103,7 +111,7 @@ attributes(global) subroutine test_tanf() end subroutine ! CHECK-LABEL: _QPtest_tanf -! CHECK: %{{.*}} = fir.call @__nv_tanf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_tanf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 attributes(global) subroutine test_exp() real :: res @@ -113,8 +121,8 @@ attributes(global) subroutine test_exp() end subroutine ! CHECK-LABEL: _QPtest_exp -! CHECK: %{{.*}} = fir.call @__nv_expf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 -! CHECK: %{{.*}} = fir.call @__nv_exp10f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_expf(%{{.*}}) proc_attrs fastmath : (f32) -> f32 +! CHECK: %{{.*}} = fir.call @__nv_fast_exp10f(%{{.*}}) proc_attrs fastmath : (f32) -> f32 attributes(global) subroutine test_double2ll_rX() integer(8) :: res