[flang][cuda] Lower clock() to NNVM op (#149228)

Also use a same gen function for all NVVM time ops.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-07-16 17:24:17 -07:00 committed by GitHub
parent b52cf756ce
commit 4cf7670b01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 14 deletions

View File

@ -282,7 +282,6 @@ struct IntrinsicLibrary {
llvm::ArrayRef<mlir::Value> args); llvm::ArrayRef<mlir::Value> args);
mlir::Value genGetUID(mlir::Type resultType, mlir::Value genGetUID(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args); llvm::ArrayRef<mlir::Value> args);
mlir::Value genGlobalTimer(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genHostnm(std::optional<mlir::Type> resultType, fir::ExtendedValue genHostnm(std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args); llvm::ArrayRef<fir::ExtendedValue> args);
fir::ExtendedValue genIall(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genIall(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@ -377,6 +376,8 @@ struct IntrinsicLibrary {
fir::ExtendedValue genNorm2(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genNorm2(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genNot(mlir::Type, llvm::ArrayRef<mlir::Value>); mlir::Value genNot(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
template <typename OpTy>
mlir::Value genNVVMTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genPack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genPack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genParity(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genParity(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
void genPerror(llvm::ArrayRef<fir::ExtendedValue>); void genPerror(llvm::ArrayRef<fir::ExtendedValue>);

View File

@ -385,6 +385,7 @@ static constexpr IntrinsicHandler handlers[]{
&I::genChdir, &I::genChdir,
{{{"name", asAddr}, {"status", asAddr, handleDynamicOptional}}}, {{{"name", asAddr}, {"status", asAddr, handleDynamicOptional}}},
/*isElemental=*/false}, /*isElemental=*/false},
{"clock", &I::genNVVMTime<mlir::NVVM::ClockOp>, {}, /*isElemental=*/false},
{"clock64", &I::genClock64, {}, /*isElemental=*/false}, {"clock64", &I::genClock64, {}, /*isElemental=*/false},
{"cmplx", {"cmplx",
&I::genCmplx, &I::genCmplx,
@ -503,7 +504,10 @@ static constexpr IntrinsicHandler handlers[]{
{"getgid", &I::genGetGID}, {"getgid", &I::genGetGID},
{"getpid", &I::genGetPID}, {"getpid", &I::genGetPID},
{"getuid", &I::genGetUID}, {"getuid", &I::genGetUID},
{"globaltimer", &I::genGlobalTimer, {}, /*isElemental=*/false}, {"globaltimer",
&I::genNVVMTime<mlir::NVVM::GlobalTimerOp>,
{},
/*isElemental=*/false},
{"hostnm", {"hostnm",
&I::genHostnm, &I::genHostnm,
{{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}}, {{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}},
@ -4320,13 +4324,6 @@ mlir::Value IntrinsicLibrary::genGetUID(mlir::Type resultType,
fir::runtime::genGetUID(builder, loc)); fir::runtime::genGetUID(builder, loc));
} }
// GLOBALTIMER
mlir::Value IntrinsicLibrary::genGlobalTimer(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0 && "globalTimer takes no args");
return builder.create<mlir::NVVM::GlobalTimerOp>(loc, resultType).getResult();
}
// GET_COMMAND_ARGUMENT // GET_COMMAND_ARGUMENT
void IntrinsicLibrary::genGetCommandArgument( void IntrinsicLibrary::genGetCommandArgument(
llvm::ArrayRef<fir::ExtendedValue> args) { llvm::ArrayRef<fir::ExtendedValue> args) {
@ -7207,6 +7204,14 @@ IntrinsicLibrary::genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) {
return fir::MutableBoxValue(boxStorage, mold->nonDeferredLenParams(), {}); return fir::MutableBoxValue(boxStorage, mold->nonDeferredLenParams(), {});
} }
// CLOCK, GLOBALTIMER
template <typename OpTy>
mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0 && "expect no arguments");
return builder.create<OpTy>(loc, resultType).getResult();
}
// PACK // PACK
fir::ExtendedValue fir::ExtendedValue
IntrinsicLibrary::genPack(mlir::Type resultType, IntrinsicLibrary::genPack(mlir::Type resultType,

View File

@ -957,11 +957,21 @@ implicit none
! Time function ! Time function
interface
attributes(device) integer function clock()
end function
end interface
interface interface
attributes(device) integer(8) function clock64() attributes(device) integer(8) function clock64()
end function end function
end interface end interface
interface
attributes(device) integer(8) function globalTimer()
end function
end interface
! Warp Match Functions ! Warp Match Functions
interface match_all_sync interface match_all_sync
@ -1613,11 +1623,6 @@ implicit none
end function end function
end interface end interface
interface
attributes(device) integer(8) function globalTimer()
end function
end interface
contains contains
attributes(device) subroutine syncthreads() attributes(device) subroutine syncthreads()

View File

@ -10,6 +10,7 @@ attributes(global) subroutine devsub()
integer(4) :: ai integer(4) :: ai
integer(8) :: al integer(8) :: al
integer(8) :: time integer(8) :: time
integer :: smalltime
call syncthreads() call syncthreads()
call syncwarp(1) call syncwarp(1)
@ -45,6 +46,7 @@ attributes(global) subroutine devsub()
ai = atomicinc(ai, 1_4) ai = atomicinc(ai, 1_4)
ai = atomicdec(ai, 1_4) ai = atomicdec(ai, 1_4)
smalltime = clock()
time = clock64() time = clock64()
time = globalTimer() time = globalTimer()
@ -84,6 +86,7 @@ end
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32 ! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32 ! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32
! CHECK: fir.call @llvm.nvvm.read.ptx.sreg.clock64() ! CHECK: fir.call @llvm.nvvm.read.ptx.sreg.clock64()
! CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64 ! CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64