[flang] Add support for lowering directives at the CONTAINS level (#95123)

There is currently support for lowering directives that appear outside
of a module or procedure, or inside the body of a module or procedure.
Extend this to support directives at the CONTAINS level of a module or
procedure, such as directives 3, 5, 7 9, and 10 in:

    !dir$ some directive 1
    module m
      !dir$ some directive 2
    contains
      !dir$ some directive 3
      subroutine p
        !dir$ some directive 4
      contains
        !dir$ some directive 5
        subroutine s1
          !dir$ some directive 6
        end subroutine s1
        !dir$ some directive 7
        subroutine s2
          !dir$ some directive 8
        end subroutine s2
        !dir$ some directive 9
      end subroutine p
      !dir$ some directive 10
    end module m
    !dir$ some directive 11

This is done by looking for CONTAINS statements at the module or
procedure level, while ignoring CONTAINS statements at the derived type
level.
This commit is contained in:
vdonaldson 2024-06-12 09:35:14 -04:00 committed by GitHub
parent 47afa10bba
commit 87374a8cff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 243 additions and 91 deletions

View File

@ -31,11 +31,14 @@
namespace Fortran::lower::pft {
struct CompilerDirectiveUnit;
struct Evaluation;
struct Program;
struct ModuleLikeUnit;
struct FunctionLikeUnit;
struct ModuleLikeUnit;
struct Program;
using ContainedUnit = std::variant<CompilerDirectiveUnit, FunctionLikeUnit>;
using ContainedUnitList = std::list<ContainedUnit>;
using EvaluationList = std::list<Evaluation>;
/// Provide a variant like container that can hold references. It can hold
@ -594,8 +597,8 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
/// Function-like units may contain evaluations (executable statements),
/// directives, and internal (nested) function-like units.
struct FunctionLikeUnit : public ProgramUnit {
// wrapper statements for function-like syntactic structures
using FunctionStatement =
@ -697,10 +700,10 @@ struct FunctionLikeUnit : public ProgramUnit {
std::optional<FunctionStatement> beginStmt;
FunctionStatement endStmt;
const semantics::Scope *scope;
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
std::list<FunctionLikeUnit> nestedFunctions;
ContainedUnitList containedUnitList;
EvaluationList evaluationList;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@ -746,7 +749,7 @@ struct ModuleLikeUnit : public ProgramUnit {
ModuleStatement beginStmt;
ModuleStatement endStmt;
std::list<FunctionLikeUnit> nestedFunctions;
ContainedUnitList containedUnitList;
EvaluationList evaluationList;
};

View File

@ -302,28 +302,32 @@ public:
bool hasMainProgram = false;
const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::visit(Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = f.getScope().symbol();
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = b.symTab.symbol();
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
},
u);
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = f.getScope().symbol();
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::ContainedUnit &unit :
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
&unit))
declareFunction(*f);
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = b.symTab.symbol();
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
},
u);
}
// Create definitions of intrinsic module constants.
@ -387,13 +391,15 @@ public:
// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
collectHostAssociatedVariables(f, escapeHost);
for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
collectHostAssociatedVariables(*f, escapeHost);
funit.setHostAssociatedSymbols(escapeHost);
// Declare internal procedures
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
declareFunction(*f);
}
/// Get the scope that is defining or using \p sym. The returned scope is not
@ -5356,8 +5362,9 @@ private:
endNewFunction(funit);
}
funit.setActiveEntry(0);
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
lowerFunc(f); // internal procedure
for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
lowerFunc(*f); // internal procedure
}
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@ -5381,8 +5388,9 @@ private:
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
lowerFunc(f);
for (Fortran::lower::pft::ContainedUnit &unit : mod.containedUnitList)
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
lowerFunc(*f);
}
void setCurrentPosition(const Fortran::parser::CharBlock &position) {

View File

@ -209,6 +209,20 @@ public:
}
}
bool Pre(const parser::SpecificationPart &) {
++specificationPartLevel;
return true;
}
void Post(const parser::SpecificationPart &) { --specificationPartLevel; }
bool Pre(const parser::ContainsStmt &) {
if (!specificationPartLevel) {
assert(containsStmtStack.size() && "empty contains stack");
containsStmtStack.back() = true;
}
return false;
}
// Module like
bool Pre(const parser::Module &node) { return enterModule(node); }
bool Pre(const parser::Submodule &node) { return enterModule(node); }
@ -249,15 +263,21 @@ public:
whereBody.u);
}
// CompilerDirective have special handling in case they are top level
// directives (i.e. they do not belong to a ProgramUnit).
// A CompilerDirective may appear outside any program unit, after a module
// or function contains statement, or inside a module or function.
bool Pre(const parser::CompilerDirective &directive) {
assert(pftParentStack.size() > 0 &&
"At least the Program must be a parent");
if (pftParentStack.back().isA<lower::pft::Program>()) {
addUnit(
lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
assert(pftParentStack.size() > 0 && "no program");
lower::pft::PftNode &node = pftParentStack.back();
if (node.isA<lower::pft::Program>()) {
addUnit(lower::pft::CompilerDirectiveUnit(directive, node));
return false;
} else if ((node.isA<lower::pft::ModuleLikeUnit>() ||
node.isA<lower::pft::FunctionLikeUnit>())) {
assert(containsStmtStack.size() && "empty contains stack");
if (containsStmtStack.back()) {
addContainedUnit(lower::pft::CompilerDirectiveUnit{directive, node});
return false;
}
}
return enterConstructOrDirective(directive);
}
@ -277,9 +297,10 @@ private:
/// Initialize a new module-like unit and make it the builder's focus.
template <typename A>
bool enterModule(const A &mod) {
Fortran::lower::pft::ModuleLikeUnit &unit =
lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
functionList = &unit.nestedFunctions;
containsStmtStack.push_back(false);
containedUnitList = &unit.containedUnitList;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@ -287,6 +308,7 @@ private:
}
void exitModule() {
containsStmtStack.pop_back();
if (!evaluationListStack.empty())
popEvaluationList();
pftParentStack.pop_back();
@ -344,12 +366,13 @@ private:
const semantics::SemanticsContext &semanticsContext) {
cleanModuleEvaluationList();
endFunctionBody(); // enclosing host subprogram body, if any
Fortran::lower::pft::FunctionLikeUnit &unit =
addFunction(lower::pft::FunctionLikeUnit{func, pftParentStack.back(),
semanticsContext});
lower::pft::FunctionLikeUnit &unit =
addContainedUnit(lower::pft::FunctionLikeUnit{
func, pftParentStack.back(), semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
functionList = &unit.nestedFunctions;
containsStmtStack.push_back(false);
containedUnitList = &unit.containedUnitList;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@ -361,6 +384,7 @@ private:
endFunctionBody();
analyzeBranches(nullptr, *evaluationListStack.back()); // add branch links
processEntryPoints();
containsStmtStack.pop_back();
popEvaluationList();
labelEvaluationMap = nullptr;
assignSymbolLabelMap = nullptr;
@ -371,7 +395,7 @@ private:
/// Initialize a new construct or directive and make it the builder's focus.
template <typename A>
bool enterConstructOrDirective(const A &constructOrDirective) {
Fortran::lower::pft::Evaluation &eval = addEvaluation(
lower::pft::Evaluation &eval = addEvaluation(
lower::pft::Evaluation{constructOrDirective, pftParentStack.back()});
eval.evaluationList.reset(new lower::pft::EvaluationList);
pushEvaluationList(eval.evaluationList.get());
@ -381,7 +405,7 @@ private:
}
void exitConstructOrDirective() {
auto isOpenMPLoopConstruct = [](Fortran::lower::pft::Evaluation *eval) {
auto isOpenMPLoopConstruct = [](lower::pft::Evaluation *eval) {
if (const auto *ompConstruct = eval->getIf<parser::OpenMPConstruct>())
if (std::holds_alternative<parser::OpenMPLoopConstruct>(
ompConstruct->u))
@ -396,8 +420,7 @@ private:
// construct region must have an exit target inside the region.
// This is not applicable to the OpenMP loop construct since the
// end of the loop is an available target inside the region.
Fortran::lower::pft::EvaluationList &evaluationList =
*eval->evaluationList;
lower::pft::EvaluationList &evaluationList = *eval->evaluationList;
if (!evaluationList.empty() && evaluationList.back().isConstruct()) {
static const parser::ContinueStmt exitTarget{};
addEvaluation(
@ -413,15 +436,15 @@ private:
void resetFunctionState() {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
[&](lower::pft::ModuleLikeUnit &p) {
containedUnitList = &p.containedUnitList;
},
[&](lower::pft::FunctionLikeUnit &p) {
functionList = &p.nestedFunctions;
containedUnitList = &p.containedUnitList;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
[&](lower::pft::ModuleLikeUnit &p) {
functionList = &p.nestedFunctions;
},
[&](auto &) { functionList = nullptr; },
[&](auto &) { containedUnitList = nullptr; },
});
}
}
@ -433,12 +456,11 @@ private:
}
template <typename A>
A &addFunction(A &&func) {
if (functionList) {
functionList->emplace_back(std::move(func));
return functionList->back();
}
return addUnit(std::move(func));
A &addContainedUnit(A &&unit) {
if (!containedUnitList)
return addUnit(std::move(unit));
containedUnitList->emplace_back(std::move(unit));
return std::get<A>(containedUnitList->back());
}
// ActionStmt has a couple of non-conforming cases, explicitly handled here.
@ -459,7 +481,6 @@ private:
/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
assert(functionList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
@ -499,15 +520,15 @@ private:
/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
assert(functionList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
"evaluation list isn't correct");
"invalid evaluation list");
evaluationListStack.emplace_back(evaluationList);
}
/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
assert(functionList && "not in a function");
assert(!evaluationListStack.empty() &&
"trying to pop an empty evaluationListStack");
evaluationListStack.pop_back();
}
@ -1089,9 +1110,8 @@ private:
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;
/// functionList points to the internal or module procedure function list
/// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
std::list<lower::pft::FunctionLikeUnit> *functionList{};
llvm::SmallVector<bool> containsStmtStack{};
lower::pft::ContainedUnitList *containedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
@ -1099,6 +1119,7 @@ private:
llvm::DenseMap<parser::Label, lower::pft::Evaluation *> *labelEvaluationMap{};
lower::pft::SymbolLabelMap *assignSymbolLabelMap{};
std::map<std::string, lower::pft::Evaluation *> constructNameMap{};
int specificationPartLevel{};
lower::pft::Evaluation *lastLexicalEvaluation{};
};
@ -1201,11 +1222,15 @@ public:
outputStream << " -> " << eval.controlSuccessor->printIndex;
else if (eval.isA<parser::EntryStmt>() && eval.lexicalSuccessor)
outputStream << " -> " << eval.lexicalSuccessor->printIndex;
bool extraNewline = false;
if (!eval.position.empty())
outputStream << ": " << eval.position.ToString();
else if (auto *dir = eval.getIf<Fortran::parser::CompilerDirective>())
else if (auto *dir = eval.getIf<parser::CompilerDirective>()) {
extraNewline = dir->source.ToString().back() == '\n';
outputStream << ": !" << dir->source.ToString();
outputStream << '\n';
}
if (!extraNewline)
outputStream << '\n';
if (eval.hasNestedEvaluations()) {
dumpEvaluationList(outputStream, *eval.evaluationList, indent + 1);
outputStream << indentString << "<<End " << name << bang << ">>\n";
@ -1265,13 +1290,7 @@ public:
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
if (!functionLikeUnit.nestedFunctions.empty()) {
outputStream << "\nContains\n";
for (const lower::pft::FunctionLikeUnit &func :
functionLikeUnit.nestedFunctions)
dumpFunctionLikeUnit(outputStream, func);
outputStream << "End Contains\n";
}
dumpContainedUnitList(outputStream, functionLikeUnit.containedUnitList);
outputStream << "End " << unitKind << ' ' << name << "\n\n";
}
@ -1298,11 +1317,8 @@ public:
});
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
outputStream << "Contains\n";
for (const lower::pft::FunctionLikeUnit &func :
moduleLikeUnit.nestedFunctions)
dumpFunctionLikeUnit(outputStream, func);
outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
dumpContainedUnitList(outputStream, moduleLikeUnit.containedUnitList);
outputStream << "End " << unitKind << ' ' << name << "\n\n";
}
// Top level directives
@ -1311,9 +1327,34 @@ public:
const lower::pft::CompilerDirectiveUnit &directive) {
outputStream << getNodeIndex(directive) << " ";
outputStream << "CompilerDirective: !";
outputStream << directive.get<Fortran::parser::CompilerDirective>()
.source.ToString();
outputStream << "\nEnd CompilerDirective\n\n";
bool extraNewline =
directive.get<parser::CompilerDirective>().source.ToString().back() ==
'\n';
outputStream
<< directive.get<parser::CompilerDirective>().source.ToString();
if (!extraNewline)
outputStream << "\n";
outputStream << "\n";
}
void dumpContainedUnitList(
llvm::raw_ostream &outputStream,
const lower::pft::ContainedUnitList &containedUnitList) {
if (containedUnitList.empty())
return;
outputStream << "\nContains\n";
for (const lower::pft::ContainedUnit &unit : containedUnitList)
if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&unit)) {
dumpFunctionLikeUnit(outputStream, *func);
} else if (const auto *dir =
std::get_if<lower::pft::CompilerDirectiveUnit>(&unit)) {
outputStream << getNodeIndex(*dir) << " ";
dumpEvaluation(outputStream,
lower::pft::Evaluation{
dir->get<parser::CompilerDirective>(), dir->parent});
outputStream << "\n";
}
outputStream << "End Contains\n";
}
void
@ -1321,8 +1362,8 @@ public:
const lower::pft::OpenACCDirectiveUnit &directive) {
outputStream << getNodeIndex(directive) << " ";
outputStream << "OpenACCDirective: !$acc ";
outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
.source.ToString();
outputStream
<< directive.get<parser::OpenACCRoutineConstruct>().source.ToString();
outputStream << "\nEnd OpenACCDirective\n\n";
}

