[flang][openmp] Changes for invoking scan Op (#123254)

This commit is contained in:
Anchu Rajendran S 2025-02-05 06:55:32 -08:00 committed by GitHub
parent 290a0d8752
commit ccd92ec4c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 128 additions and 47 deletions

View File

@ -344,6 +344,20 @@ bool ClauseProcessor::processDistSchedule(
return false;
}
bool ClauseProcessor::processExclusive(
mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
for (const Object &object : clause->v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.exclusiveVars.push_back(symVal);
}
return true;
}
return false;
}
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@ -380,6 +394,20 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}
bool ClauseProcessor::processInclusive(
mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
for (const Object &object : clause->v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.inclusiveVars.push_back(symVal);
}
return true;
}
return false;
}
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@ -1135,10 +1163,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);
rp.processReductionArguments(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms, result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));

View File

@ -64,6 +64,8 @@ public:
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
bool processExclusive(mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
@ -72,6 +74,8 @@ public:
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,

View File

@ -736,8 +736,8 @@ Enter make(const parser::OmpClause::Enter &inp,
Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: exclusive");
// inp.v -> parser::OmpObjectList
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Fail make(const parser::OmpClause::Fail &inp,
@ -846,8 +846,8 @@ If make(const parser::OmpClause::If &inp,
Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: inclusive");
// inp.v -> parser::OmpObjectList
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Indirect make(const parser::OmpClause::Indirect &inp,

View File

@ -1584,6 +1584,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}
static void genScanClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::ScanOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInclusive(loc, clauseOps);
cp.processExclusive(loc, clauseOps);
}
static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
@ -1981,6 +1990,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}
static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
converter.getCurrentLocation(), clauseOps);
}
/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
@ -2990,7 +3009,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");

View File

@ -31,6 +31,9 @@ static llvm::cl::opt<bool> forceByrefReduction(
llvm::cl::desc("Pass all reduction arguments by reference"),
llvm::cl::Hidden);
using ReductionModifier =
Fortran::lower::omp::clause::Reduction::ReductionModifier;
namespace Fortran {
namespace lower {
namespace omp {
@ -518,18 +521,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
void ReductionProcessor::addDeclareReduction(
mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
switch (mod) {
case ReductionModifier::Default:
return mlir::omp::ReductionModifier::defaultmod;
case ReductionModifier::Inscan:
return mlir::omp::ReductionModifier::inscan;
case ReductionModifier::Task:
return mlir::omp::ReductionModifier::task;
}
return mlir::omp::ReductionModifier::defaultmod;
}
void ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value()) {
if (mod.value() == ReductionModifier::Task)
TODO(currentLocation, "Reduction modifier `task` is not supported");
else
reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
}
mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{

View File

@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
@ -120,13 +121,14 @@ public:
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
static void addDeclareReduction(
static void processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod);
};
template <typename FloatOp, typename IntegerOp>

View File

@ -1,15 +0,0 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! CHECK: not yet implemented: Reduction modifiers are not supported
subroutine reduction_inscan()
integer :: i,j
i = 0
!$omp do reduction(inscan, +:i)
do j=1,10
!$omp scan inclusive(i)
i = i + 1
end do
!$omp end do
end subroutine reduction_inscan

View File

@ -1,14 +0,0 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! CHECK: not yet implemented: Reduction modifiers are not supported
subroutine foo()
integer :: i, j
j = 0
!$omp do reduction (inscan, *: j)
do i = 1, 10
!$omp scan inclusive(j)
j = j + 1
end do
end subroutine

View File

@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! CHECK: not yet implemented: Reduction modifiers are not supported
! CHECK: not yet implemented: Reduction modifier `task` is not supported
subroutine reduction_task()
integer :: i
i = 0

View File

@ -0,0 +1,36 @@
! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_1:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL_1:.*]]:2 = hlfir.declare %[[RED_ARG_1]]
! CHECK: omp.scan inclusive(%[[RED_DECL_1]]#1 : {{.*}})
subroutine inclusive_scan(a, b, n)
implicit none
integer a(:), b(:)
integer x, k, n
!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!$omp scan inclusive(x)
b(k) = x
end do
end subroutine inclusive_scan
! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_2:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL_2:.*]]:2 = hlfir.declare %[[RED_ARG_2]]
! CHECK: omp.scan exclusive(%[[RED_DECL_2]]#1 : {{.*}})
subroutine exclusive_scan(a, b, n)
implicit none
integer a(:), b(:)
integer x, k, n
!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!$omp scan exclusive(x)
b(k) = x
end do
end subroutine exclusive_scan

View File

@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
@ -274,6 +274,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,