[flang][cuda] Lower clock() to NNVM op (#149228)

Also use a same gen function for all NVVM time ops.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-07-16 17:24:17 -07:00 committed by GitHub
parent b52cf756ce
commit 4cf7670b01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 14 deletions

View File

@ -282,7 +282,6 @@ struct IntrinsicLibrary {
llvm::ArrayRef<mlir::Value> args);
mlir::Value genGetUID(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args);
mlir::Value genGlobalTimer(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genHostnm(std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args);
fir::ExtendedValue genIall(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@ -377,6 +376,8 @@ struct IntrinsicLibrary {
fir::ExtendedValue genNorm2(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genNot(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
template <typename OpTy>
mlir::Value genNVVMTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genPack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genParity(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
void genPerror(llvm::ArrayRef<fir::ExtendedValue>);

View File

@ -385,6 +385,7 @@ static constexpr IntrinsicHandler handlers[]{
&I::genChdir,
{{{"name", asAddr}, {"status", asAddr, handleDynamicOptional}}},
/*isElemental=*/false},
{"clock", &I::genNVVMTime<mlir::NVVM::ClockOp>, {}, /*isElemental=*/false},
{"clock64", &I::genClock64, {}, /*isElemental=*/false},
{"cmplx",
&I::genCmplx,
@ -503,7 +504,10 @@ static constexpr IntrinsicHandler handlers[]{
{"getgid", &I::genGetGID},
{"getpid", &I::genGetPID},
{"getuid", &I::genGetUID},
{"globaltimer", &I::genGlobalTimer, {}, /*isElemental=*/false},
{"globaltimer",
&I::genNVVMTime<mlir::NVVM::GlobalTimerOp>,
{},
/*isElemental=*/false},
{"hostnm",
&I::genHostnm,
{{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}},
@ -4320,13 +4324,6 @@ mlir::Value IntrinsicLibrary::genGetUID(mlir::Type resultType,
fir::runtime::genGetUID(builder, loc));
}
// GLOBALTIMER
mlir::Value IntrinsicLibrary::genGlobalTimer(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0 && "globalTimer takes no args");
return builder.create<mlir::NVVM::GlobalTimerOp>(loc, resultType).getResult();
}
// GET_COMMAND_ARGUMENT
void IntrinsicLibrary::genGetCommandArgument(
llvm::ArrayRef<fir::ExtendedValue> args) {
@ -7207,6 +7204,14 @@ IntrinsicLibrary::genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) {
return fir::MutableBoxValue(boxStorage, mold->nonDeferredLenParams(), {});
}
// CLOCK, GLOBALTIMER
template <typename OpTy>
mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0 && "expect no arguments");
return builder.create<OpTy>(loc, resultType).getResult();
}
// PACK
fir::ExtendedValue
IntrinsicLibrary::genPack(mlir::Type resultType,

View File

@ -957,11 +957,21 @@ implicit none
! Time function
interface
attributes(device) integer function clock()
end function
end interface
interface
attributes(device) integer(8) function clock64()
end function
end interface
interface
attributes(device) integer(8) function globalTimer()
end function
end interface
! Warp Match Functions
interface match_all_sync
@ -1613,11 +1623,6 @@ implicit none
end function
end interface
interface
attributes(device) integer(8) function globalTimer()
end function
end interface
contains
attributes(device) subroutine syncthreads()

View File

@ -10,6 +10,7 @@ attributes(global) subroutine devsub()
integer(4) :: ai
integer(8) :: al
integer(8) :: time
integer :: smalltime
call syncthreads()
call syncwarp(1)
@ -45,6 +46,7 @@ attributes(global) subroutine devsub()
ai = atomicinc(ai, 1_4)
ai = atomicdec(ai, 1_4)
smalltime = clock()
time = clock64()
time = globalTimer()
@ -84,6 +86,7 @@ end
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32
! CHECK: fir.call @llvm.nvvm.read.ptx.sreg.clock64()
! CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64