3870 lines
148 KiB
C++
3870 lines
148 KiB
C++
//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the OpenMP dialect and its operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/Interfaces/FoldInterfaces.h"
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/STLForwardCompat.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/ADT/bit.h"
|
|
#include "llvm/Frontend/OpenMP/OMPConstants.h"
|
|
#include <cstddef>
|
|
#include <iterator>
|
|
#include <optional>
|
|
#include <variant>
|
|
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
|
|
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::omp;
|
|
|
|
static ArrayAttr makeArrayAttr(MLIRContext *context,
|
|
llvm::ArrayRef<Attribute> attrs) {
|
|
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
|
|
}
|
|
|
|
static DenseBoolArrayAttr
|
|
makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
|
|
return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
|
|
}
|
|
|
|
namespace {
|
|
struct MemRefPointerLikeModel
|
|
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
|
|
MemRefType> {
|
|
Type getElementType(Type pointer) const {
|
|
return llvm::cast<MemRefType>(pointer).getElementType();
|
|
}
|
|
};
|
|
|
|
struct LLVMPointerPointerLikeModel
|
|
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
|
|
LLVM::LLVMPointerType> {
|
|
Type getElementType(Type pointer) const { return Type(); }
|
|
};
|
|
} // namespace
|
|
|
|
void OpenMPDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
|
|
>();
|
|
addAttributes<
|
|
#define GET_ATTRDEF_LIST
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
|
|
>();
|
|
addTypes<
|
|
#define GET_TYPEDEF_LIST
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
|
|
>();
|
|
|
|
declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
|
|
|
|
MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
|
|
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
|
|
*getContext());
|
|
|
|
// Attach default offload module interface to module op to access
|
|
// offload functionality through
|
|
mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
|
|
*getContext());
|
|
|
|
// Attach default declare target interfaces to operations which can be marked
|
|
// as declare target (Global Operations and Functions/Subroutines in dialects
|
|
// that Fortran (or other languages that lower to MLIR) translates too
|
|
mlir::LLVM::GlobalOp::attachInterface<
|
|
mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>(
|
|
*getContext());
|
|
mlir::LLVM::LLVMFuncOp::attachInterface<
|
|
mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>(
|
|
*getContext());
|
|
mlir::func::FuncOp::attachInterface<
|
|
mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for Allocate Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parse an allocate clause with allocators and a list of operands with types.
|
|
///
|
|
/// allocate-operand-list :: = allocate-operand |
|
|
/// allocator-operand `,` allocate-operand-list
|
|
/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
|
|
/// ssa-id-and-type ::= ssa-id `:` type
|
|
static ParseResult parseAllocateAndAllocator(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars,
|
|
SmallVectorImpl<Type> &allocateTypes,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars,
|
|
SmallVectorImpl<Type> &allocatorTypes) {
|
|
|
|
return parser.parseCommaSeparatedList([&]() {
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
Type type;
|
|
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
|
return failure();
|
|
allocatorVars.push_back(operand);
|
|
allocatorTypes.push_back(type);
|
|
if (parser.parseArrow())
|
|
return failure();
|
|
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
|
return failure();
|
|
|
|
allocateVars.push_back(operand);
|
|
allocateTypes.push_back(type);
|
|
return success();
|
|
});
|
|
}
|
|
|
|
/// Print allocate clause
|
|
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
|
|
OperandRange allocateVars,
|
|
TypeRange allocateTypes,
|
|
OperandRange allocatorVars,
|
|
TypeRange allocatorTypes) {
|
|
for (unsigned i = 0; i < allocateVars.size(); ++i) {
|
|
std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
|
|
p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
|
|
p << allocateVars[i] << " : " << allocateTypes[i] << separator;
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for a clause attribute (StringEnumAttr)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename ClauseAttr>
|
|
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
|
|
using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
|
|
StringRef enumStr;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (parser.parseKeyword(&enumStr))
|
|
return failure();
|
|
if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
|
|
attr = ClauseAttr::get(parser.getContext(), *enumValue);
|
|
return success();
|
|
}
|
|
return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
|
|
}
|
|
|
|
template <typename ClauseAttr>
|
|
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
|
|
p << stringifyEnum(attr.getValue());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for Linear Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// linear ::= `linear` `(` linear-list `)`
|
|
/// linear-list := linear-val | linear-val linear-list
|
|
/// linear-val := ssa-id-and-type `=` ssa-id-and-type
|
|
static ParseResult parseLinearClause(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars,
|
|
SmallVectorImpl<Type> &linearTypes,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) {
|
|
return parser.parseCommaSeparatedList([&]() {
|
|
OpAsmParser::UnresolvedOperand var;
|
|
Type type;
|
|
OpAsmParser::UnresolvedOperand stepVar;
|
|
if (parser.parseOperand(var) || parser.parseEqual() ||
|
|
parser.parseOperand(stepVar) || parser.parseColonType(type))
|
|
return failure();
|
|
|
|
linearVars.push_back(var);
|
|
linearTypes.push_back(type);
|
|
linearStepVars.push_back(stepVar);
|
|
return success();
|
|
});
|
|
}
|
|
|
|
/// Print Linear Clause
|
|
static void printLinearClause(OpAsmPrinter &p, Operation *op,
|
|
ValueRange linearVars, TypeRange linearTypes,
|
|
ValueRange linearStepVars) {
|
|
size_t linearVarsSize = linearVars.size();
|
|
for (unsigned i = 0; i < linearVarsSize; ++i) {
|
|
std::string separator = i == linearVarsSize - 1 ? "" : ", ";
|
|
p << linearVars[i];
|
|
if (linearStepVars.size() > i)
|
|
p << " = " << linearStepVars[i];
|
|
p << " : " << linearVars[i].getType() << separator;
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifier for Nontemporal Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verifyNontemporalClause(Operation *op,
|
|
OperandRange nontemporalVars) {
|
|
|
|
// Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
|
|
DenseSet<Value> nontemporalItems;
|
|
for (const auto &it : nontemporalVars)
|
|
if (!nontemporalItems.insert(it).second)
|
|
return op->emitOpError() << "nontemporal variable used more than once";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, verifier and printer for Aligned Clause
|
|
//===----------------------------------------------------------------------===//
|
|
static LogicalResult verifyAlignedClause(Operation *op,
|
|
std::optional<ArrayAttr> alignments,
|
|
OperandRange alignedVars) {
|
|
// Check if number of alignment values equals to number of aligned variables
|
|
if (!alignedVars.empty()) {
|
|
if (!alignments || alignments->size() != alignedVars.size())
|
|
return op->emitOpError()
|
|
<< "expected as many alignment values as aligned variables";
|
|
} else {
|
|
if (alignments)
|
|
return op->emitOpError() << "unexpected alignment values attribute";
|
|
return success();
|
|
}
|
|
|
|
// Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
|
|
DenseSet<Value> alignedItems;
|
|
for (auto it : alignedVars)
|
|
if (!alignedItems.insert(it).second)
|
|
return op->emitOpError() << "aligned variable used more than once";
|
|
|
|
if (!alignments)
|
|
return success();
|
|
|
|
// Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
|
|
for (unsigned i = 0; i < (*alignments).size(); ++i) {
|
|
if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
|
|
if (intAttr.getValue().sle(0))
|
|
return op->emitOpError() << "alignment should be greater than 0";
|
|
} else {
|
|
return op->emitOpError() << "expected integer alignment";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// aligned ::= `aligned` `(` aligned-list `)`
|
|
/// aligned-list := aligned-val | aligned-val aligned-list
|
|
/// aligned-val := ssa-id-and-type `->` alignment
|
|
static ParseResult
|
|
parseAlignedClause(OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars,
|
|
SmallVectorImpl<Type> &alignedTypes,
|
|
ArrayAttr &alignmentsAttr) {
|
|
SmallVector<Attribute> alignmentVec;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseOperand(alignedVars.emplace_back()) ||
|
|
parser.parseColonType(alignedTypes.emplace_back()) ||
|
|
parser.parseArrow() ||
|
|
parser.parseAttribute(alignmentVec.emplace_back())) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
})))
|
|
return failure();
|
|
SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
|
|
alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
|
|
return success();
|
|
}
|
|
|
|
/// Print Aligned Clause
|
|
static void printAlignedClause(OpAsmPrinter &p, Operation *op,
|
|
ValueRange alignedVars, TypeRange alignedTypes,
|
|
std::optional<ArrayAttr> alignments) {
|
|
for (unsigned i = 0; i < alignedVars.size(); ++i) {
|
|
if (i != 0)
|
|
p << ", ";
|
|
p << alignedVars[i] << " : " << alignedVars[i].getType();
|
|
p << " -> " << (*alignments)[i];
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, printer and verifier for Schedule Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult
|
|
verifyScheduleModifiers(OpAsmParser &parser,
|
|
SmallVectorImpl<SmallString<12>> &modifiers) {
|
|
if (modifiers.size() > 2)
|
|
return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
|
|
for (const auto &mod : modifiers) {
|
|
// Translate the string. If it has no value, then it was not a valid
|
|
// modifier!
|
|
auto symbol = symbolizeScheduleModifier(mod);
|
|
if (!symbol)
|
|
return parser.emitError(parser.getNameLoc())
|
|
<< " unknown modifier type: " << mod;
|
|
}
|
|
|
|
// If we have one modifier that is "simd", then stick a "none" modiifer in
|
|
// index 0.
|
|
if (modifiers.size() == 1) {
|
|
if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
|
|
modifiers.push_back(modifiers[0]);
|
|
modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
|
|
}
|
|
} else if (modifiers.size() == 2) {
|
|
// If there are two modifier:
|
|
// First modifier should not be simd, second one should be simd
|
|
if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
|
|
symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
|
|
return parser.emitError(parser.getNameLoc())
|
|
<< " incorrect modifier order";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// schedule ::= `schedule` `(` sched-list `)`
|
|
/// sched-list ::= sched-val | sched-val sched-list |
|
|
/// sched-val `,` sched-modifier
|
|
/// sched-val ::= sched-with-chunk | sched-wo-chunk
|
|
/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
|
|
/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
|
|
/// sched-wo-chunk ::= `auto` | `runtime`
|
|
/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
|
|
/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
|
|
static ParseResult
|
|
parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
|
|
ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
|
|
Type &chunkType) {
|
|
StringRef keyword;
|
|
if (parser.parseKeyword(&keyword))
|
|
return failure();
|
|
std::optional<mlir::omp::ClauseScheduleKind> schedule =
|
|
symbolizeClauseScheduleKind(keyword);
|
|
if (!schedule)
|
|
return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
|
|
|
|
scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
|
|
switch (*schedule) {
|
|
case ClauseScheduleKind::Static:
|
|
case ClauseScheduleKind::Dynamic:
|
|
case ClauseScheduleKind::Guided:
|
|
if (succeeded(parser.parseOptionalEqual())) {
|
|
chunkSize = OpAsmParser::UnresolvedOperand{};
|
|
if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
|
|
return failure();
|
|
} else {
|
|
chunkSize = std::nullopt;
|
|
}
|
|
break;
|
|
case ClauseScheduleKind::Auto:
|
|
case ClauseScheduleKind::Runtime:
|
|
chunkSize = std::nullopt;
|
|
}
|
|
|
|
// If there is a comma, we have one or more modifiers..
|
|
SmallVector<SmallString<12>> modifiers;
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
StringRef mod;
|
|
if (parser.parseKeyword(&mod))
|
|
return failure();
|
|
modifiers.push_back(mod);
|
|
}
|
|
|
|
if (verifyScheduleModifiers(parser, modifiers))
|
|
return failure();
|
|
|
|
if (!modifiers.empty()) {
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (std::optional<ScheduleModifier> mod =
|
|
symbolizeScheduleModifier(modifiers[0])) {
|
|
scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
|
|
} else {
|
|
return parser.emitError(loc, "invalid schedule modifier");
|
|
}
|
|
// Only SIMD attribute is allowed here!
|
|
if (modifiers.size() > 1) {
|
|
assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
|
|
scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Print schedule clause
|
|
static void printScheduleClause(OpAsmPrinter &p, Operation *op,
|
|
ClauseScheduleKindAttr scheduleKind,
|
|
ScheduleModifierAttr scheduleMod,
|
|
UnitAttr scheduleSimd, Value scheduleChunk,
|
|
Type scheduleChunkType) {
|
|
p << stringifyClauseScheduleKind(scheduleKind.getValue());
|
|
if (scheduleChunk)
|
|
p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
|
|
if (scheduleMod)
|
|
p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
|
|
if (scheduleSimd)
|
|
p << ", simd";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for Order Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// order ::= `order` `(` [order-modifier ':'] concurrent `)`
|
|
// order-modifier ::= reproducible | unconstrained
|
|
static ParseResult parseOrderClause(OpAsmParser &parser,
|
|
ClauseOrderKindAttr &order,
|
|
OrderModifierAttr &orderMod) {
|
|
StringRef enumStr;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (parser.parseKeyword(&enumStr))
|
|
return failure();
|
|
if (std::optional<OrderModifier> enumValue =
|
|
symbolizeOrderModifier(enumStr)) {
|
|
orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
|
|
if (parser.parseOptionalColon())
|
|
return failure();
|
|
loc = parser.getCurrentLocation();
|
|
if (parser.parseKeyword(&enumStr))
|
|
return failure();
|
|
}
|
|
if (std::optional<ClauseOrderKind> enumValue =
|
|
symbolizeClauseOrderKind(enumStr)) {
|
|
order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
|
|
return success();
|
|
}
|
|
return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
|
|
}
|
|
|
|
static void printOrderClause(OpAsmPrinter &p, Operation *op,
|
|
ClauseOrderKindAttr order,
|
|
OrderModifierAttr orderMod) {
|
|
if (orderMod)
|
|
p << stringifyOrderModifier(orderMod.getValue()) << ":";
|
|
if (order)
|
|
p << stringifyClauseOrderKind(order.getValue());
|
|
}
|
|
|
|
template <typename ClauseTypeAttr, typename ClauseType>
|
|
static ParseResult
|
|
parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &operand,
|
|
Type &operandType,
|
|
std::optional<ClauseType> (*symbolizeClause)(StringRef),
|
|
StringRef clauseName) {
|
|
StringRef enumStr;
|
|
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
|
|
if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
|
|
prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
|
|
if (parser.parseComma())
|
|
return failure();
|
|
} else {
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid " << clauseName << " modifier : '" << enumStr << "'";
|
|
;
|
|
}
|
|
}
|
|
|
|
OpAsmParser::UnresolvedOperand var;
|
|
if (succeeded(parser.parseOperand(var))) {
|
|
operand = var;
|
|
} else {
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "expected " << clauseName << " operand";
|
|
}
|
|
|
|
if (operand.has_value()) {
|
|
if (parser.parseColonType(operandType))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
template <typename ClauseTypeAttr, typename ClauseType>
|
|
static void
|
|
printGranularityClause(OpAsmPrinter &p, Operation *op,
|
|
ClauseTypeAttr prescriptiveness, Value operand,
|
|
mlir::Type operandType,
|
|
StringRef (*stringifyClauseType)(ClauseType)) {
|
|
|
|
if (prescriptiveness)
|
|
p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
|
|
|
|
if (operand)
|
|
p << operand << ": " << operandType;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for grainsize Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
|
|
static ParseResult
|
|
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
|
|
Type &grainsizeType) {
|
|
return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
|
|
parser, grainsizeMod, grainsize, grainsizeType,
|
|
&symbolizeClauseGrainsizeType, "grainsize");
|
|
}
|
|
|
|
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
|
|
ClauseGrainsizeTypeAttr grainsizeMod,
|
|
Value grainsize, mlir::Type grainsizeType) {
|
|
printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
|
|
p, op, grainsizeMod, grainsize, grainsizeType,
|
|
&stringifyClauseGrainsizeType);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser and printer for num_tasks Clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
|
|
static ParseResult
|
|
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
|
|
Type &numTasksType) {
|
|
return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
|
|
parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
|
|
"num_tasks");
|
|
}
|
|
|
|
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
|
|
ClauseNumTasksTypeAttr numTasksMod,
|
|
Value numTasks, mlir::Type numTasksType) {
|
|
printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
|
|
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parsers for operations including clauses that define entry block arguments.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct MapParseArgs {
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
|
|
SmallVectorImpl<Type> &types;
|
|
MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
|
|
SmallVectorImpl<Type> &types)
|
|
: vars(vars), types(types) {}
|
|
};
|
|
struct PrivateParseArgs {
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
|
|
llvm::SmallVectorImpl<Type> &types;
|
|
ArrayAttr &syms;
|
|
UnitAttr &needsBarrier;
|
|
DenseI64ArrayAttr *mapIndices;
|
|
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
|
|
SmallVectorImpl<Type> &types, ArrayAttr &syms,
|
|
UnitAttr &needsBarrier,
|
|
DenseI64ArrayAttr *mapIndices = nullptr)
|
|
: vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
|
|
mapIndices(mapIndices) {}
|
|
};
|
|
|
|
struct ReductionParseArgs {
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
|
|
SmallVectorImpl<Type> &types;
|
|
DenseBoolArrayAttr &byref;
|
|
ArrayAttr &syms;
|
|
ReductionModifierAttr *modifier;
|
|
ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
|
|
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
|
|
ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
|
|
: vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
|
|
};
|
|
|
|
struct AllRegionParseArgs {
|
|
std::optional<MapParseArgs> hasDeviceAddrArgs;
|
|
std::optional<MapParseArgs> hostEvalArgs;
|
|
std::optional<ReductionParseArgs> inReductionArgs;
|
|
std::optional<MapParseArgs> mapArgs;
|
|
std::optional<PrivateParseArgs> privateArgs;
|
|
std::optional<ReductionParseArgs> reductionArgs;
|
|
std::optional<ReductionParseArgs> taskReductionArgs;
|
|
std::optional<MapParseArgs> useDeviceAddrArgs;
|
|
std::optional<MapParseArgs> useDevicePtrArgs;
|
|
};
|
|
} // namespace
|
|
|
|
static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
|
|
return "private_barrier";
|
|
}
|
|
|
|
static ParseResult parseClauseWithRegionArgs(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
|
|
SmallVectorImpl<Type> &types,
|
|
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
|
|
ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
|
|
DenseBoolArrayAttr *byref = nullptr,
|
|
ReductionModifierAttr *modifier = nullptr,
|
|
UnitAttr *needsBarrier = nullptr) {
|
|
SmallVector<SymbolRefAttr> symbolVec;
|
|
SmallVector<int64_t> mapIndicesVec;
|
|
SmallVector<bool> isByRefVec;
|
|
unsigned regionArgOffset = regionPrivateArgs.size();
|
|
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
|
|
StringRef enumStr;
|
|
if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
|
|
parser.parseComma())
|
|
return failure();
|
|
std::optional<ReductionModifier> enumValue =
|
|
symbolizeReductionModifier(enumStr);
|
|
if (!enumValue.has_value())
|
|
return failure();
|
|
*modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
|
|
if (!*modifier)
|
|
return failure();
|
|
}
|
|
|
|
if (parser.parseCommaSeparatedList([&]() {
|
|
if (byref)
|
|
isByRefVec.push_back(
|
|
parser.parseOptionalKeyword("byref").succeeded());
|
|
|
|
if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
|
|
return failure();
|
|
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseArrow() ||
|
|
parser.parseArgument(regionPrivateArgs.emplace_back()))
|
|
return failure();
|
|
|
|
if (mapIndices) {
|
|
if (parser.parseOptionalLSquare().succeeded()) {
|
|
if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
|
|
parser.parseInteger(mapIndicesVec.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
mapIndicesVec.push_back(-1);
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}))
|
|
return failure();
|
|
|
|
if (parser.parseColon())
|
|
return failure();
|
|
|
|
if (parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseType(types.emplace_back()))
|
|
return failure();
|
|
|
|
return success();
|
|
}))
|
|
return failure();
|
|
|
|
if (operands.size() != types.size())
|
|
return failure();
|
|
|
|
if (parser.parseRParen())
|
|
return failure();
|
|
|
|
if (needsBarrier) {
|
|
if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
|
|
.succeeded())
|
|
*needsBarrier = mlir::UnitAttr::get(parser.getContext());
|
|
}
|
|
|
|
auto *argsBegin = regionPrivateArgs.begin();
|
|
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
|
|
argsBegin + regionArgOffset + types.size());
|
|
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
|
|
prv.type = type;
|
|
}
|
|
|
|
if (symbols) {
|
|
SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
|
|
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
|
|
}
|
|
|
|
if (!mapIndicesVec.empty())
|
|
*mapIndices =
|
|
mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
|
|
|
|
if (byref)
|
|
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
|
|
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseBlockArgClause(
|
|
OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
|
|
StringRef keyword, std::optional<MapParseArgs> mapArgs) {
|
|
if (succeeded(parser.parseOptionalKeyword(keyword))) {
|
|
if (!mapArgs)
|
|
return failure();
|
|
|
|
if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
|
|
entryBlockArgs)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseBlockArgClause(
|
|
OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
|
|
StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
|
|
if (succeeded(parser.parseOptionalKeyword(keyword))) {
|
|
if (!privateArgs)
|
|
return failure();
|
|
|
|
if (failed(parseClauseWithRegionArgs(
|
|
parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
|
|
&privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
|
|
/*modifier=*/nullptr, &privateArgs->needsBarrier)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseBlockArgClause(
|
|
OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
|
|
StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
|
|
if (succeeded(parser.parseOptionalKeyword(keyword))) {
|
|
if (!reductionArgs)
|
|
return failure();
|
|
if (failed(parseClauseWithRegionArgs(
|
|
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
|
|
&reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
|
|
reductionArgs->modifier)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
|
|
AllRegionParseArgs args) {
|
|
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
|
|
args.hasDeviceAddrArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `has_device_addr` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
|
|
args.hostEvalArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `host_eval` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
|
|
args.inReductionArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `in_reduction` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
|
|
args.mapArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `map_entries` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
|
|
args.privateArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `private` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
|
|
args.reductionArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `reduction` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
|
|
args.taskReductionArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `task_reduction` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
|
|
args.useDeviceAddrArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `use_device_addr` format";
|
|
|
|
if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
|
|
args.useDevicePtrArgs)))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "invalid `use_device_addr` format";
|
|
|
|
return parser.parseRegion(region, entryBlockArgs);
|
|
}
|
|
|
|
// These parseXyz functions correspond to the custom<Xyz> definitions
|
|
// in the .td file(s).
|
|
static ParseResult parseTargetOpRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
|
|
SmallVectorImpl<Type> &hasDeviceAddrTypes,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
|
|
SmallVectorImpl<Type> &hostEvalTypes,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
|
|
SmallVectorImpl<Type> &inReductionTypes,
|
|
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
|
|
SmallVectorImpl<Type> &mapTypes,
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
|
|
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
|
|
UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
|
|
AllRegionParseArgs args;
|
|
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
|
|
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.mapArgs.emplace(mapVars, mapTypes);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier, &privateMaps);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parseInReductionPrivateRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
|
|
SmallVectorImpl<Type> &inReductionTypes,
|
|
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
|
|
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
|
|
UnitAttr &privateNeedsBarrier) {
|
|
AllRegionParseArgs args;
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parseInReductionPrivateReductionRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
|
|
SmallVectorImpl<Type> &inReductionTypes,
|
|
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
|
|
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
|
|
UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
|
|
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
|
|
ArrayAttr &reductionSyms) {
|
|
AllRegionParseArgs args;
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier);
|
|
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
|
|
reductionSyms, &reductionMod);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parsePrivateRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
|
|
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
|
|
UnitAttr &privateNeedsBarrier) {
|
|
AllRegionParseArgs args;
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parsePrivateReductionRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
|
|
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
|
|
UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
|
|
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
|
|
ArrayAttr &reductionSyms) {
|
|
AllRegionParseArgs args;
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier);
|
|
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
|
|
reductionSyms, &reductionMod);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parseTaskReductionRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
|
|
SmallVectorImpl<Type> &taskReductionTypes,
|
|
DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
|
|
AllRegionParseArgs args;
|
|
args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
|
|
taskReductionByref, taskReductionSyms);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
|
|
SmallVectorImpl<Type> &useDeviceAddrTypes,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
|
|
SmallVectorImpl<Type> &useDevicePtrTypes) {
|
|
AllRegionParseArgs args;
|
|
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
|
|
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
|
|
return parseBlockArgRegion(parser, region, args);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Printers for operations including clauses that define entry block arguments.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct MapPrintArgs {
|
|
ValueRange vars;
|
|
TypeRange types;
|
|
MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
|
|
};
|
|
struct PrivatePrintArgs {
|
|
ValueRange vars;
|
|
TypeRange types;
|
|
ArrayAttr syms;
|
|
UnitAttr needsBarrier;
|
|
DenseI64ArrayAttr mapIndices;
|
|
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
|
|
UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
|
|
: vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
|
|
mapIndices(mapIndices) {}
|
|
};
|
|
struct ReductionPrintArgs {
|
|
ValueRange vars;
|
|
TypeRange types;
|
|
DenseBoolArrayAttr byref;
|
|
ArrayAttr syms;
|
|
ReductionModifierAttr modifier;
|
|
ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
|
|
ArrayAttr syms, ReductionModifierAttr mod = nullptr)
|
|
: vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
|
|
};
|
|
struct AllRegionPrintArgs {
|
|
std::optional<MapPrintArgs> hasDeviceAddrArgs;
|
|
std::optional<MapPrintArgs> hostEvalArgs;
|
|
std::optional<ReductionPrintArgs> inReductionArgs;
|
|
std::optional<MapPrintArgs> mapArgs;
|
|
std::optional<PrivatePrintArgs> privateArgs;
|
|
std::optional<ReductionPrintArgs> reductionArgs;
|
|
std::optional<ReductionPrintArgs> taskReductionArgs;
|
|
std::optional<MapPrintArgs> useDeviceAddrArgs;
|
|
std::optional<MapPrintArgs> useDevicePtrArgs;
|
|
};
|
|
} // namespace
|
|
|
|
static void printClauseWithRegionArgs(
|
|
OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
|
|
ValueRange argsSubrange, ValueRange operands, TypeRange types,
|
|
ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
|
|
DenseBoolArrayAttr byref = nullptr,
|
|
ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
|
|
if (argsSubrange.empty())
|
|
return;
|
|
|
|
p << clauseName << "(";
|
|
|
|
if (modifier)
|
|
p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
|
|
|
|
if (!symbols) {
|
|
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
|
|
symbols = ArrayAttr::get(ctx, values);
|
|
}
|
|
|
|
if (!mapIndices) {
|
|
llvm::SmallVector<int64_t> values(operands.size(), -1);
|
|
mapIndices = DenseI64ArrayAttr::get(ctx, values);
|
|
}
|
|
|
|
if (!byref) {
|
|
mlir::SmallVector<bool> values(operands.size(), false);
|
|
byref = DenseBoolArrayAttr::get(ctx, values);
|
|
}
|
|
|
|
llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
|
|
mapIndices.asArrayRef(),
|
|
byref.asArrayRef()),
|
|
p, [&p](auto t) {
|
|
auto [op, arg, sym, map, isByRef] = t;
|
|
if (isByRef)
|
|
p << "byref ";
|
|
if (sym)
|
|
p << sym << " ";
|
|
|
|
p << op << " -> " << arg;
|
|
|
|
if (map != -1)
|
|
p << " [map_idx=" << map << "]";
|
|
});
|
|
p << " : ";
|
|
llvm::interleaveComma(types, p);
|
|
p << ") ";
|
|
|
|
if (needsBarrier)
|
|
p << getPrivateNeedsBarrierSpelling() << " ";
|
|
}
|
|
|
|
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
|
|
StringRef clauseName, ValueRange argsSubrange,
|
|
std::optional<MapPrintArgs> mapArgs) {
|
|
if (mapArgs)
|
|
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
|
|
mapArgs->types);
|
|
}
|
|
|
|
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
|
|
StringRef clauseName, ValueRange argsSubrange,
|
|
std::optional<PrivatePrintArgs> privateArgs) {
|
|
if (privateArgs)
|
|
printClauseWithRegionArgs(
|
|
p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
|
|
privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
|
|
/*modifier=*/nullptr, privateArgs->needsBarrier);
|
|
}
|
|
|
|
static void
|
|
printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
|
|
ValueRange argsSubrange,
|
|
std::optional<ReductionPrintArgs> reductionArgs) {
|
|
if (reductionArgs)
|
|
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
|
|
reductionArgs->vars, reductionArgs->types,
|
|
reductionArgs->syms, /*mapIndices=*/nullptr,
|
|
reductionArgs->byref, reductionArgs->modifier);
|
|
}
|
|
|
|
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
|
|
const AllRegionPrintArgs &args) {
|
|
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
|
|
MLIRContext *ctx = op->getContext();
|
|
|
|
printBlockArgClause(p, ctx, "has_device_addr",
|
|
iface.getHasDeviceAddrBlockArgs(),
|
|
args.hasDeviceAddrArgs);
|
|
printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
|
|
args.hostEvalArgs);
|
|
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
|
|
args.inReductionArgs);
|
|
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
|
|
args.mapArgs);
|
|
printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
|
|
args.privateArgs);
|
|
printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
|
|
args.reductionArgs);
|
|
printBlockArgClause(p, ctx, "task_reduction",
|
|
iface.getTaskReductionBlockArgs(),
|
|
args.taskReductionArgs);
|
|
printBlockArgClause(p, ctx, "use_device_addr",
|
|
iface.getUseDeviceAddrBlockArgs(),
|
|
args.useDeviceAddrArgs);
|
|
printBlockArgClause(p, ctx, "use_device_ptr",
|
|
iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
|
|
|
|
p.printRegion(region, /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
// These parseXyz functions correspond to the custom<Xyz> definitions
|
|
// in the .td file(s).
|
|
static void printTargetOpRegion(
|
|
OpAsmPrinter &p, Operation *op, Region ®ion,
|
|
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
|
|
ValueRange hostEvalVars, TypeRange hostEvalTypes,
|
|
ValueRange inReductionVars, TypeRange inReductionTypes,
|
|
DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
|
|
ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
|
|
TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
|
|
DenseI64ArrayAttr privateMaps) {
|
|
AllRegionPrintArgs args;
|
|
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
|
|
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.mapArgs.emplace(mapVars, mapTypes);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier, privateMaps);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printInReductionPrivateRegion(
|
|
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
|
|
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
|
|
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
|
|
ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
|
|
AllRegionPrintArgs args;
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier,
|
|
/*mapIndices=*/nullptr);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printInReductionPrivateReductionRegion(
|
|
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
|
|
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
|
|
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
|
|
ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
|
|
ReductionModifierAttr reductionMod, ValueRange reductionVars,
|
|
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
|
|
ArrayAttr reductionSyms) {
|
|
AllRegionPrintArgs args;
|
|
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
|
|
inReductionByref, inReductionSyms);
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier,
|
|
/*mapIndices=*/nullptr);
|
|
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
|
|
reductionSyms, reductionMod);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
|
|
ValueRange privateVars, TypeRange privateTypes,
|
|
ArrayAttr privateSyms,
|
|
UnitAttr privateNeedsBarrier) {
|
|
AllRegionPrintArgs args;
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier,
|
|
/*mapIndices=*/nullptr);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printPrivateReductionRegion(
|
|
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars,
|
|
TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
|
|
ReductionModifierAttr reductionMod, ValueRange reductionVars,
|
|
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
|
|
ArrayAttr reductionSyms) {
|
|
AllRegionPrintArgs args;
|
|
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
|
|
privateNeedsBarrier,
|
|
/*mapIndices=*/nullptr);
|
|
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
|
|
reductionSyms, reductionMod);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
|
|
Region ®ion,
|
|
ValueRange taskReductionVars,
|
|
TypeRange taskReductionTypes,
|
|
DenseBoolArrayAttr taskReductionByref,
|
|
ArrayAttr taskReductionSyms) {
|
|
AllRegionPrintArgs args;
|
|
args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
|
|
taskReductionByref, taskReductionSyms);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
|
|
Region ®ion,
|
|
ValueRange useDeviceAddrVars,
|
|
TypeRange useDeviceAddrTypes,
|
|
ValueRange useDevicePtrVars,
|
|
TypeRange useDevicePtrTypes) {
|
|
AllRegionPrintArgs args;
|
|
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
|
|
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
|
|
printBlockArgRegion(p, op, region, args);
|
|
}
|
|
|
|
/// Verifies Reduction Clause
|
|
static LogicalResult
|
|
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
|
|
OperandRange reductionVars,
|
|
std::optional<ArrayRef<bool>> reductionByref) {
|
|
if (!reductionVars.empty()) {
|
|
if (!reductionSyms || reductionSyms->size() != reductionVars.size())
|
|
return op->emitOpError()
|
|
<< "expected as many reduction symbol references "
|
|
"as reduction variables";
|
|
if (reductionByref && reductionByref->size() != reductionVars.size())
|
|
return op->emitError() << "expected as many reduction variable by "
|
|
"reference attributes as reduction variables";
|
|
} else {
|
|
if (reductionSyms)
|
|
return op->emitOpError() << "unexpected reduction symbol references";
|
|
return success();
|
|
}
|
|
|
|
// TODO: The followings should be done in
|
|
// SymbolUserOpInterface::verifySymbolUses.
|
|
DenseSet<Value> accumulators;
|
|
for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
|
|
Value accum = std::get<0>(args);
|
|
|
|
if (!accumulators.insert(accum).second)
|
|
return op->emitOpError() << "accumulator variable used more than once";
|
|
|
|
Type varType = accum.getType();
|
|
auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
|
|
auto decl =
|
|
SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
|
|
if (!decl)
|
|
return op->emitOpError() << "expected symbol reference " << symbolRef
|
|
<< " to point to a reduction declaration";
|
|
|
|
if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
|
|
return op->emitOpError()
|
|
<< "expected accumulator (" << varType
|
|
<< ") to be the same type as reduction declaration ("
|
|
<< decl.getAccumulatorType() << ")";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, printer and verifier for Copyprivate
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// copyprivate-entry-list ::= copyprivate-entry
|
|
/// | copyprivate-entry-list `,` copyprivate-entry
|
|
/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
|
|
static ParseResult parseCopyprivate(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> ©privateVars,
|
|
SmallVectorImpl<Type> ©privateTypes, ArrayAttr ©privateSyms) {
|
|
SmallVector<SymbolRefAttr> symsVec;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseOperand(copyprivateVars.emplace_back()) ||
|
|
parser.parseArrow() ||
|
|
parser.parseAttribute(symsVec.emplace_back()) ||
|
|
parser.parseColonType(copyprivateTypes.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
|
|
copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
|
|
return success();
|
|
}
|
|
|
|
/// Print Copyprivate clause
|
|
static void printCopyprivate(OpAsmPrinter &p, Operation *op,
|
|
OperandRange copyprivateVars,
|
|
TypeRange copyprivateTypes,
|
|
std::optional<ArrayAttr> copyprivateSyms) {
|
|
if (!copyprivateSyms.has_value())
|
|
return;
|
|
llvm::interleaveComma(
|
|
llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
|
|
[&](const auto &args) {
|
|
p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
|
|
<< std::get<2>(args);
|
|
});
|
|
}
|
|
|
|
/// Verifies CopyPrivate Clause
|
|
static LogicalResult
|
|
verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars,
|
|
std::optional<ArrayAttr> copyprivateSyms) {
|
|
size_t copyprivateSymsSize =
|
|
copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
|
|
if (copyprivateSymsSize != copyprivateVars.size())
|
|
return op->emitOpError() << "inconsistent number of copyprivate vars (= "
|
|
<< copyprivateVars.size()
|
|
<< ") and functions (= " << copyprivateSymsSize
|
|
<< "), both must be equal";
|
|
if (!copyprivateSyms.has_value())
|
|
return success();
|
|
|
|
for (auto copyprivateVarAndSym :
|
|
llvm::zip(copyprivateVars, *copyprivateSyms)) {
|
|
auto symbolRef =
|
|
llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
|
|
std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
|
|
funcOp;
|
|
if (mlir::func::FuncOp mlirFuncOp =
|
|
SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
|
|
symbolRef))
|
|
funcOp = mlirFuncOp;
|
|
else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
|
|
SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
|
|
op, symbolRef))
|
|
funcOp = llvmFuncOp;
|
|
|
|
auto getNumArguments = [&] {
|
|
return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
|
|
};
|
|
|
|
auto getArgumentType = [&](unsigned i) {
|
|
return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
|
|
*funcOp);
|
|
};
|
|
|
|
if (!funcOp)
|
|
return op->emitOpError() << "expected symbol reference " << symbolRef
|
|
<< " to point to a copy function";
|
|
|
|
if (getNumArguments() != 2)
|
|
return op->emitOpError()
|
|
<< "expected copy function " << symbolRef << " to have 2 operands";
|
|
|
|
Type argTy = getArgumentType(0);
|
|
if (argTy != getArgumentType(1))
|
|
return op->emitOpError() << "expected copy function " << symbolRef
|
|
<< " arguments to have the same type";
|
|
|
|
Type varType = std::get<0>(copyprivateVarAndSym).getType();
|
|
if (argTy != varType)
|
|
return op->emitOpError()
|
|
<< "expected copy function arguments' type (" << argTy
|
|
<< ") to be the same as copyprivate variable's type (" << varType
|
|
<< ")";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, printer and verifier for DependVarList
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// depend-entry-list ::= depend-entry
|
|
/// | depend-entry-list `,` depend-entry
|
|
/// depend-entry ::= depend-kind `->` ssa-id `:` type
|
|
static ParseResult
|
|
parseDependVarList(OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
|
|
SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
|
|
SmallVector<ClauseTaskDependAttr> kindsVec;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
StringRef keyword;
|
|
if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
|
|
parser.parseOperand(dependVars.emplace_back()) ||
|
|
parser.parseColonType(dependTypes.emplace_back()))
|
|
return failure();
|
|
if (std::optional<ClauseTaskDepend> keywordDepend =
|
|
(symbolizeClauseTaskDepend(keyword)))
|
|
kindsVec.emplace_back(
|
|
ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
|
|
else
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
|
|
dependKinds = ArrayAttr::get(parser.getContext(), kinds);
|
|
return success();
|
|
}
|
|
|
|
/// Print Depend clause
|
|
static void printDependVarList(OpAsmPrinter &p, Operation *op,
|
|
OperandRange dependVars, TypeRange dependTypes,
|
|
std::optional<ArrayAttr> dependKinds) {
|
|
|
|
for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
|
|
if (i != 0)
|
|
p << ", ";
|
|
p << stringifyClauseTaskDepend(
|
|
llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
|
|
.getValue())
|
|
<< " -> " << dependVars[i] << " : " << dependTypes[i];
|
|
}
|
|
}
|
|
|
|
/// Verifies Depend clause
|
|
static LogicalResult verifyDependVarList(Operation *op,
|
|
std::optional<ArrayAttr> dependKinds,
|
|
OperandRange dependVars) {
|
|
if (!dependVars.empty()) {
|
|
if (!dependKinds || dependKinds->size() != dependVars.size())
|
|
return op->emitOpError() << "expected as many depend values"
|
|
" as depend variables";
|
|
} else {
|
|
if (dependKinds && !dependKinds->empty())
|
|
return op->emitOpError() << "unexpected depend values";
|
|
return success();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, printer and verifier for Synchronization Hint (2.17.12)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parses a Synchronization Hint clause. The value of hint is an integer
|
|
/// which is a combination of different hints from `omp_sync_hint_t`.
|
|
///
|
|
/// hint-clause = `hint` `(` hint-value `)`
|
|
static ParseResult parseSynchronizationHint(OpAsmParser &parser,
|
|
IntegerAttr &hintAttr) {
|
|
StringRef hintKeyword;
|
|
int64_t hint = 0;
|
|
if (succeeded(parser.parseOptionalKeyword("none"))) {
|
|
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
|
|
return success();
|
|
}
|
|
auto parseKeyword = [&]() -> ParseResult {
|
|
if (failed(parser.parseKeyword(&hintKeyword)))
|
|
return failure();
|
|
if (hintKeyword == "uncontended")
|
|
hint |= 1;
|
|
else if (hintKeyword == "contended")
|
|
hint |= 2;
|
|
else if (hintKeyword == "nonspeculative")
|
|
hint |= 4;
|
|
else if (hintKeyword == "speculative")
|
|
hint |= 8;
|
|
else
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< hintKeyword << " is not a valid hint";
|
|
return success();
|
|
};
|
|
if (parser.parseCommaSeparatedList(parseKeyword))
|
|
return failure();
|
|
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
|
|
return success();
|
|
}
|
|
|
|
/// Prints a Synchronization Hint clause
|
|
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
|
|
IntegerAttr hintAttr) {
|
|
int64_t hint = hintAttr.getInt();
|
|
|
|
if (hint == 0) {
|
|
p << "none";
|
|
return;
|
|
}
|
|
|
|
// Helper function to get n-th bit from the right end of `value`
|
|
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
|
|
|
|
bool uncontended = bitn(hint, 0);
|
|
bool contended = bitn(hint, 1);
|
|
bool nonspeculative = bitn(hint, 2);
|
|
bool speculative = bitn(hint, 3);
|
|
|
|
SmallVector<StringRef> hints;
|
|
if (uncontended)
|
|
hints.push_back("uncontended");
|
|
if (contended)
|
|
hints.push_back("contended");
|
|
if (nonspeculative)
|
|
hints.push_back("nonspeculative");
|
|
if (speculative)
|
|
hints.push_back("speculative");
|
|
|
|
llvm::interleaveComma(hints, p);
|
|
}
|
|
|
|
/// Verifies a synchronization hint clause
|
|
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
|
|
|
|
// Helper function to get n-th bit from the right end of `value`
|
|
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
|
|
|
|
bool uncontended = bitn(hint, 0);
|
|
bool contended = bitn(hint, 1);
|
|
bool nonspeculative = bitn(hint, 2);
|
|
bool speculative = bitn(hint, 3);
|
|
|
|
if (uncontended && contended)
|
|
return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
|
|
"omp_sync_hint_contended cannot be combined";
|
|
if (nonspeculative && speculative)
|
|
return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
|
|
"omp_sync_hint_speculative cannot be combined.";
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser, printer and verifier for Target
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Helper function to get bitwise AND of `value` and 'flag'
|
|
uint64_t mapTypeToBitFlag(uint64_t value,
|
|
llvm::omp::OpenMPOffloadMappingFlags flag) {
|
|
return value & llvm::to_underlying(flag);
|
|
}
|
|
|
|
/// Parses a map_entries map type from a string format back into its numeric
|
|
/// value.
|
|
///
|
|
/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
|
|
/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
|
|
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
|
|
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
|
|
|
|
// This simply verifies the correct keyword is read in, the
|
|
// keyword itself is stored inside of the operation
|
|
auto parseTypeAndMod = [&]() -> ParseResult {
|
|
StringRef mapTypeMod;
|
|
if (parser.parseKeyword(&mapTypeMod))
|
|
return failure();
|
|
|
|
if (mapTypeMod == "always")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
|
|
|
|
if (mapTypeMod == "implicit")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
|
|
|
|
if (mapTypeMod == "ompx_hold")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
|
|
|
|
if (mapTypeMod == "close")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
|
|
|
|
if (mapTypeMod == "present")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
|
|
|
|
if (mapTypeMod == "to")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
|
|
|
|
if (mapTypeMod == "from")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
|
|
|
|
if (mapTypeMod == "tofrom")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
|
|
|
|
if (mapTypeMod == "delete")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
|
|
|
|
if (mapTypeMod == "return_param")
|
|
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
|
|
|
|
return success();
|
|
};
|
|
|
|
if (parser.parseCommaSeparatedList(parseTypeAndMod))
|
|
return failure();
|
|
|
|
mapType = parser.getBuilder().getIntegerAttr(
|
|
parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
|
|
llvm::to_underlying(mapTypeBits));
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Prints a map_entries map type from its numeric value out into its string
|
|
/// format.
|
|
static void printMapClause(OpAsmPrinter &p, Operation *op,
|
|
IntegerAttr mapType) {
|
|
uint64_t mapTypeBits = mapType.getUInt();
|
|
|
|
bool emitAllocRelease = true;
|
|
llvm::SmallVector<std::string, 4> mapTypeStrs;
|
|
|
|
// handling of always, close, present placed at the beginning of the string
|
|
// to aid readability
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
|
|
mapTypeStrs.push_back("always");
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
|
|
mapTypeStrs.push_back("implicit");
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
|
|
mapTypeStrs.push_back("ompx_hold");
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
|
|
mapTypeStrs.push_back("close");
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
|
|
mapTypeStrs.push_back("present");
|
|
|
|
// special handling of to/from/tofrom/delete and release/alloc, release +
|
|
// alloc are the abscense of one of the other flags, whereas tofrom requires
|
|
// both the to and from flag to be set.
|
|
bool to = mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
|
|
bool from = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
|
|
if (to && from) {
|
|
emitAllocRelease = false;
|
|
mapTypeStrs.push_back("tofrom");
|
|
} else if (from) {
|
|
emitAllocRelease = false;
|
|
mapTypeStrs.push_back("from");
|
|
} else if (to) {
|
|
emitAllocRelease = false;
|
|
mapTypeStrs.push_back("to");
|
|
}
|
|
if (mapTypeToBitFlag(mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
|
|
emitAllocRelease = false;
|
|
mapTypeStrs.push_back("delete");
|
|
}
|
|
if (mapTypeToBitFlag(
|
|
mapTypeBits,
|
|
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
|
|
emitAllocRelease = false;
|
|
mapTypeStrs.push_back("return_param");
|
|
}
|
|
if (emitAllocRelease)
|
|
mapTypeStrs.push_back("exit_release_or_enter_alloc");
|
|
|
|
for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
|
|
p << mapTypeStrs[i];
|
|
if (i + 1 < mapTypeStrs.size()) {
|
|
p << ", ";
|
|
}
|
|
}
|
|
}
|
|
|
|
static ParseResult parseMembersIndex(OpAsmParser &parser,
|
|
ArrayAttr &membersIdx) {
|
|
SmallVector<Attribute> values, memberIdxs;
|
|
|
|
auto parseIndices = [&]() -> ParseResult {
|
|
int64_t value;
|
|
if (parser.parseInteger(value))
|
|
return failure();
|
|
values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
|
|
APInt(64, value, /*isSigned=*/false)));
|
|
return success();
|
|
};
|
|
|
|
do {
|
|
if (failed(parser.parseLSquare()))
|
|
return failure();
|
|
|
|
if (parser.parseCommaSeparatedList(parseIndices))
|
|
return failure();
|
|
|
|
if (failed(parser.parseRSquare()))
|
|
return failure();
|
|
|
|
memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
|
|
values.clear();
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
if (!memberIdxs.empty())
|
|
membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
|
|
ArrayAttr membersIdx) {
|
|
if (!membersIdx)
|
|
return;
|
|
|
|
llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
|
|
p << "[";
|
|
auto memberIdx = cast<ArrayAttr>(v);
|
|
llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
|
|
p << cast<IntegerAttr>(v2).getInt();
|
|
});
|
|
p << "]";
|
|
});
|
|
}
|
|
|
|
static void printCaptureType(OpAsmPrinter &p, Operation *op,
|
|
VariableCaptureKindAttr mapCaptureType) {
|
|
std::string typeCapStr;
|
|
llvm::raw_string_ostream typeCap(typeCapStr);
|
|
if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
|
|
typeCap << "ByRef";
|
|
if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
|
|
typeCap << "ByCopy";
|
|
if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
|
|
typeCap << "VLAType";
|
|
if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
|
|
typeCap << "This";
|
|
p << typeCapStr;
|
|
}
|
|
|
|
static ParseResult parseCaptureType(OpAsmParser &parser,
|
|
VariableCaptureKindAttr &mapCaptureType) {
|
|
StringRef mapCaptureKey;
|
|
if (parser.parseKeyword(&mapCaptureKey))
|
|
return failure();
|
|
|
|
if (mapCaptureKey == "This")
|
|
mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
|
|
parser.getContext(), mlir::omp::VariableCaptureKind::This);
|
|
if (mapCaptureKey == "ByRef")
|
|
mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
|
|
parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
|
|
if (mapCaptureKey == "ByCopy")
|
|
mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
|
|
parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
|
|
if (mapCaptureKey == "VLAType")
|
|
mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
|
|
parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
|
|
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
|
|
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
|
|
|
|
for (auto mapOp : mapVars) {
|
|
if (!mapOp.getDefiningOp())
|
|
return emitError(op->getLoc(), "missing map operation");
|
|
|
|
if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
|
|
uint64_t mapTypeBits = mapInfoOp.getMapType();
|
|
|
|
bool to = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
|
|
bool from = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
|
|
bool del = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
|
|
|
|
bool always = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
|
|
bool close = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
|
|
bool implicit = mapTypeToBitFlag(
|
|
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
|
|
|
|
if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
|
|
return emitError(op->getLoc(),
|
|
"to, from, tofrom and alloc map types are permitted");
|
|
|
|
if (isa<TargetEnterDataOp>(op) && (from || del))
|
|
return emitError(op->getLoc(), "to and alloc map types are permitted");
|
|
|
|
if (isa<TargetExitDataOp>(op) && to)
|
|
return emitError(op->getLoc(),
|
|
"from, release and delete map types are permitted");
|
|
|
|
if (isa<TargetUpdateOp>(op)) {
|
|
if (del) {
|
|
return emitError(op->getLoc(),
|
|
"at least one of to or from map types must be "
|
|
"specified, other map types are not permitted");
|
|
}
|
|
|
|
if (!to && !from) {
|
|
return emitError(op->getLoc(),
|
|
"at least one of to or from map types must be "
|
|
"specified, other map types are not permitted");
|
|
}
|
|
|
|
auto updateVar = mapInfoOp.getVarPtr();
|
|
|
|
if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
|
|
(from && updateToVars.contains(updateVar))) {
|
|
return emitError(
|
|
op->getLoc(),
|
|
"either to or from map types can be specified, not both");
|
|
}
|
|
|
|
if (always || close || implicit) {
|
|
return emitError(
|
|
op->getLoc(),
|
|
"present, mapper and iterator map type modifiers are permitted");
|
|
}
|
|
|
|
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
|
|
}
|
|
} else if (!isa<DeclareMapperInfoOp>(op)) {
|
|
return emitError(op->getLoc(),
|
|
"map argument is not a map entry operation");
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
|
|
std::optional<DenseI64ArrayAttr> privateMapIndices =
|
|
targetOp.getPrivateMapsAttr();
|
|
|
|
// None of the private operands are mapped.
|
|
if (!privateMapIndices.has_value() || !privateMapIndices.value())
|
|
return success();
|
|
|
|
OperandRange privateVars = targetOp.getPrivateVars();
|
|
|
|
if (privateMapIndices.value().size() !=
|
|
static_cast<int64_t>(privateVars.size()))
|
|
return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
|
|
"`private_maps` attribute mismatch");
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MapInfoOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
|
|
StringRef clauseName,
|
|
OperandRange vars) {
|
|
for (Value var : vars)
|
|
if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
|
|
return op->emitOpError()
|
|
<< "'" << clauseName
|
|
<< "' arguments must be defined by 'omp.map.info' ops";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult MapInfoOp::verify() {
|
|
if (getMapperId() &&
|
|
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
|
|
*this, getMapperIdAttr())) {
|
|
return emitError("invalid mapper id");
|
|
}
|
|
|
|
if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TargetDataOp::build(OpBuilder &builder, OperationState &state,
|
|
const TargetDataOperands &clauses) {
|
|
TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
|
|
clauses.mapVars, clauses.useDeviceAddrVars,
|
|
clauses.useDevicePtrVars);
|
|
}
|
|
|
|
LogicalResult TargetDataOp::verify() {
|
|
if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
|
|
getUseDeviceAddrVars().empty()) {
|
|
return ::emitError(this->getLoc(),
|
|
"At least one of map, use_device_ptr_vars, or "
|
|
"use_device_addr_vars operand must be present");
|
|
}
|
|
|
|
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
|
|
getUseDevicePtrVars())))
|
|
return failure();
|
|
|
|
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
|
|
getUseDeviceAddrVars())))
|
|
return failure();
|
|
|
|
return verifyMapClause(*this, getMapVars());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetEnterDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TargetEnterDataOp::build(
|
|
OpBuilder &builder, OperationState &state,
|
|
const TargetEnterExitUpdateDataOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TargetEnterDataOp::build(builder, state,
|
|
makeArrayAttr(ctx, clauses.dependKinds),
|
|
clauses.dependVars, clauses.device, clauses.ifExpr,
|
|
clauses.mapVars, clauses.nowait);
|
|
}
|
|
|
|
LogicalResult TargetEnterDataOp::verify() {
|
|
LogicalResult verifyDependVars =
|
|
verifyDependVarList(*this, getDependKinds(), getDependVars());
|
|
return failed(verifyDependVars) ? verifyDependVars
|
|
: verifyMapClause(*this, getMapVars());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetExitDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
|
|
const TargetEnterExitUpdateDataOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TargetExitDataOp::build(builder, state,
|
|
makeArrayAttr(ctx, clauses.dependKinds),
|
|
clauses.dependVars, clauses.device, clauses.ifExpr,
|
|
clauses.mapVars, clauses.nowait);
|
|
}
|
|
|
|
LogicalResult TargetExitDataOp::verify() {
|
|
LogicalResult verifyDependVars =
|
|
verifyDependVarList(*this, getDependKinds(), getDependVars());
|
|
return failed(verifyDependVars) ? verifyDependVars
|
|
: verifyMapClause(*this, getMapVars());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetUpdateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
|
|
const TargetEnterExitUpdateDataOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
|
|
clauses.dependVars, clauses.device, clauses.ifExpr,
|
|
clauses.mapVars, clauses.nowait);
|
|
}
|
|
|
|
LogicalResult TargetUpdateOp::verify() {
|
|
LogicalResult verifyDependVars =
|
|
verifyDependVarList(*this, getDependKinds(), getDependVars());
|
|
return failed(verifyDependVars) ? verifyDependVars
|
|
: verifyMapClause(*this, getMapVars());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TargetOp::build(OpBuilder &builder, OperationState &state,
|
|
const TargetOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
|
|
// inReductionByref, inReductionSyms.
|
|
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
|
|
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
|
|
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
|
|
clauses.hostEvalVars, clauses.ifExpr,
|
|
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
|
|
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
|
|
clauses.mapVars, clauses.nowait, clauses.privateVars,
|
|
makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.threadLimit,
|
|
/*private_maps=*/nullptr);
|
|
}
|
|
|
|
LogicalResult TargetOp::verify() {
|
|
if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
|
|
return failure();
|
|
|
|
if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
|
|
getHasDeviceAddrVars())))
|
|
return failure();
|
|
|
|
if (failed(verifyMapClause(*this, getMapVars())))
|
|
return failure();
|
|
|
|
return verifyPrivateVarsMapping(*this);
|
|
}
|
|
|
|
LogicalResult TargetOp::verifyRegions() {
|
|
auto teamsOps = getOps<TeamsOp>();
|
|
if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
|
|
return emitError("target containing multiple 'omp.teams' nested ops");
|
|
|
|
// Check that host_eval values are only used in legal ways.
|
|
bool hostEvalTripCount;
|
|
Operation *capturedOp = getInnermostCapturedOmpOp();
|
|
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
|
|
for (Value hostEvalArg :
|
|
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
|
|
for (Operation *user : hostEvalArg.getUsers()) {
|
|
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
|
|
if (llvm::is_contained({teamsOp.getNumTeamsLower(),
|
|
teamsOp.getNumTeamsUpper(),
|
|
teamsOp.getThreadLimit()},
|
|
hostEvalArg))
|
|
continue;
|
|
|
|
return emitOpError() << "host_eval argument only legal as 'num_teams' "
|
|
"and 'thread_limit' in 'omp.teams'";
|
|
}
|
|
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
|
|
if (execMode == TargetExecMode::spmd &&
|
|
parallelOp->isAncestor(capturedOp) &&
|
|
hostEvalArg == parallelOp.getNumThreads())
|
|
continue;
|
|
|
|
return emitOpError()
|
|
<< "host_eval argument only legal as 'num_threads' in "
|
|
"'omp.parallel' when representing target SPMD";
|
|
}
|
|
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
|
|
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
|
|
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
|
|
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
|
|
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
|
|
continue;
|
|
|
|
return emitOpError() << "host_eval argument only legal as loop bounds "
|
|
"and steps in 'omp.loop_nest' when trip count "
|
|
"must be evaluated in the host";
|
|
}
|
|
|
|
return emitOpError() << "host_eval argument illegal use in '"
|
|
<< user->getName() << "' operation";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static Operation *
|
|
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
|
|
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
|
|
assert(rootOp && "expected valid operation");
|
|
|
|
Dialect *ompDialect = rootOp->getDialect();
|
|
Operation *capturedOp = nullptr;
|
|
DominanceInfo domInfo;
|
|
|
|
// Process in pre-order to check operations from outermost to innermost,
|
|
// ensuring we only enter the region of an operation if it meets the criteria
|
|
// for being captured. We stop the exploration of nested operations as soon as
|
|
// we process a region holding no operations to be captured.
|
|
rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
|
if (op == rootOp)
|
|
return WalkResult::advance();
|
|
|
|
// Ignore operations of other dialects or omp operations with no regions,
|
|
// because these will only be checked if they are siblings of an omp
|
|
// operation that can potentially be captured.
|
|
bool isOmpDialect = op->getDialect() == ompDialect;
|
|
bool hasRegions = op->getNumRegions() > 0;
|
|
if (!isOmpDialect || !hasRegions)
|
|
return WalkResult::skip();
|
|
|
|
// This operation cannot be captured if it can be executed more than once
|
|
// (i.e. its block's successors can reach it) or if it's not guaranteed to
|
|
// be executed before all exits of the region (i.e. it doesn't dominate all
|
|
// blocks with no successors reachable from the entry block).
|
|
if (checkSingleMandatoryExec) {
|
|
Region *parentRegion = op->getParentRegion();
|
|
Block *parentBlock = op->getBlock();
|
|
|
|
for (Block *successor : parentBlock->getSuccessors())
|
|
if (successor->isReachable(parentBlock))
|
|
return WalkResult::interrupt();
|
|
|
|
for (Block &block : *parentRegion)
|
|
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
|
|
!domInfo.dominates(parentBlock, &block))
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
// Don't capture this op if it has a not-allowed sibling, and stop recursing
|
|
// into nested operations.
|
|
for (Operation &sibling : op->getParentRegion()->getOps())
|
|
if (&sibling != op && !siblingAllowedFn(&sibling))
|
|
return WalkResult::interrupt();
|
|
|
|
// Don't continue capturing nested operations if we reach an omp.loop_nest.
|
|
// Otherwise, process the contents of this operation.
|
|
capturedOp = op;
|
|
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
|
|
: WalkResult::advance();
|
|
});
|
|
|
|
return capturedOp;
|
|
}
|
|
|
|
Operation *TargetOp::getInnermostCapturedOmpOp() {
|
|
auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
|
|
|
|
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
|
|
// effects, but don't include a memory write effect.
|
|
return findCapturedOmpOp(
|
|
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
|
|
if (!sibling)
|
|
return false;
|
|
|
|
if (ompDialect == sibling->getDialect())
|
|
return sibling->hasTrait<OpTrait::IsTerminator>();
|
|
|
|
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
|
|
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
|
|
effects;
|
|
memOp.getEffects(effects);
|
|
return !llvm::any_of(
|
|
effects, [&](MemoryEffects::EffectInstance &effect) {
|
|
return isa<MemoryEffects::Write>(effect.getEffect()) &&
|
|
isa<SideEffects::AutomaticAllocationScopeResource>(
|
|
effect.getResource());
|
|
});
|
|
}
|
|
return true;
|
|
});
|
|
}
|
|
|
|
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
|
|
bool *hostEvalTripCount) {
|
|
// TODO: Support detection of bare kernel mode.
|
|
// A non-null captured op is only valid if it resides inside of a TargetOp
|
|
// and is the result of calling getInnermostCapturedOmpOp() on it.
|
|
TargetOp targetOp =
|
|
capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
|
|
assert((!capturedOp ||
|
|
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
|
|
"unexpected captured op");
|
|
|
|
if (hostEvalTripCount)
|
|
*hostEvalTripCount = false;
|
|
|
|
// If it's not capturing a loop, it's a default target region.
|
|
if (!isa_and_present<LoopNestOp>(capturedOp))
|
|
return TargetExecMode::generic;
|
|
|
|
// Get the innermost non-simd loop wrapper.
|
|
SmallVector<LoopWrapperInterface> loopWrappers;
|
|
cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
|
|
assert(!loopWrappers.empty());
|
|
|
|
LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
|
|
if (isa<SimdOp>(innermostWrapper))
|
|
innermostWrapper = std::next(innermostWrapper);
|
|
|
|
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
|
|
if (numWrappers != 1 && numWrappers != 2)
|
|
return TargetExecMode::generic;
|
|
|
|
// Detect target-teams-distribute-parallel-wsloop[-simd].
|
|
if (numWrappers == 2) {
|
|
if (!isa<WsloopOp>(innermostWrapper))
|
|
return TargetExecMode::generic;
|
|
|
|
innermostWrapper = std::next(innermostWrapper);
|
|
if (!isa<DistributeOp>(innermostWrapper))
|
|
return TargetExecMode::generic;
|
|
|
|
Operation *parallelOp = (*innermostWrapper)->getParentOp();
|
|
if (!isa_and_present<ParallelOp>(parallelOp))
|
|
return TargetExecMode::generic;
|
|
|
|
Operation *teamsOp = parallelOp->getParentOp();
|
|
if (!isa_and_present<TeamsOp>(teamsOp))
|
|
return TargetExecMode::generic;
|
|
|
|
if (teamsOp->getParentOp() == targetOp.getOperation()) {
|
|
if (hostEvalTripCount)
|
|
*hostEvalTripCount = true;
|
|
return TargetExecMode::spmd;
|
|
}
|
|
}
|
|
// Detect target-teams-distribute[-simd] and target-teams-loop.
|
|
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
|
|
Operation *teamsOp = (*innermostWrapper)->getParentOp();
|
|
if (!isa_and_present<TeamsOp>(teamsOp))
|
|
return TargetExecMode::generic;
|
|
|
|
if (teamsOp->getParentOp() != targetOp.getOperation())
|
|
return TargetExecMode::generic;
|
|
|
|
if (hostEvalTripCount)
|
|
*hostEvalTripCount = true;
|
|
|
|
if (isa<LoopOp>(innermostWrapper))
|
|
return TargetExecMode::spmd;
|
|
|
|
return TargetExecMode::generic;
|
|
}
|
|
// Detect target-parallel-wsloop[-simd].
|
|
else if (isa<WsloopOp>(innermostWrapper)) {
|
|
Operation *parallelOp = (*innermostWrapper)->getParentOp();
|
|
if (!isa_and_present<ParallelOp>(parallelOp))
|
|
return TargetExecMode::generic;
|
|
|
|
if (parallelOp->getParentOp() == targetOp.getOperation())
|
|
return TargetExecMode::spmd;
|
|
}
|
|
|
|
return TargetExecMode::generic;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ParallelOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ParallelOp::build(OpBuilder &builder, OperationState &state,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
|
|
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
|
|
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
|
|
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
|
|
/*proc_bind_kind=*/nullptr,
|
|
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
|
|
/*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
|
|
state.addAttributes(attributes);
|
|
}
|
|
|
|
void ParallelOp::build(OpBuilder &builder, OperationState &state,
|
|
const ParallelOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
clauses.ifExpr, clauses.numThreads, clauses.privateVars,
|
|
makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.procBindKind,
|
|
clauses.reductionMod, clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms));
|
|
}
|
|
|
|
template <typename OpType>
|
|
static LogicalResult verifyPrivateVarList(OpType &op) {
|
|
auto privateVars = op.getPrivateVars();
|
|
auto privateSyms = op.getPrivateSymsAttr();
|
|
|
|
if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
|
|
return success();
|
|
|
|
auto numPrivateVars = privateVars.size();
|
|
auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
|
|
|
|
if (numPrivateVars != numPrivateSyms)
|
|
return op.emitError() << "inconsistent number of private variables and "
|
|
"privatizer op symbols, private vars: "
|
|
<< numPrivateVars
|
|
<< " vs. privatizer op symbols: " << numPrivateSyms;
|
|
|
|
for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
|
|
Type varType = std::get<0>(privateVarInfo).getType();
|
|
SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
|
|
PrivateClauseOp privatizerOp =
|
|
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
|
|
|
|
if (privatizerOp == nullptr)
|
|
return op.emitError() << "failed to lookup privatizer op with symbol: '"
|
|
<< privateSym << "'";
|
|
|
|
Type privatizerType = privatizerOp.getArgType();
|
|
|
|
if (privatizerType && (varType != privatizerType))
|
|
return op.emitError()
|
|
<< "type mismatch between a "
|
|
<< (privatizerOp.getDataSharingType() ==
|
|
DataSharingClauseType::Private
|
|
? "private"
|
|
: "firstprivate")
|
|
<< " variable and its privatizer op, var type: " << varType
|
|
<< " vs. privatizer op type: " << privatizerType;
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ParallelOp::verify() {
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
|
|
if (failed(verifyPrivateVarList(*this)))
|
|
return failure();
|
|
|
|
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
|
|
getReductionByref());
|
|
}
|
|
|
|
LogicalResult ParallelOp::verifyRegions() {
|
|
auto distChildOps = getOps<DistributeOp>();
|
|
int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
|
|
if (numDistChildOps > 1)
|
|
return emitError()
|
|
<< "multiple 'omp.distribute' nested inside of 'omp.parallel'";
|
|
|
|
if (numDistChildOps == 1) {
|
|
if (!isComposite())
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite operation";
|
|
|
|
auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
|
|
Operation &distributeOp = **distChildOps.begin();
|
|
for (Operation &childOp : getOps()) {
|
|
if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
|
|
continue;
|
|
|
|
if (!childOp.hasTrait<OpTrait::IsTerminator>())
|
|
return emitError() << "unexpected OpenMP operation inside of composite "
|
|
"'omp.parallel': "
|
|
<< childOp.getName();
|
|
}
|
|
} else if (isComposite()) {
|
|
return emitError()
|
|
<< "'omp.composite' attribute present in non-composite operation";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TeamsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static bool opInGlobalImplicitParallelRegion(Operation *op) {
|
|
while ((op = op->getParentOp()))
|
|
if (isa<OpenMPDialect>(op->getDialect()))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
void TeamsOp::build(OpBuilder &builder, OperationState &state,
|
|
const TeamsOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
|
|
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
|
|
/*private_vars=*/{}, /*private_syms=*/nullptr,
|
|
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
|
|
clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms),
|
|
clauses.threadLimit);
|
|
}
|
|
|
|
LogicalResult TeamsOp::verify() {
|
|
// Check parent region
|
|
// TODO If nested inside of a target region, also check that it does not
|
|
// contain any statements, declarations or directives other than this
|
|
// omp.teams construct. The issue is how to support the initialization of
|
|
// this operation's own arguments (allow SSA values across omp.target?).
|
|
Operation *op = getOperation();
|
|
if (!isa<TargetOp>(op->getParentOp()) &&
|
|
!opInGlobalImplicitParallelRegion(op))
|
|
return emitError("expected to be nested inside of omp.target or not nested "
|
|
"in any OpenMP dialect operations");
|
|
|
|
// Check for num_teams clause restrictions
|
|
if (auto numTeamsLowerBound = getNumTeamsLower()) {
|
|
auto numTeamsUpperBound = getNumTeamsUpper();
|
|
if (!numTeamsUpperBound)
|
|
return emitError("expected num_teams upper bound to be defined if the "
|
|
"lower bound is defined");
|
|
if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
|
|
return emitError(
|
|
"expected num_teams upper bound and lower bound to be the same type");
|
|
}
|
|
|
|
// Check for allocate clause restrictions
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
|
|
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
|
|
getReductionByref());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SectionOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OperandRange SectionOp::getPrivateVars() {
|
|
return getParentOp().getPrivateVars();
|
|
}
|
|
|
|
OperandRange SectionOp::getReductionVars() {
|
|
return getParentOp().getReductionVars();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SectionsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void SectionsOp::build(OpBuilder &builder, OperationState &state,
|
|
const SectionsOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
|
|
SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
clauses.nowait, /*private_vars=*/{},
|
|
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
|
|
clauses.reductionMod, clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms));
|
|
}
|
|
|
|
LogicalResult SectionsOp::verify() {
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
|
|
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
|
|
getReductionByref());
|
|
}
|
|
|
|
LogicalResult SectionsOp::verifyRegions() {
|
|
for (auto &inst : *getRegion().begin()) {
|
|
if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
|
|
return emitOpError()
|
|
<< "expected omp.section op or terminator op inside region";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SingleOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void SingleOp::build(OpBuilder &builder, OperationState &state,
|
|
const SingleOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
|
|
SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
clauses.copyprivateVars,
|
|
makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
|
|
/*private_vars=*/{}, /*private_syms=*/nullptr,
|
|
/*private_needs_barrier=*/nullptr);
|
|
}
|
|
|
|
LogicalResult SingleOp::verify() {
|
|
// Check for allocate clause restrictions
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
|
|
return verifyCopyprivateVarList(*this, getCopyprivateVars(),
|
|
getCopyprivateSyms());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WorkshareOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void WorkshareOp::build(OpBuilder &builder, OperationState &state,
|
|
const WorkshareOperands &clauses) {
|
|
WorkshareOp::build(builder, state, clauses.nowait);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WorkshareLoopWrapperOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WorkshareLoopWrapperOp::verify() {
|
|
if (!(*this)->getParentOfType<WorkshareOp>())
|
|
return emitOpError() << "must be nested in an omp.workshare";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
|
|
if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
|
|
getNestedWrapper())
|
|
return emitOpError() << "expected to be a standalone loop wrapper";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoopWrapperInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult LoopWrapperInterface::verifyImpl() {
|
|
Operation *op = this->getOperation();
|
|
if (!op->hasTrait<OpTrait::NoTerminator>() ||
|
|
!op->hasTrait<OpTrait::SingleBlock>())
|
|
return emitOpError() << "loop wrapper must also have the `NoTerminator` "
|
|
"and `SingleBlock` traits";
|
|
|
|
if (op->getNumRegions() != 1)
|
|
return emitOpError() << "loop wrapper does not contain exactly one region";
|
|
|
|
Region ®ion = op->getRegion(0);
|
|
if (range_size(region.getOps()) != 1)
|
|
return emitOpError()
|
|
<< "loop wrapper does not contain exactly one nested op";
|
|
|
|
Operation &firstOp = *region.op_begin();
|
|
if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
|
|
return emitOpError() << "nested in loop wrapper is not another loop "
|
|
"wrapper or `omp.loop_nest`";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoopOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LoopOp::build(OpBuilder &builder, OperationState &state,
|
|
const LoopOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
|
|
LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
|
|
makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
|
|
clauses.reductionMod, clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms));
|
|
}
|
|
|
|
LogicalResult LoopOp::verify() {
|
|
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
|
|
getReductionByref());
|
|
}
|
|
|
|
LogicalResult LoopOp::verifyRegions() {
|
|
if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
|
|
getNestedWrapper())
|
|
return emitOpError() << "expected to be a standalone loop wrapper";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WsloopOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void WsloopOp::build(OpBuilder &builder, OperationState &state,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
|
|
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
|
|
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
|
|
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
|
|
/*private_needs_barrier=*/false,
|
|
/*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
|
|
/*reduction_byref=*/nullptr,
|
|
/*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
|
|
/*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
|
|
/*schedule_simd=*/false);
|
|
state.addAttributes(attributes);
|
|
}
|
|
|
|
void WsloopOp::build(OpBuilder &builder, OperationState &state,
|
|
const WsloopOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO: Store clauses in op: allocateVars, allocatorVars
|
|
WsloopOp::build(
|
|
builder, state,
|
|
/*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
|
|
clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
|
|
clauses.ordered, clauses.privateVars,
|
|
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
|
|
clauses.reductionMod, clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
|
|
clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
|
|
}
|
|
|
|
LogicalResult WsloopOp::verify() {
|
|
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
|
|
getReductionByref());
|
|
}
|
|
|
|
LogicalResult WsloopOp::verifyRegions() {
|
|
bool isCompositeChildLeaf =
|
|
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
|
|
|
|
if (LoopWrapperInterface nested = getNestedWrapper()) {
|
|
if (!isComposite())
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite wrapper";
|
|
|
|
// Check for the allowed leaf constructs that may appear in a composite
|
|
// construct directly after DO/FOR.
|
|
if (!isa<SimdOp>(nested))
|
|
return emitError() << "only supported nested wrapper is 'omp.simd'";
|
|
|
|
} else if (isComposite() && !isCompositeChildLeaf) {
|
|
return emitError()
|
|
<< "'omp.composite' attribute present in non-composite wrapper";
|
|
} else if (!isComposite() && isCompositeChildLeaf) {
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite wrapper";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Simd construct [2.9.3.1]
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void SimdOp::build(OpBuilder &builder, OperationState &state,
|
|
const SimdOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
// TODO Store clauses in op: linearVars, linearStepVars
|
|
SimdOp::build(builder, state, clauses.alignedVars,
|
|
makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
|
|
/*linear_vars=*/{}, /*linear_step_vars=*/{},
|
|
clauses.nontemporalVars, clauses.order, clauses.orderMod,
|
|
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.reductionMod,
|
|
clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
|
|
clauses.simdlen);
|
|
}
|
|
|
|
LogicalResult SimdOp::verify() {
|
|
if (getSimdlen().has_value() && getSafelen().has_value() &&
|
|
getSimdlen().value() > getSafelen().value())
|
|
return emitOpError()
|
|
<< "simdlen clause and safelen clause are both present, but the "
|
|
"simdlen value is not less than or equal to safelen value";
|
|
|
|
if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
|
|
return failure();
|
|
|
|
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
|
|
return failure();
|
|
|
|
bool isCompositeChildLeaf =
|
|
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
|
|
|
|
if (!isComposite() && isCompositeChildLeaf)
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite wrapper";
|
|
|
|
if (isComposite() && !isCompositeChildLeaf)
|
|
return emitError()
|
|
<< "'omp.composite' attribute present in non-composite wrapper";
|
|
|
|
// Firstprivate is not allowed for SIMD in the standard. Check that none of
|
|
// the private decls are for firstprivate.
|
|
std::optional<ArrayAttr> privateSyms = getPrivateSyms();
|
|
if (privateSyms) {
|
|
for (const Attribute &sym : *privateSyms) {
|
|
auto symRef = cast<SymbolRefAttr>(sym);
|
|
omp::PrivateClauseOp privatizer =
|
|
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
|
|
getOperation(), symRef);
|
|
if (!privatizer)
|
|
return emitError() << "Cannot find privatizer '" << symRef << "'";
|
|
if (privatizer.getDataSharingType() ==
|
|
DataSharingClauseType::FirstPrivate)
|
|
return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SimdOp::verifyRegions() {
|
|
if (getNestedWrapper())
|
|
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Distribute construct [2.9.4.1]
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void DistributeOp::build(OpBuilder &builder, OperationState &state,
|
|
const DistributeOperands &clauses) {
|
|
DistributeOp::build(builder, state, clauses.allocateVars,
|
|
clauses.allocatorVars, clauses.distScheduleStatic,
|
|
clauses.distScheduleChunkSize, clauses.order,
|
|
clauses.orderMod, clauses.privateVars,
|
|
makeArrayAttr(builder.getContext(), clauses.privateSyms),
|
|
clauses.privateNeedsBarrier);
|
|
}
|
|
|
|
LogicalResult DistributeOp::verify() {
|
|
if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
|
|
return emitOpError() << "chunk size set without "
|
|
"dist_schedule_static being present";
|
|
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult DistributeOp::verifyRegions() {
|
|
if (LoopWrapperInterface nested = getNestedWrapper()) {
|
|
if (!isComposite())
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite wrapper";
|
|
// Check for the allowed leaf constructs that may appear in a composite
|
|
// construct directly after DISTRIBUTE.
|
|
if (isa<WsloopOp>(nested)) {
|
|
Operation *parentOp = (*this)->getParentOp();
|
|
if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
|
|
!cast<ComposableOpInterface>(parentOp).isComposite()) {
|
|
return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
|
|
"when a composite 'omp.parallel' is the direct "
|
|
"parent";
|
|
}
|
|
} else if (!isa<SimdOp>(nested))
|
|
return emitError() << "only supported nested wrappers are 'omp.simd' and "
|
|
"'omp.wsloop'";
|
|
} else if (isComposite()) {
|
|
return emitError()
|
|
<< "'omp.composite' attribute present in non-composite wrapper";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareMapperOp / DeclareMapperInfoOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult DeclareMapperInfoOp::verify() {
|
|
return verifyMapClause(*this, getMapVars());
|
|
}
|
|
|
|
LogicalResult DeclareMapperOp::verifyRegions() {
|
|
if (!llvm::isa_and_present<DeclareMapperInfoOp>(
|
|
getRegion().getBlocks().front().getTerminator()))
|
|
return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareReductionOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult DeclareReductionOp::verifyRegions() {
|
|
if (!getAllocRegion().empty()) {
|
|
for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
|
|
if (yieldOp.getResults().size() != 1 ||
|
|
yieldOp.getResults().getTypes()[0] != getType())
|
|
return emitOpError() << "expects alloc region to yield a value "
|
|
"of the reduction type";
|
|
}
|
|
}
|
|
|
|
if (getInitializerRegion().empty())
|
|
return emitOpError() << "expects non-empty initializer region";
|
|
Block &initializerEntryBlock = getInitializerRegion().front();
|
|
|
|
if (initializerEntryBlock.getNumArguments() == 1) {
|
|
if (!getAllocRegion().empty())
|
|
return emitOpError() << "expects two arguments to the initializer region "
|
|
"when an allocation region is used";
|
|
} else if (initializerEntryBlock.getNumArguments() == 2) {
|
|
if (getAllocRegion().empty())
|
|
return emitOpError() << "expects one argument to the initializer region "
|
|
"when no allocation region is used";
|
|
} else {
|
|
return emitOpError()
|
|
<< "expects one or two arguments to the initializer region";
|
|
}
|
|
|
|
for (mlir::Value arg : initializerEntryBlock.getArguments())
|
|
if (arg.getType() != getType())
|
|
return emitOpError() << "expects initializer region argument to match "
|
|
"the reduction type";
|
|
|
|
for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
|
|
if (yieldOp.getResults().size() != 1 ||
|
|
yieldOp.getResults().getTypes()[0] != getType())
|
|
return emitOpError() << "expects initializer region to yield a value "
|
|
"of the reduction type";
|
|
}
|
|
|
|
if (getReductionRegion().empty())
|
|
return emitOpError() << "expects non-empty reduction region";
|
|
Block &reductionEntryBlock = getReductionRegion().front();
|
|
if (reductionEntryBlock.getNumArguments() != 2 ||
|
|
reductionEntryBlock.getArgumentTypes()[0] !=
|
|
reductionEntryBlock.getArgumentTypes()[1] ||
|
|
reductionEntryBlock.getArgumentTypes()[0] != getType())
|
|
return emitOpError() << "expects reduction region with two arguments of "
|
|
"the reduction type";
|
|
for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
|
|
if (yieldOp.getResults().size() != 1 ||
|
|
yieldOp.getResults().getTypes()[0] != getType())
|
|
return emitOpError() << "expects reduction region to yield a value "
|
|
"of the reduction type";
|
|
}
|
|
|
|
if (!getAtomicReductionRegion().empty()) {
|
|
Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
|
|
if (atomicReductionEntryBlock.getNumArguments() != 2 ||
|
|
atomicReductionEntryBlock.getArgumentTypes()[0] !=
|
|
atomicReductionEntryBlock.getArgumentTypes()[1])
|
|
return emitOpError() << "expects atomic reduction region with two "
|
|
"arguments of the same type";
|
|
auto ptrType = llvm::dyn_cast<PointerLikeType>(
|
|
atomicReductionEntryBlock.getArgumentTypes()[0]);
|
|
if (!ptrType ||
|
|
(ptrType.getElementType() && ptrType.getElementType() != getType()))
|
|
return emitOpError() << "expects atomic reduction region arguments to "
|
|
"be accumulators containing the reduction type";
|
|
}
|
|
|
|
if (getCleanupRegion().empty())
|
|
return success();
|
|
Block &cleanupEntryBlock = getCleanupRegion().front();
|
|
if (cleanupEntryBlock.getNumArguments() != 1 ||
|
|
cleanupEntryBlock.getArgument(0).getType() != getType())
|
|
return emitOpError() << "expects cleanup region with one argument "
|
|
"of the reduction type";
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TaskOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TaskOp::build(OpBuilder &builder, OperationState &state,
|
|
const TaskOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
|
|
clauses.final, clauses.ifExpr, clauses.inReductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
|
|
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
|
|
clauses.priority, /*private_vars=*/clauses.privateVars,
|
|
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.untied,
|
|
clauses.eventHandle);
|
|
}
|
|
|
|
LogicalResult TaskOp::verify() {
|
|
LogicalResult verifyDependVars =
|
|
verifyDependVarList(*this, getDependKinds(), getDependVars());
|
|
return failed(verifyDependVars)
|
|
? verifyDependVars
|
|
: verifyReductionVarList(*this, getInReductionSyms(),
|
|
getInReductionVars(),
|
|
getInReductionByref());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TaskgroupOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
|
|
const TaskgroupOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TaskgroupOp::build(builder, state, clauses.allocateVars,
|
|
clauses.allocatorVars, clauses.taskReductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
|
|
makeArrayAttr(ctx, clauses.taskReductionSyms));
|
|
}
|
|
|
|
LogicalResult TaskgroupOp::verify() {
|
|
return verifyReductionVarList(*this, getTaskReductionSyms(),
|
|
getTaskReductionVars(),
|
|
getTaskReductionByref());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TaskloopOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TaskloopOp::build(OpBuilder &builder, OperationState &state,
|
|
const TaskloopOperands &clauses) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
TaskloopOp::build(
|
|
builder, state, clauses.allocateVars, clauses.allocatorVars,
|
|
clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
|
|
clauses.inReductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
|
|
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
|
|
clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
|
|
/*private_vars=*/clauses.privateVars,
|
|
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
|
|
clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
|
|
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
|
|
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
|
|
}
|
|
|
|
LogicalResult TaskloopOp::verify() {
|
|
if (getAllocateVars().size() != getAllocatorVars().size())
|
|
return emitError(
|
|
"expected equal sizes for allocate and allocator variables");
|
|
if (failed(verifyReductionVarList(*this, getReductionSyms(),
|
|
getReductionVars(), getReductionByref())) ||
|
|
failed(verifyReductionVarList(*this, getInReductionSyms(),
|
|
getInReductionVars(),
|
|
getInReductionByref())))
|
|
return failure();
|
|
|
|
if (!getReductionVars().empty() && getNogroup())
|
|
return emitError("if a reduction clause is present on the taskloop "
|
|
"directive, the nogroup clause must not be specified");
|
|
for (auto var : getReductionVars()) {
|
|
if (llvm::is_contained(getInReductionVars(), var))
|
|
return emitError("the same list item cannot appear in both a reduction "
|
|
"and an in_reduction clause");
|
|
}
|
|
|
|
if (getGrainsize() && getNumTasks()) {
|
|
return emitError(
|
|
"the grainsize clause and num_tasks clause are mutually exclusive and "
|
|
"may not appear on the same taskloop directive");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult TaskloopOp::verifyRegions() {
|
|
if (LoopWrapperInterface nested = getNestedWrapper()) {
|
|
if (!isComposite())
|
|
return emitError()
|
|
<< "'omp.composite' attribute missing from composite wrapper";
|
|
|
|
// Check for the allowed leaf constructs that may appear in a composite
|
|
// construct directly after TASKLOOP.
|
|
if (!isa<SimdOp>(nested))
|
|
return emitError() << "only supported nested wrapper is 'omp.simd'";
|
|
} else if (isComposite()) {
|
|
return emitError()
|
|
<< "'omp.composite' attribute present in non-composite wrapper";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoopNestOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse an opening `(` followed by induction variables followed by `)`
|
|
SmallVector<OpAsmParser::Argument> ivs;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
|
|
Type loopVarType;
|
|
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
|
|
parser.parseColonType(loopVarType) ||
|
|
// Parse loop bounds.
|
|
parser.parseEqual() ||
|
|
parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
|
|
parser.parseKeyword("to") ||
|
|
parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
|
|
return failure();
|
|
|
|
for (auto &iv : ivs)
|
|
iv.type = loopVarType;
|
|
|
|
// Parse "inclusive" flag.
|
|
if (succeeded(parser.parseOptionalKeyword("inclusive")))
|
|
result.addAttribute("loop_inclusive",
|
|
UnitAttr::get(parser.getBuilder().getContext()));
|
|
|
|
// Parse step values.
|
|
SmallVector<OpAsmParser::UnresolvedOperand> steps;
|
|
if (parser.parseKeyword("step") ||
|
|
parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
|
|
return failure();
|
|
|
|
// Parse the body.
|
|
Region *region = result.addRegion();
|
|
if (parser.parseRegion(*region, ivs))
|
|
return failure();
|
|
|
|
// Resolve operands.
|
|
if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
|
|
parser.resolveOperands(ubs, loopVarType, result.operands) ||
|
|
parser.resolveOperands(steps, loopVarType, result.operands))
|
|
return failure();
|
|
|
|
// Parse the optional attribute list.
|
|
return parser.parseOptionalAttrDict(result.attributes);
|
|
}
|
|
|
|
void LoopNestOp::print(OpAsmPrinter &p) {
|
|
Region ®ion = getRegion();
|
|
auto args = region.getArguments();
|
|
p << " (" << args << ") : " << args[0].getType() << " = ("
|
|
<< getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
|
|
if (getLoopInclusive())
|
|
p << "inclusive ";
|
|
p << "step (" << getLoopSteps() << ") ";
|
|
p.printRegion(region, /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
void LoopNestOp::build(OpBuilder &builder, OperationState &state,
|
|
const LoopNestOperands &clauses) {
|
|
LoopNestOp::build(builder, state, clauses.loopLowerBounds,
|
|
clauses.loopUpperBounds, clauses.loopSteps,
|
|
clauses.loopInclusive);
|
|
}
|
|
|
|
LogicalResult LoopNestOp::verify() {
|
|
if (getLoopLowerBounds().empty())
|
|
return emitOpError() << "must represent at least one loop";
|
|
|
|
if (getLoopLowerBounds().size() != getIVs().size())
|
|
return emitOpError() << "number of range arguments and IVs do not match";
|
|
|
|
for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
|
|
if (lb.getType() != iv.getType())
|
|
return emitOpError()
|
|
<< "range argument type does not match corresponding IV type";
|
|
}
|
|
|
|
if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
|
|
return emitOpError() << "expects parent op to be a loop wrapper";
|
|
|
|
return success();
|
|
}
|
|
|
|
void LoopNestOp::gatherWrappers(
|
|
SmallVectorImpl<LoopWrapperInterface> &wrappers) {
|
|
Operation *parent = (*this)->getParentOp();
|
|
while (auto wrapper =
|
|
llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
|
|
wrappers.push_back(wrapper);
|
|
parent = parent->getParentOp();
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpenMP canonical loop handling
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::tuple<NewCliOp, OpOperand *, OpOperand *>
|
|
mlir::omp ::decodeCli(Value cli) {
|
|
|
|
// Defining a CLI for a generated loop is optional; if there is none then
|
|
// there is no followup-tranformation
|
|
if (!cli)
|
|
return {{}, nullptr, nullptr};
|
|
|
|
assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
|
|
"Unexpected type of cli");
|
|
|
|
NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
|
|
OpOperand *gen = nullptr;
|
|
OpOperand *cons = nullptr;
|
|
for (OpOperand &use : cli.getUses()) {
|
|
auto op = cast<LoopTransformationInterface>(use.getOwner());
|
|
|
|
unsigned opnum = use.getOperandNumber();
|
|
if (op.isGeneratee(opnum)) {
|
|
assert(!gen && "Each CLI may have at most one def");
|
|
gen = &use;
|
|
} else if (op.isApplyee(opnum)) {
|
|
assert(!cons && "Each CLI may have at most one consumer");
|
|
cons = &use;
|
|
} else {
|
|
llvm_unreachable("Unexpected operand for a CLI");
|
|
}
|
|
}
|
|
|
|
return {create, gen, cons};
|
|
}
|
|
|
|
void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
|
|
::mlir::OperationState &odsState) {
|
|
odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
|
|
}
|
|
|
|
void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
|
|
Value result = getResult();
|
|
auto [newCli, gen, cons] = decodeCli(result);
|
|
|
|
// Derive the CLI variable name from its generator:
|
|
// * "canonloop" for omp.canonical_loop
|
|
// * custom name for loop transformation generatees
|
|
// * "cli" as fallback if no generator
|
|
// * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
|
|
// at that level
|
|
// * "_s<idx>" suffix for operations with multiple regions, where <idx> is
|
|
// the index of that region
|
|
std::string cliName{"cli"};
|
|
if (gen) {
|
|
cliName =
|
|
TypeSwitch<Operation *, std::string>(gen->getOwner())
|
|
.Case([&](CanonicalLoopOp op) {
|
|
// Find the canonical loop nesting: For each ancestor add a
|
|
// "+_r<idx>" suffix (in reverse order)
|
|
SmallVector<std::string> components;
|
|
Operation *o = op.getOperation();
|
|
while (o) {
|
|
if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
|
|
break;
|
|
|
|
Region *r = o->getParentRegion();
|
|
if (!r)
|
|
break;
|
|
|
|
auto getSequentialIndex = [](Region *r, Operation *o) {
|
|
llvm::ReversePostOrderTraversal<Block *> traversal(
|
|
&r->getBlocks().front());
|
|
size_t idx = 0;
|
|
for (Block *b : traversal) {
|
|
for (Operation &op : *b) {
|
|
if (&op == o)
|
|
return idx;
|
|
// Only consider operations that are containers as
|
|
// possible children
|
|
if (!op.getRegions().empty())
|
|
idx += 1;
|
|
}
|
|
}
|
|
llvm_unreachable("Operation not part of the region");
|
|
};
|
|
size_t sequentialIdx = getSequentialIndex(r, o);
|
|
components.push_back(("s" + Twine(sequentialIdx)).str());
|
|
|
|
Operation *parent = r->getParentOp();
|
|
if (!parent)
|
|
break;
|
|
|
|
// If the operation has more than one region, also count in
|
|
// which of the regions
|
|
if (parent->getRegions().size() > 1) {
|
|
auto getRegionIndex = [](Operation *o, Region *r) {
|
|
for (auto [idx, region] :
|
|
llvm::enumerate(o->getRegions())) {
|
|
if (®ion == r)
|
|
return idx;
|
|
}
|
|
llvm_unreachable("Region not child its parent operation");
|
|
};
|
|
size_t regionIdx = getRegionIndex(parent, r);
|
|
components.push_back(("r" + Twine(regionIdx)).str());
|
|
}
|
|
|
|
// next parent
|
|
o = parent;
|
|
}
|
|
|
|
SmallString<64> Name("canonloop");
|
|
for (std::string s : reverse(components)) {
|
|
Name += '_';
|
|
Name += s;
|
|
}
|
|
|
|
return Name;
|
|
})
|
|
.Case([&](UnrollHeuristicOp op) -> std::string {
|
|
llvm_unreachable("heuristic unrolling does not generate a loop");
|
|
})
|
|
.Default([&](Operation *op) {
|
|
assert(false && "TODO: Custom name for this operation");
|
|
return "transformed";
|
|
});
|
|
}
|
|
|
|
setNameFn(result, cliName);
|
|
}
|
|
|
|
LogicalResult NewCliOp::verify() {
|
|
Value cli = getResult();
|
|
|
|
assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
|
|
"Unexpected type of cli");
|
|
|
|
// Check that the CLI is used in at most generator and one consumer
|
|
OpOperand *gen = nullptr;
|
|
OpOperand *cons = nullptr;
|
|
for (mlir::OpOperand &use : cli.getUses()) {
|
|
auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
|
|
|
|
unsigned opnum = use.getOperandNumber();
|
|
if (op.isGeneratee(opnum)) {
|
|
if (gen) {
|
|
InFlightDiagnostic error =
|
|
emitOpError("CLI must have at most one generator");
|
|
error.attachNote(gen->getOwner()->getLoc())
|
|
.append("first generator here:");
|
|
error.attachNote(use.getOwner()->getLoc())
|
|
.append("second generator here:");
|
|
return error;
|
|
}
|
|
|
|
gen = &use;
|
|
} else if (op.isApplyee(opnum)) {
|
|
if (cons) {
|
|
InFlightDiagnostic error =
|
|
emitOpError("CLI must have at most one consumer");
|
|
error.attachNote(cons->getOwner()->getLoc())
|
|
.append("first consumer here:")
|
|
.appendOp(*cons->getOwner(),
|
|
OpPrintingFlags().printGenericOpForm());
|
|
error.attachNote(use.getOwner()->getLoc())
|
|
.append("second consumer here:")
|
|
.appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
|
|
return error;
|
|
}
|
|
|
|
cons = &use;
|
|
} else {
|
|
llvm_unreachable("Unexpected operand for a CLI");
|
|
}
|
|
}
|
|
|
|
// If the CLI is source of a transformation, it must have a generator
|
|
if (cons && !gen) {
|
|
InFlightDiagnostic error = emitOpError("CLI has no generator");
|
|
error.attachNote(cons->getOwner()->getLoc())
|
|
.append("see consumer here: ")
|
|
.appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
|
|
return error;
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
|
Value tripCount) {
|
|
odsState.addOperands(tripCount);
|
|
odsState.addOperands(Value());
|
|
(void)odsState.addRegion();
|
|
}
|
|
|
|
void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
|
Value tripCount, ::mlir::Value cli) {
|
|
odsState.addOperands(tripCount);
|
|
odsState.addOperands(cli);
|
|
(void)odsState.addRegion();
|
|
}
|
|
|
|
void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
|
|
setNameFn(&getRegion().front(), "body_entry");
|
|
}
|
|
|
|
void CanonicalLoopOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
setNameFn(region.getArgument(0), "iv");
|
|
}
|
|
|
|
void CanonicalLoopOp::print(OpAsmPrinter &p) {
|
|
if (getCli())
|
|
p << '(' << getCli() << ')';
|
|
p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
|
|
<< " in range(" << getTripCount() << ") ";
|
|
|
|
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
|
|
/*printBlockTerminators=*/true);
|
|
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
}
|
|
|
|
mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
|
|
::mlir::OperationState &result) {
|
|
CanonicalLoopInfoType cliType =
|
|
CanonicalLoopInfoType::get(parser.getContext());
|
|
|
|
// Parse (optional) omp.cli identifier
|
|
OpAsmParser::UnresolvedOperand cli;
|
|
SmallVector<mlir::Value, 1> cliOperand;
|
|
if (!parser.parseOptionalLParen()) {
|
|
if (parser.parseOperand(cli) ||
|
|
parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
// We derive the type of tripCount from inductionVariable. MLIR requires the
|
|
// type of tripCount to be known when calling resolveOperand so we have parse
|
|
// the type before processing the inductionVariable.
|
|
OpAsmParser::Argument inductionVariable;
|
|
OpAsmParser::UnresolvedOperand tripcount;
|
|
if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
|
|
parser.parseKeyword("in") || parser.parseKeyword("range") ||
|
|
parser.parseLParen() || parser.parseOperand(tripcount) ||
|
|
parser.parseRParen() ||
|
|
parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
|
|
return failure();
|
|
|
|
// Parse the loop body.
|
|
Region *region = result.addRegion();
|
|
if (parser.parseRegion(*region, {inductionVariable}))
|
|
return failure();
|
|
|
|
// We parsed the cli operand forst, but because it is optional, it must be
|
|
// last in the operand list.
|
|
result.operands.append(cliOperand);
|
|
|
|
// Parse the optional attribute list.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
LogicalResult CanonicalLoopOp::verify() {
|
|
// The region's entry must accept the induction variable
|
|
// It can also be empty if just created
|
|
if (!getRegion().empty()) {
|
|
Region ®ion = getRegion();
|
|
if (region.getNumArguments() != 1)
|
|
return emitOpError(
|
|
"Canonical loop region must have exactly one argument");
|
|
|
|
if (getInductionVar().getType() != getTripCount().getType())
|
|
return emitOpError(
|
|
"Region argument must be the same type as the trip count");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
|
|
|
|
std::pair<unsigned, unsigned>
|
|
CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
|
|
// No applyees
|
|
return {0, 0};
|
|
}
|
|
|
|
std::pair<unsigned, unsigned>
|
|
CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
|
|
return getODSOperandIndexAndLength(odsIndex_cli);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnrollHeuristicOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
|
|
::mlir::OperationState &odsState,
|
|
::mlir::Value cli) {
|
|
odsState.addOperands(cli);
|
|
}
|
|
|
|
void UnrollHeuristicOp::print(OpAsmPrinter &p) {
|
|
p << '(' << getApplyee() << ')';
|
|
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
}
|
|
|
|
mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
|
|
::mlir::OperationState &result) {
|
|
auto cliType = CanonicalLoopInfoType::get(parser.getContext());
|
|
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
OpAsmParser::UnresolvedOperand applyee;
|
|
if (parser.parseOperand(applyee) ||
|
|
parser.resolveOperand(applyee, cliType, result.operands))
|
|
return failure();
|
|
|
|
if (parser.parseRParen())
|
|
return failure();
|
|
|
|
// Optional output loop (full unrolling has none)
|
|
if (!parser.parseOptionalArrow()) {
|
|
if (parser.parseLParen() || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
// Parse the optional attribute list.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
std::pair<unsigned, unsigned>
|
|
UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
|
|
return getODSOperandIndexAndLength(odsIndex_applyee);
|
|
}
|
|
|
|
std::pair<unsigned, unsigned>
|
|
UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
|
|
return {0, 0};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Critical construct (2.17.1)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
|
|
const CriticalDeclareOperands &clauses) {
|
|
CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
|
|
}
|
|
|
|
LogicalResult CriticalDeclareOp::verify() {
|
|
return verifySynchronizationHint(*this, getHint());
|
|
}
|
|
|
|
LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
if (getNameAttr()) {
|
|
SymbolRefAttr symbolRef = getNameAttr();
|
|
auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
|
|
*this, symbolRef);
|
|
if (!decl) {
|
|
return emitOpError() << "expected symbol reference " << symbolRef
|
|
<< " to point to a critical declaration";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Ordered construct
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verifyOrderedParent(Operation &op) {
|
|
bool hasRegion = op.getNumRegions() > 0;
|
|
auto loopOp = op.getParentOfType<LoopNestOp>();
|
|
if (!loopOp) {
|
|
if (hasRegion)
|
|
return success();
|
|
|
|
// TODO: Consider if this needs to be the case only for the standalone
|
|
// variant of the ordered construct.
|
|
return op.emitOpError() << "must be nested inside of a loop";
|
|
}
|
|
|
|
Operation *wrapper = loopOp->getParentOp();
|
|
if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
|
|
IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
|
|
if (!orderedAttr)
|
|
return op.emitOpError() << "the enclosing worksharing-loop region must "
|
|
"have an ordered clause";
|
|
|
|
if (hasRegion && orderedAttr.getInt() != 0)
|
|
return op.emitOpError() << "the enclosing loop's ordered clause must not "
|
|
"have a parameter present";
|
|
|
|
if (!hasRegion && orderedAttr.getInt() == 0)
|
|
return op.emitOpError() << "the enclosing loop's ordered clause must "
|
|
"have a parameter present";
|
|
} else if (!isa<SimdOp>(wrapper)) {
|
|
return op.emitOpError() << "must be nested inside of a worksharing, simd "
|
|
"or worksharing simd loop";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void OrderedOp::build(OpBuilder &builder, OperationState &state,
|
|
const OrderedOperands &clauses) {
|
|
OrderedOp::build(builder, state, clauses.doacrossDependType,
|
|
clauses.doacrossNumLoops, clauses.doacrossDependVars);
|
|
}
|
|
|
|
LogicalResult OrderedOp::verify() {
|
|
if (failed(verifyOrderedParent(**this)))
|
|
return failure();
|
|
|
|
auto wrapper = (*this)->getParentOfType<WsloopOp>();
|
|
if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
|
|
return emitOpError() << "number of variables in depend clause does not "
|
|
<< "match number of iteration variables in the "
|
|
<< "doacross loop";
|
|
|
|
return success();
|
|
}
|
|
|
|
void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
|
|
const OrderedRegionOperands &clauses) {
|
|
OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
|
|
}
|
|
|
|
LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TaskwaitOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
|
|
const TaskwaitOperands &clauses) {
|
|
// TODO Store clauses in op: dependKinds, dependVars, nowait.
|
|
TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
|
|
/*depend_vars=*/{}, /*nowait=*/nullptr);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifier for AtomicReadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicReadOp::verify() {
|
|
if (verifyCommon().failed())
|
|
return mlir::failure();
|
|
|
|
if (auto mo = getMemoryOrder()) {
|
|
if (*mo == ClauseMemoryOrderKind::Acq_rel ||
|
|
*mo == ClauseMemoryOrderKind::Release) {
|
|
return emitError(
|
|
"memory-order must not be acq_rel or release for atomic reads");
|
|
}
|
|
}
|
|
return verifySynchronizationHint(*this, getHint());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifier for AtomicWriteOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicWriteOp::verify() {
|
|
if (verifyCommon().failed())
|
|
return mlir::failure();
|
|
|
|
if (auto mo = getMemoryOrder()) {
|
|
if (*mo == ClauseMemoryOrderKind::Acq_rel ||
|
|
*mo == ClauseMemoryOrderKind::Acquire) {
|
|
return emitError(
|
|
"memory-order must not be acq_rel or acquire for atomic writes");
|
|
}
|
|
}
|
|
return verifySynchronizationHint(*this, getHint());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifier for AtomicUpdateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
|
|
PatternRewriter &rewriter) {
|
|
if (op.isNoOp()) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
if (Value writeVal = op.getWriteOpVal()) {
|
|
rewriter.replaceOpWithNewOp<AtomicWriteOp>(
|
|
op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult AtomicUpdateOp::verify() {
|
|
if (verifyCommon().failed())
|
|
return mlir::failure();
|
|
|
|
if (auto mo = getMemoryOrder()) {
|
|
if (*mo == ClauseMemoryOrderKind::Acq_rel ||
|
|
*mo == ClauseMemoryOrderKind::Acquire) {
|
|
return emitError(
|
|
"memory-order must not be acq_rel or acquire for atomic updates");
|
|
}
|
|
}
|
|
|
|
return verifySynchronizationHint(*this, getHint());
|
|
}
|
|
|
|
LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifier for AtomicCaptureOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
|
|
if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicReadOp>(getSecondOp());
|
|
}
|
|
|
|
AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
|
|
if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicWriteOp>(getSecondOp());
|
|
}
|
|
|
|
AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
|
|
if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicUpdateOp>(getSecondOp());
|
|
}
|
|
|
|
LogicalResult AtomicCaptureOp::verify() {
|
|
return verifySynchronizationHint(*this, getHint());
|
|
}
|
|
|
|
LogicalResult AtomicCaptureOp::verifyRegions() {
|
|
if (verifyRegionsCommon().failed())
|
|
return mlir::failure();
|
|
|
|
if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
|
|
return emitOpError(
|
|
"operations inside capture region must not have hint clause");
|
|
|
|
if (getFirstOp()->getAttr("memory_order") ||
|
|
getSecondOp()->getAttr("memory_order"))
|
|
return emitOpError(
|
|
"operations inside capture region must not have memory_order clause");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CancelOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void CancelOp::build(OpBuilder &builder, OperationState &state,
|
|
const CancelOperands &clauses) {
|
|
CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
|
|
}
|
|
|
|
static Operation *getParentInSameDialect(Operation *thisOp) {
|
|
Operation *parent = thisOp->getParentOp();
|
|
while (parent) {
|
|
if (parent->getDialect() == thisOp->getDialect())
|
|
return parent;
|
|
parent = parent->getParentOp();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
LogicalResult CancelOp::verify() {
|
|
ClauseCancellationConstructType cct = getCancelDirective();
|
|
// The next OpenMP operation in the chain of parents
|
|
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
|
|
if (!structuralParent)
|
|
return emitOpError() << "Orphaned cancel construct";
|
|
|
|
if ((cct == ClauseCancellationConstructType::Parallel) &&
|
|
!mlir::isa<ParallelOp>(structuralParent)) {
|
|
return emitOpError() << "cancel parallel must appear "
|
|
<< "inside a parallel region";
|
|
}
|
|
if (cct == ClauseCancellationConstructType::Loop) {
|
|
// structural parent will be omp.loop_nest, directly nested inside
|
|
// omp.wsloop
|
|
auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
|
|
|
|
if (!wsloopOp) {
|
|
return emitOpError()
|
|
<< "cancel loop must appear inside a worksharing-loop region";
|
|
}
|
|
if (wsloopOp.getNowaitAttr()) {
|
|
return emitError() << "A worksharing construct that is canceled "
|
|
<< "must not have a nowait clause";
|
|
}
|
|
if (wsloopOp.getOrderedAttr()) {
|
|
return emitError() << "A worksharing construct that is canceled "
|
|
<< "must not have an ordered clause";
|
|
}
|
|
|
|
} else if (cct == ClauseCancellationConstructType::Sections) {
|
|
// structural parent will be an omp.section, directly nested inside
|
|
// omp.sections
|
|
auto sectionsOp =
|
|
mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
|
|
if (!sectionsOp) {
|
|
return emitOpError() << "cancel sections must appear "
|
|
<< "inside a sections region";
|
|
}
|
|
if (sectionsOp.getNowait()) {
|
|
return emitError() << "A sections construct that is canceled "
|
|
<< "must not have a nowait clause";
|
|
}
|
|
}
|
|
if ((cct == ClauseCancellationConstructType::Taskgroup) &&
|
|
(!mlir::isa<omp::TaskOp>(structuralParent) &&
|
|
!mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
|
|
return emitOpError() << "cancel taskgroup must appear "
|
|
<< "inside a task region";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CancellationPointOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
|
|
const CancellationPointOperands &clauses) {
|
|
CancellationPointOp::build(builder, state, clauses.cancelDirective);
|
|
}
|
|
|
|
LogicalResult CancellationPointOp::verify() {
|
|
ClauseCancellationConstructType cct = getCancelDirective();
|
|
// The next OpenMP operation in the chain of parents
|
|
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
|
|
if (!structuralParent)
|
|
return emitOpError() << "Orphaned cancellation point";
|
|
|
|
if ((cct == ClauseCancellationConstructType::Parallel) &&
|
|
!mlir::isa<ParallelOp>(structuralParent)) {
|
|
return emitOpError() << "cancellation point parallel must appear "
|
|
<< "inside a parallel region";
|
|
}
|
|
// Strucutal parent here will be an omp.loop_nest. Get the parent of that to
|
|
// find the wsloop
|
|
if ((cct == ClauseCancellationConstructType::Loop) &&
|
|
!mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
|
|
return emitOpError() << "cancellation point loop must appear "
|
|
<< "inside a worksharing-loop region";
|
|
}
|
|
if ((cct == ClauseCancellationConstructType::Sections) &&
|
|
!mlir::isa<omp::SectionOp>(structuralParent)) {
|
|
return emitOpError() << "cancellation point sections must appear "
|
|
<< "inside a sections region";
|
|
}
|
|
if ((cct == ClauseCancellationConstructType::Taskgroup) &&
|
|
!mlir::isa<omp::TaskOp>(structuralParent)) {
|
|
return emitOpError() << "cancellation point taskgroup must appear "
|
|
<< "inside a task region";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MapBoundsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult MapBoundsOp::verify() {
|
|
auto extent = getExtent();
|
|
auto upperbound = getUpperBound();
|
|
if (!extent && !upperbound)
|
|
return emitError("expected extent or upperbound.");
|
|
return success();
|
|
}
|
|
|
|
void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
|
TypeRange /*result_types*/, StringAttr symName,
|
|
TypeAttr type) {
|
|
PrivateClauseOp::build(
|
|
odsBuilder, odsState, symName, type,
|
|
DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
|
|
DataSharingClauseType::Private));
|
|
}
|
|
|
|
LogicalResult PrivateClauseOp::verifyRegions() {
|
|
Type argType = getArgType();
|
|
auto verifyTerminator = [&](Operation *terminator,
|
|
bool yieldsValue) -> LogicalResult {
|
|
if (!terminator->getBlock()->getSuccessors().empty())
|
|
return success();
|
|
|
|
if (!llvm::isa<YieldOp>(terminator))
|
|
return mlir::emitError(terminator->getLoc())
|
|
<< "expected exit block terminator to be an `omp.yield` op.";
|
|
|
|
YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
|
|
TypeRange yieldedTypes = yieldOp.getResults().getTypes();
|
|
|
|
if (!yieldsValue) {
|
|
if (yieldedTypes.empty())
|
|
return success();
|
|
|
|
return mlir::emitError(terminator->getLoc())
|
|
<< "Did not expect any values to be yielded.";
|
|
}
|
|
|
|
if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
|
|
return success();
|
|
|
|
auto error = mlir::emitError(yieldOp.getLoc())
|
|
<< "Invalid yielded value. Expected type: " << argType
|
|
<< ", got: ";
|
|
|
|
if (yieldedTypes.empty())
|
|
error << "None";
|
|
else
|
|
error << yieldedTypes;
|
|
|
|
return error;
|
|
};
|
|
|
|
auto verifyRegion = [&](Region ®ion, unsigned expectedNumArgs,
|
|
StringRef regionName,
|
|
bool yieldsValue) -> LogicalResult {
|
|
assert(!region.empty());
|
|
|
|
if (region.getNumArguments() != expectedNumArgs)
|
|
return mlir::emitError(region.getLoc())
|
|
<< "`" << regionName << "`: "
|
|
<< "expected " << expectedNumArgs
|
|
<< " region arguments, got: " << region.getNumArguments();
|
|
|
|
for (Block &block : region) {
|
|
// MLIR will verify the absence of the terminator for us.
|
|
if (!block.mightHaveTerminator())
|
|
continue;
|
|
|
|
if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
};
|
|
|
|
// Ensure all of the region arguments have the same type
|
|
for (Region *region : getRegions())
|
|
for (Type ty : region->getArgumentTypes())
|
|
if (ty != argType)
|
|
return emitError() << "Region argument type mismatch: got " << ty
|
|
<< " expected " << argType << ".";
|
|
|
|
mlir::Region &initRegion = getInitRegion();
|
|
if (!initRegion.empty() &&
|
|
failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
|
|
/*yieldsValue=*/true)))
|
|
return failure();
|
|
|
|
DataSharingClauseType dsType = getDataSharingType();
|
|
|
|
if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
|
|
return emitError("`private` clauses do not require a `copy` region.");
|
|
|
|
if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
|
|
return emitError(
|
|
"`firstprivate` clauses require at least a `copy` region.");
|
|
|
|
if (dsType == DataSharingClauseType::FirstPrivate &&
|
|
failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
|
|
/*yieldsValue=*/true)))
|
|
return failure();
|
|
|
|
if (!getDeallocRegion().empty() &&
|
|
failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
|
|
/*yieldsValue=*/false)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Spec 5.2: Masked construct (10.5)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void MaskedOp::build(OpBuilder &builder, OperationState &state,
|
|
const MaskedOperands &clauses) {
|
|
MaskedOp::build(builder, state, clauses.filteredThreadId);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Spec 5.2: Scan construct (5.6)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ScanOp::build(OpBuilder &builder, OperationState &state,
|
|
const ScanOperands &clauses) {
|
|
ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
|
|
}
|
|
|
|
LogicalResult ScanOp::verify() {
|
|
if (hasExclusiveVars() == hasInclusiveVars())
|
|
return emitError(
|
|
"Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
|
|
if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
|
|
if (parentWsLoopOp.getReductionModAttr() &&
|
|
parentWsLoopOp.getReductionModAttr().getValue() ==
|
|
ReductionModifier::inscan)
|
|
return success();
|
|
}
|
|
if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
|
|
if (parentSimdOp.getReductionModAttr() &&
|
|
parentSimdOp.getReductionModAttr().getValue() ==
|
|
ReductionModifier::inscan)
|
|
return success();
|
|
}
|
|
return emitError("SCAN directive needs to be enclosed within a parent "
|
|
"worksharing loop construct or SIMD construct with INSCAN "
|
|
"reduction modifier");
|
|
}
|
|
|
|
/// Verifies align clause in allocate directive
|
|
|
|
LogicalResult AllocateDirOp::verify() {
|
|
std::optional<uint64_t> align = this->getAlign();
|
|
|
|
if (align.has_value()) {
|
|
if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
|
|
return emitError() << "ALIGN value : " << align.value()
|
|
<< " must be power of 2";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
|