diff --git a/flang/include/flang/Common/Fortran-features.h b/flang/include/flang/Common/Fortran-features.h index 07ed7f43c1e7..f57fcdc895ad 100644 --- a/flang/include/flang/Common/Fortran-features.h +++ b/flang/include/flang/Common/Fortran-features.h @@ -49,7 +49,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines, IndistinguishableSpecifics, SubroutineAndFunctionSpecifics, EmptySequenceType, NonSequenceCrayPointee, BranchIntoConstruct, BadBranchTarget, ConvertedArgument, HollerithPolymorphic, ListDirectedSize, - NonBindCInteroperability) + NonBindCInteroperability, CudaManaged, CudaUnified) // Portability and suspicious usage warnings ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable, @@ -81,6 +81,8 @@ public: disable_.set(LanguageFeature::OpenACC); disable_.set(LanguageFeature::OpenMP); disable_.set(LanguageFeature::CUDA); // !@cuf + disable_.set(LanguageFeature::CudaManaged); + disable_.set(LanguageFeature::CudaUnified); disable_.set(LanguageFeature::ImplicitNoneTypeNever); disable_.set(LanguageFeature::ImplicitNoneTypeAlways); disable_.set(LanguageFeature::DefaultSave); diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h index 3b965fe60c2f..0701e3e8b64c 100644 --- a/flang/include/flang/Common/Fortran.h +++ b/flang/include/flang/Common/Fortran.h @@ -19,6 +19,7 @@ #include namespace Fortran::common { +class LanguageFeatureControl; // Fortran has five kinds of intrinsic data types, plus the derived types. ENUM_CLASS(TypeCategory, Integer, Real, Complex, Character, Logical, Derived) @@ -115,7 +116,8 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind, std::string AsFortran(IgnoreTKRSet); bool AreCompatibleCUDADataAttrs(std::optional, - std::optional, IgnoreTKRSet, bool allowUnifiedMatchingRule); + std::optional, IgnoreTKRSet, bool allowUnifiedMatchingRule, + const LanguageFeatureControl *features = nullptr); static constexpr char blankCommonObjectName[] = "__BLNK__"; diff --git a/flang/lib/Common/Fortran.cpp b/flang/lib/Common/Fortran.cpp index 170ce8c22509..c014b1263a67 100644 --- a/flang/lib/Common/Fortran.cpp +++ b/flang/lib/Common/Fortran.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Common/Fortran.h" +#include "flang/Common/Fortran-features.h" namespace Fortran::common { @@ -102,7 +103,13 @@ std::string AsFortran(IgnoreTKRSet tkr) { /// dummy argument attribute while `y` represents the actual argument attribute. bool AreCompatibleCUDADataAttrs(std::optional x, std::optional y, IgnoreTKRSet ignoreTKR, - bool allowUnifiedMatchingRule) { + bool allowUnifiedMatchingRule, const LanguageFeatureControl *features) { + bool isCudaManaged{features + ? features->IsEnabled(common::LanguageFeature::CudaManaged) + : false}; + bool isCudaUnified{features + ? features->IsEnabled(common::LanguageFeature::CudaUnified) + : false}; if (!x && !y) { return true; } else if (x && y && *x == *y) { @@ -120,19 +127,27 @@ bool AreCompatibleCUDADataAttrs(std::optional x, return true; } else if (allowUnifiedMatchingRule) { if (!x) { // Dummy argument has no attribute -> host - if (y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) { + if ((y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) || + (!y && (isCudaUnified || isCudaManaged))) { return true; } } else { - if (*x == CUDADataAttr::Device && y && - (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) { - return true; - } else if (*x == CUDADataAttr::Managed && y && - *y == CUDADataAttr::Unified) { - return true; - } else if (*x == CUDADataAttr::Unified && y && - *y == CUDADataAttr::Managed) { - return true; + if (*x == CUDADataAttr::Device) { + if ((y && + (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) || + (!y && (isCudaUnified || isCudaManaged))) { + return true; + } + } else if (*x == CUDADataAttr::Managed) { + if ((y && *y == CUDADataAttr::Unified) || + (!y && (isCudaUnified || isCudaManaged))) { + return true; + } + } else if (*x == CUDADataAttr::Unified) { + if ((y && *y == CUDADataAttr::Managed) || + (!y && (isCudaUnified || isCudaManaged))) { + return true; + } } } return false; diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp index 94afcbb68b34..8f51ef5ebeba 100644 --- a/flang/lib/Semantics/check-call.cpp +++ b/flang/lib/Semantics/check-call.cpp @@ -914,7 +914,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, } if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr, dummy.ignoreTKR, - /*allowUnifiedMatchingRule=*/true)) { + /*allowUnifiedMatchingRule=*/true, &context.languageFeatures())) { auto toStr{[](std::optional x) { return x ? "ATTRIBUTES("s + parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index c503ea3f0246..06e38da6626a 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -2501,8 +2501,13 @@ static constexpr int cudaInfMatchingValue{std::numeric_limits::max()}; // Compute the matching distance as described in section 3.2.3 of the CUDA // Fortran references. -static int GetMatchingDistance(const characteristics::DummyArgument &dummy, +static int GetMatchingDistance(const common::LanguageFeatureControl &features, + const characteristics::DummyArgument &dummy, const std::optional &actual) { + bool isCudaManaged{features.IsEnabled(common::LanguageFeature::CudaManaged)}; + bool isCudaUnified{features.IsEnabled(common::LanguageFeature::CudaUnified)}; + CHECK(!(isCudaUnified && isCudaManaged) && "expect only one enabled."); + std::optional actualDataAttr, dummyDataAttr; if (actual) { if (auto *expr{actual->UnwrapExpr()}) { @@ -2529,6 +2534,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy, if (!dummyDataAttr) { if (!actualDataAttr) { + if (isCudaUnified || isCudaManaged) { + return 3; + } return 0; } else if (*actualDataAttr == common::CUDADataAttr::Device) { return cudaInfMatchingValue; @@ -2538,6 +2546,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy, } } else if (*dummyDataAttr == common::CUDADataAttr::Device) { if (!actualDataAttr) { + if (isCudaUnified || isCudaManaged) { + return 2; + } return cudaInfMatchingValue; } else if (*actualDataAttr == common::CUDADataAttr::Device) { return 0; @@ -2546,7 +2557,10 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy, return 2; } } else if (*dummyDataAttr == common::CUDADataAttr::Managed) { - if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) { + if (!actualDataAttr) { + return isCudaUnified ? 1 : isCudaManaged ? 0 : cudaInfMatchingValue; + } + if (*actualDataAttr == common::CUDADataAttr::Device) { return cudaInfMatchingValue; } else if (*actualDataAttr == common::CUDADataAttr::Managed) { return 0; @@ -2554,7 +2568,10 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy, return 1; } } else if (*dummyDataAttr == common::CUDADataAttr::Unified) { - if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) { + if (!actualDataAttr) { + return isCudaUnified ? 0 : isCudaManaged ? 1 : cudaInfMatchingValue; + } + if (*actualDataAttr == common::CUDADataAttr::Device) { return cudaInfMatchingValue; } else if (*actualDataAttr == common::CUDADataAttr::Managed) { return 1; @@ -2566,6 +2583,7 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy, } static int ComputeCudaMatchingDistance( + const common::LanguageFeatureControl &features, const characteristics::Procedure &procedure, const ActualArguments &actuals) { const auto &dummies{procedure.dummyArguments}; @@ -2574,7 +2592,7 @@ static int ComputeCudaMatchingDistance( for (std::size_t i{0}; i < dummies.size(); ++i) { const characteristics::DummyArgument &dummy{dummies[i]}; const std::optional &actual{actuals[i]}; - int d{GetMatchingDistance(dummy, actual)}; + int d{GetMatchingDistance(features, dummy, actual)}; if (d == cudaInfMatchingValue) return d; distance += d; @@ -2666,7 +2684,9 @@ std::pair ExpressionAnalyzer::ResolveGeneric( CheckCompatibleArguments(*procedure, localActuals)) { if ((procedure->IsElemental() && elemental) || (!procedure->IsElemental() && nonElemental)) { - int d{ComputeCudaMatchingDistance(*procedure, localActuals)}; + int d{ComputeCudaMatchingDistance( + context_.languageFeatures(), *procedure, localActuals)}; + llvm::errs() << "matching distance: " << d << "\n"; if (d != crtMatchingDistance) { if (d > crtMatchingDistance) { continue; @@ -2688,8 +2708,8 @@ std::pair ExpressionAnalyzer::ResolveGeneric( } else { elemental = &specific; } - crtMatchingDistance = - ComputeCudaMatchingDistance(*procedure, localActuals); + crtMatchingDistance = ComputeCudaMatchingDistance( + context_.languageFeatures(), *procedure, localActuals); } } } diff --git a/flang/test/Semantics/cuf14.cuf b/flang/test/Semantics/cuf14.cuf new file mode 100644 index 000000000000..29c9ecf90677 --- /dev/null +++ b/flang/test/Semantics/cuf14.cuf @@ -0,0 +1,55 @@ +! RUN: bbc -emit-hlfir -fcuda -gpu=unified %s -o - | FileCheck %s + +module matching + interface host_and_device + module procedure sub_host + module procedure sub_device + end interface + + interface all + module procedure sub_host + module procedure sub_device + module procedure sub_managed + module procedure sub_unified + end interface + + interface all_without_unified + module procedure sub_host + module procedure sub_device + module procedure sub_managed + end interface + +contains + subroutine sub_host(a) + integer :: a(:) + end + + subroutine sub_device(a) + integer, device :: a(:) + end + + subroutine sub_managed(a) + integer, managed :: a(:) + end + + subroutine sub_unified(a) + integer, unified :: a(:) + end +end module + +program m + use matching + + integer, allocatable :: actual_host(:) + + allocate(actual_host(10)) + + call host_and_device(actual_host) ! Should resolve to sub_device + call all(actual_host) ! Should resolved to unified + call all_without_unified(actual_host) ! Should resolved to managed +end + +! CHECK: fir.call @_QMmatchingPsub_device +! CHECK: fir.call @_QMmatchingPsub_unified +! CHECK: fir.call @_QMmatchingPsub_managed + diff --git a/flang/test/Semantics/cuf15.cuf b/flang/test/Semantics/cuf15.cuf new file mode 100644 index 000000000000..030dd6ff8ffe --- /dev/null +++ b/flang/test/Semantics/cuf15.cuf @@ -0,0 +1,55 @@ +! RUN: bbc -emit-hlfir -fcuda -gpu=managed %s -o - | FileCheck %s + +module matching + interface host_and_device + module procedure sub_host + module procedure sub_device + end interface + + interface all + module procedure sub_host + module procedure sub_device + module procedure sub_managed + module procedure sub_unified + end interface + + interface all_without_managed + module procedure sub_host + module procedure sub_device + module procedure sub_unified + end interface + +contains + subroutine sub_host(a) + integer :: a(:) + end + + subroutine sub_device(a) + integer, device :: a(:) + end + + subroutine sub_managed(a) + integer, managed :: a(:) + end + + subroutine sub_unified(a) + integer, unified :: a(:) + end +end module + +program m + use matching + + integer, allocatable :: actual_host(:) + + allocate(actual_host(10)) + + call host_and_device(actual_host) ! Should resolve to sub_device + call all(actual_host) ! Should resolved to unified + call all_without_managed(actual_host) ! Should resolved to managed +end + +! CHECK: fir.call @_QMmatchingPsub_device +! CHECK: fir.call @_QMmatchingPsub_managed +! CHECK: fir.call @_QMmatchingPsub_unified + diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index ee2ff8562e9f..f7092d35eeb5 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -204,6 +204,10 @@ static llvm::cl::opt enableCUDA("fcuda", llvm::cl::desc("enable CUDA Fortran"), llvm::cl::init(false)); +static llvm::cl::opt + enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"), + llvm::cl::init("")); + static llvm::cl::opt fixedForm("ffixed-form", llvm::cl::desc("enable fixed form"), llvm::cl::init(false)); @@ -495,6 +499,12 @@ int main(int argc, char **argv) { options.features.Enable(Fortran::common::LanguageFeature::CUDA); } + if (enableGPUMode == "managed") { + options.features.Enable(Fortran::common::LanguageFeature::CudaManaged); + } else if (enableGPUMode == "unified") { + options.features.Enable(Fortran::common::LanguageFeature::CudaUnified); + } + if (fixedForm) { options.isFixedForm = fixedForm; }