[flang][cuda] Update some bind name to fast version and add __sincosf (#153744)

Use the fast version in the bind name and reorder these fast math
functions. Add missing __sincosf interface.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-08-15 11:07:15 -07:00 committed by GitHub
parent ed6d505fab
commit 3720d8b52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 54 deletions

View File

@ -394,11 +394,25 @@ implicit none
end interface end interface
interface interface
attributes(device) real(4) function __cosf(x) bind(c, name='__nv_cosf') attributes(device) real(4) function __cosf(x) bind(c, name='__nv_fast_cosf')
real(4), value :: x real(4), value :: x
end function end function
end interface end interface
interface __exp10f
attributes(device) real function __exp10f(r) bind(c, name='__nv_fast_exp10f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __expf
attributes(device) real function __expf(r) bind(c, name='__nv_fast_expf')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __fdividef interface __fdividef
attributes(device) real function __fdividef(r,d) bind(c, name='__nv_fast_fdividef') attributes(device) real function __fdividef(r,d) bind(c, name='__nv_fast_fdividef')
!dir$ ignore_tkr (d) r, (d) d !dir$ ignore_tkr (d) r, (d) d
@ -406,15 +420,51 @@ implicit none
end function end function
end interface end interface
interface __log10f
attributes(device) real function __log10f(r) bind(c, name='__nv_fast_log10f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __log2f
attributes(device) real function __log2f(r) bind(c, name='__nv_fast_log2f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __logf
attributes(device) real function __logf(r) bind(c, name='__nv_fast_logf')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface
attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_fast_powf')
!dir$ ignore_tkr (d) x, y
real(4), value :: x, y
end function
end interface
interface __sincosf
attributes(device) subroutine __sincosf(r, s, c) bind(c, name='__nv_fast_sincosf')
!dir$ ignore_tkr (d) r, (d) s, (d) c
real, value :: r
real :: s, c
end subroutine
end interface
interface __sinf interface __sinf
attributes(device) real function __sinf(r) bind(c, name='__nv_sinf') attributes(device) real function __sinf(r) bind(c, name='__nv_fast_sinf')
!dir$ ignore_tkr (d) r !dir$ ignore_tkr (d) r
real, value :: r real, value :: r
end function end function
end interface end interface
interface __tanf interface __tanf
attributes(device) real function __tanf(r) bind(c, name='__nv_tanf') attributes(device) real function __tanf(r) bind(c, name='__nv_fast_tanf')
!dir$ ignore_tkr (d) r !dir$ ignore_tkr (d) r
real, value :: r real, value :: r
end function end function
@ -1078,13 +1128,6 @@ implicit none
end function end function
end interface end interface
interface
attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_powf')
!dir$ ignore_tkr (d) x, y
real(4), value :: x, y
end function
end interface
interface __brev interface __brev
attributes(device) integer function __brev(i) bind(c, name='__nv_brev') attributes(device) integer function __brev(i) bind(c, name='__nv_brev')
!dir$ ignore_tkr (d) i !dir$ ignore_tkr (d) i
@ -1944,41 +1987,6 @@ implicit none
end function end function
end interface end interface
interface __log2f
attributes(device) real function __log2f(r) bind(c, name='__nv_log2f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __log10f
attributes(device) real function __log10f(r) bind(c, name='__nv_log10f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __logf
attributes(device) real function __logf(r) bind(c, name='__nv_logf')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __expf
attributes(device) real function __expf(r) bind(c, name='__nv_expf')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
interface __exp10f
attributes(device) real function __exp10f(r) bind(c, name='__nv_exp10f')
!dir$ ignore_tkr (d) r
real, value :: r
end function
end interface
contains contains
attributes(device) subroutine syncthreads() attributes(device) subroutine syncthreads()

View File

@ -140,7 +140,7 @@ end
! CHECK: %{{.*}} = fir.call @__nv_brevll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i64 ! CHECK: %{{.*}} = fir.call @__nv_brevll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i64
! CHECK: %{{.*}} = fir.call @__nv_clz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_clz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @__nv_clzll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_clzll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i32
! CHECK: %{{.*}} = fir.call @__nv_cosf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_cosf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
! CHECK: %{{.*}} = fir.call @__nv_ddiv_rn(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_rn(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
! CHECK: %{{.*}} = fir.call @__nv_ddiv_rz(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_rz(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
! CHECK: %{{.*}} = fir.call @__nv_ddiv_ru(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ddiv_ru(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
@ -159,7 +159,7 @@ end
! CHECK: %{{.*}} = fir.call @__nv_double2uint_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_double2uint_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i32
! CHECK: %{{.*}} = fir.call @__nv_mul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_mul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32
! CHECK: %{{.*}} = fir.call @__nv_umul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32 ! CHECK: %{{.*}} = fir.call @__nv_umul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32
! CHECK: %{{.*}} = fir.call @__nv_powf(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_powf(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
! CHECK: %{{.*}} = fir.call @__nv_ull2double_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64
! CHECK: %{{.*}} = fir.call @__nv_ull2double_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64
! CHECK: %{{.*}} = fir.call @__nv_ull2double_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64 ! CHECK: %{{.*}} = fir.call @__nv_ull2double_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64

View File

@ -83,9 +83,17 @@ attributes(global) subroutine test_log()
end subroutine end subroutine
! CHECK-LABEL: _QPtest_log ! CHECK-LABEL: _QPtest_log
! CHECK: %{{.*}} = fir.call @__nv_logf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_logf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
! CHECK: %{{.*}} = fir.call @__nv_log2f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_log2f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
! CHECK: %{{.*}} = fir.call @__nv_log10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_log10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
attributes(global) subroutine test_sincosf()
real :: r, s, c
call __sincosf(r, s, c)
end subroutine
! CHECK-LABEL: _QPtest_sincosf
! CHECK: fir.call @__nv_fast_sincosf(%{{.*}}, %{{.*}}#0, %{{.*}}#0) proc_attrs<bind_c> fastmath<contract> : (f32, !fir.ref<f32>, !fir.ref<f32>) -> ()
attributes(global) subroutine test_sinf() attributes(global) subroutine test_sinf()
real :: res real :: res
@ -94,7 +102,7 @@ attributes(global) subroutine test_sinf()
end subroutine end subroutine
! CHECK-LABEL: _QPtest_sinf ! CHECK-LABEL: _QPtest_sinf
! CHECK: %{{.*}} = fir.call @__nv_sinf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_sinf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
attributes(global) subroutine test_tanf() attributes(global) subroutine test_tanf()
real :: res real :: res
@ -103,7 +111,7 @@ attributes(global) subroutine test_tanf()
end subroutine end subroutine
! CHECK-LABEL: _QPtest_tanf ! CHECK-LABEL: _QPtest_tanf
! CHECK: %{{.*}} = fir.call @__nv_tanf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_tanf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
attributes(global) subroutine test_exp() attributes(global) subroutine test_exp()
real :: res real :: res
@ -113,8 +121,8 @@ attributes(global) subroutine test_exp()
end subroutine end subroutine
! CHECK-LABEL: _QPtest_exp ! CHECK-LABEL: _QPtest_exp
! CHECK: %{{.*}} = fir.call @__nv_expf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_expf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
! CHECK: %{{.*}} = fir.call @__nv_exp10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32 ! CHECK: %{{.*}} = fir.call @__nv_fast_exp10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
attributes(global) subroutine test_double2ll_rX() attributes(global) subroutine test_double2ll_rX()
integer(8) :: res integer(8) :: res