View File

@ -0,0 +1,100 @@
! RUN: bbc -pft-test -o %t %s | FileCheck %s
module mm
!dir$ some directive 1
type t
logical :: tag
contains
final :: fin
end type
!dir$ some directive 2
contains
!dir$ some directive 3
subroutine fin(x)
type(t), intent(inout) :: x
x%tag =.true.
!dir$ some directive 4
call s1
call s2
print*, 'fin', x
contains
!dir$ some directive 5
subroutine s1
print*, 's1'
!dir$ some directive 6
end subroutine s1
!dir$ some directive 7
subroutine s2
!dir$ some directive 8
if (.true.) then
!dir$ some directive 9
print*, 's2'
!dir$ some directive 10
endif
!dir$ some directive 11
end subroutine s2
!dir$ some directive 12
end subroutine fin
!dir$ some directive 13
end module mm
!dir$ some directive 14
end
! CHECK: Module mm: module mm
! CHECK: CompilerDirective: !some directive 1
! CHECK: CompilerDirective: !some directive 2
! CHECK: Contains
! CHECK: CompilerDirective: !some directive 3
! CHECK: Subroutine fin: subroutine fin(x)
! CHECK: AssignmentStmt: x%tag =.true.
! CHECK: CompilerDirective: !some directive 4
! CHECK: CallStmt: call s1
! CHECK: CallStmt: call s2
! CHECK: PrintStmt: print*, 'fin', x
! CHECK: EndSubroutineStmt: end subroutine fin
! CHECK: Contains
! CHECK: CompilerDirective: !some directive 5
! CHECK: Subroutine s1: subroutine s1
! CHECK: PrintStmt: print*, 's1'
! CHECK: CompilerDirective: !some directive 6
! CHECK: EndSubroutineStmt: end subroutine s1
! CHECK: End Subroutine s1
! CHECK: CompilerDirective: !some directive 7
! CHECK: Subroutine s2: subroutine s2
! CHECK: CompilerDirective: !some directive 8
! CHECK: <<IfConstruct>> -> 7
! CHECK: IfThenStmt -> 7: if(.true.) then
! CHECK: ^CompilerDirective: !some directive 9
! CHECK: PrintStmt: print*, 's2'
! CHECK: CompilerDirective: !some directive 10
! CHECK: EndIfStmt: endif
! CHECK: <<End IfConstruct>>
! CHECK: CompilerDirective: !some directive 11
! CHECK: EndSubroutineStmt: end subroutine s2
! CHECK: End Subroutine s2
! CHECK: CompilerDirective: !some directive 12
! CHECK: End Contains
! CHECK: End Subroutine fin
! CHECK: CompilerDirective: !some directive 13
! CHECK: End Contains
! CHECK: End Module mm
! CHECK: CompilerDirective: !some directive 14
! CHECK: Program <anonymous>
! CHECK: EndProgramStmt: end
! CHECK: End Program <anonymous>