diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 25682cba5620..51eb33dec186 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector &attributes, return std::nullopt; } +// Helper function to extract string value from bind name variant +static std::optional getBindNameStringValue( + const std::optional> + &bindNameValue) { + if (!bindNameValue.has_value()) + return std::nullopt; + + return std::visit( + [](const auto &attr) -> std::optional { + if constexpr (std::is_same_v, + mlir::StringAttr>) { + return attr.getValue(); + } else if constexpr (std::is_same_v, + mlir::SymbolRefAttr>) { + return attr.getLeafReference(); + } else { + return std::nullopt; + } + }, + bindNameValue.value()); +} + static bool compareDeviceTypeInfo( mlir::acc::RoutineOp op, - llvm::SmallVector &bindNameArrayAttr, - llvm::SmallVector &bindNameDeviceTypeArrayAttr, + llvm::SmallVector &bindIdNameArrayAttr, + llvm::SmallVector &bindStrNameArrayAttr, + llvm::SmallVector &bindIdNameDeviceTypeArrayAttr, + llvm::SmallVector &bindStrNameDeviceTypeArrayAttr, llvm::SmallVector &gangArrayAttr, llvm::SmallVector &gangDimArrayAttr, llvm::SmallVector &gangDimDeviceTypeArrayAttr, @@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo( for (uint32_t dtypeInt = 0; dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) { auto dtype = static_cast(dtypeInt); - if (op.getBindNameValue(dtype) != - getAttributeValueByDeviceType( - bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype)) + auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype)); + if (bindNameValue != + getAttributeValueByDeviceType( + bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) && + bindNameValue != + getAttributeValueByDeviceType( + bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype)) return false; if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype)) return false; @@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder, void createOpenACCRoutineConstruct( Fortran::lower::AbstractConverter &converter, mlir::Location loc, mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName, - bool hasNohost, llvm::SmallVector &bindNames, - llvm::SmallVector &bindNameDeviceTypes, + bool hasNohost, llvm::SmallVector &bindIdNames, + llvm::SmallVector &bindStrNames, + llvm::SmallVector &bindIdNameDeviceTypes, + llvm::SmallVector &bindStrNameDeviceTypes, llvm::SmallVector &gangDeviceTypes, llvm::SmallVector &gangDimValues, llvm::SmallVector &gangDimDeviceTypes, @@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct( 0) { // If the routine is already specified with the same clauses, just skip // the operation creation. - if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes, + if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames, + bindIdNameDeviceTypes, bindStrNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes) && @@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct( modBuilder.create( loc, routineOpStr, mlir::SymbolRefAttr::get(builder.getContext(), funcName), - getArrayAttrOrNull(builder, bindNames), - getArrayAttrOrNull(builder, bindNameDeviceTypes), + getArrayAttrOrNull(builder, bindIdNames), + getArrayAttrOrNull(builder, bindStrNames), + getArrayAttrOrNull(builder, bindIdNameDeviceTypes), + getArrayAttrOrNull(builder, bindStrNameDeviceTypes), getArrayAttrOrNull(builder, workerDeviceTypes), getArrayAttrOrNull(builder, vectorDeviceTypes), getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost, @@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo( llvm::SmallVector &seqDeviceTypes, llvm::SmallVector &vectorDeviceTypes, llvm::SmallVector &workerDeviceTypes, - llvm::SmallVector &bindNameDeviceTypes, - llvm::SmallVector &bindNames, + llvm::SmallVector &bindIdNameDeviceTypes, + llvm::SmallVector &bindStrNameDeviceTypes, + llvm::SmallVector &bindIdNames, + llvm::SmallVector &bindStrNames, llvm::SmallVector &gangDeviceTypes, llvm::SmallVector &gangDimValues, llvm::SmallVector &gangDimDeviceTypes) { @@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo( if (dinfo.bindNameOpt().has_value()) { const auto &bindName = dinfo.bindNameOpt().value(); mlir::Attribute bindNameAttr; - if (const auto &bindStr{std::get_if(&bindName)}) { + if (const auto &bindSym{ + std::get_if(&bindName)}) { + bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym)); + bindIdNames.push_back(bindNameAttr); + bindIdNameDeviceTypes.push_back(getDeviceTypeAttr()); + } else if (const auto &bindStr{std::get_if(&bindName)}) { bindNameAttr = builder.getStringAttr(*bindStr); - } else if (const auto &bindSym{ - std::get_if(&bindName)}) { - bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym)); + bindStrNames.push_back(bindNameAttr); + bindStrNameDeviceTypes.push_back(getDeviceTypeAttr()); } else { llvm_unreachable("Unsupported bind name type"); } - bindNames.push_back(bindNameAttr); - bindNameDeviceTypes.push_back(getDeviceTypeAttr()); } } @@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct( bool hasNohost{false}; llvm::SmallVector seqDeviceTypes, vectorDeviceTypes, - workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes, - gangDimDeviceTypes, gangDimValues; + workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes, + bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes, + gangDimValues; for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) { // Device Independent Attributes @@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct( } // Note: Device Independent Attributes are set to the // none device type in `info`. - interpretRoutineDeviceInfo(converter, info, seqDeviceTypes, - vectorDeviceTypes, workerDeviceTypes, - bindNameDeviceTypes, bindNames, gangDeviceTypes, - gangDimValues, gangDimDeviceTypes); + interpretRoutineDeviceInfo( + converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes, + bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames, + bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes); // Device Dependent Attributes for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo : info.deviceTypeInfos()) { - interpretRoutineDeviceInfo( - converter, dinfo, seqDeviceTypes, vectorDeviceTypes, - workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes, - gangDimValues, gangDimDeviceTypes); + interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes, + vectorDeviceTypes, workerDeviceTypes, + bindIdNameDeviceTypes, bindStrNameDeviceTypes, + bindIdNames, bindStrNames, gangDeviceTypes, + gangDimValues, gangDimDeviceTypes); } } createOpenACCRoutineConstruct( - converter, loc, mod, funcOp, funcName, hasNohost, bindNames, - bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes, - seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes); + converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames, + bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes, + gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes, + workerDeviceTypes, vectorDeviceTypes); } static void diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90 index 789f3a57e1f7..1a63b4120235 100644 --- a/flang/test/Lower/OpenACC/acc-routine.f90 +++ b/flang/test/Lower/OpenACC/acc-routine.f90 @@ -2,13 +2,14 @@ ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s -! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type], "_QPacc_routine17" [#acc.device_type], "_QPacc_routine16" [#acc.device_type]) -! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type], "_QPacc_routine16" [#acc.device_type]) +! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type], @_QPacc_routine17 +! [#acc.device_type], @_QPacc_routine16 [#acc.device_type]) +! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type], @_QPacc_routine16 [#acc.device_type]) ! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type]) vector ([#acc.device_type]) ! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type]) seq ([#acc.device_type]) ! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq ! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq -! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a") +! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a) ! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_") ! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64) ! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost diff --git a/flang/test/Lower/OpenACC/acc-routine03.f90 b/flang/test/Lower/OpenACC/acc-routine03.f90 index 85e4ef580f98..ddd6bda0367e 100644 --- a/flang/test/Lower/OpenACC/acc-routine03.f90 +++ b/flang/test/Lower/OpenACC/acc-routine03.f90 @@ -30,6 +30,6 @@ end interface end subroutine ! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost -! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker +! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker ! CHECK: func.func @_QPsub1(%arg0: !fir.box> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} ! CHECK: func.func @_QPsub2(%arg0: !fir.box> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>} diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 4eb666239d4e..8f87235fcd23 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -29,6 +29,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc" diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 66378f116784..96b9adcc53b3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> { }]; let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name, - OptionalAttr:$bindName, - OptionalAttr:$bindNameDeviceType, + OptionalAttr:$bindIdName, + OptionalAttr:$bindStrName, + OptionalAttr:$bindIdNameDeviceType, + OptionalAttr:$bindStrNameDeviceType, OptionalAttr:$worker, OptionalAttr:$vector, OptionalAttr:$seq, UnitAttr:$nohost, @@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> { std::optional getGangDimValue(); std::optional getGangDimValue(mlir::acc::DeviceType deviceType); - std::optional getBindNameValue(); - std::optional getBindNameValue(mlir::acc::DeviceType deviceType); + std::optional<::std::variant> getBindNameValue(); + std::optional<::std::variant> getBindNameValue(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ $sym_name `func` `(` $func_name `)` oilist ( - `bind` `(` custom($bindName, $bindNameDeviceType) `)` + `bind` `(` custom($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)` | `gang` `` custom($gang, $gangDim, $gangDimDeviceType) | `worker` custom($worker) | `vector` custom($vector) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index f2eab62b286a..fbc1f003ab64 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/LogicalResult.h" +#include using namespace mlir; using namespace acc; @@ -3461,40 +3462,88 @@ LogicalResult acc::RoutineOp::verify() { return success(); } -static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, - mlir::ArrayAttr &deviceTypes) { - llvm::SmallVector bindNameAttrs; - llvm::SmallVector deviceTypeAttrs; +static ParseResult parseBindName(OpAsmParser &parser, + mlir::ArrayAttr &bindIdName, + mlir::ArrayAttr &bindStrName, + mlir::ArrayAttr &deviceIdTypes, + mlir::ArrayAttr &deviceStrTypes) { + llvm::SmallVector bindIdNameAttrs; + llvm::SmallVector bindStrNameAttrs; + llvm::SmallVector deviceIdTypeAttrs; + llvm::SmallVector deviceStrTypeAttrs; if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(bindNameAttrs.emplace_back())) + mlir::Attribute newAttr; + bool isSymbolRefAttr; + auto parseResult = parser.parseAttribute(newAttr); + if (auto symbolRefAttr = dyn_cast(newAttr)) { + bindIdNameAttrs.push_back(symbolRefAttr); + isSymbolRefAttr = true; + } else if (auto stringAttr = dyn_cast(newAttr)) { + bindStrNameAttrs.push_back(stringAttr); + isSymbolRefAttr = false; + } + if (parseResult) return failure(); if (failed(parser.parseOptionalLSquare())) { - deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( - parser.getContext(), mlir::acc::DeviceType::None)); + if (isSymbolRefAttr) { + deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } else { + deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } } else { - if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || - parser.parseRSquare()) - return failure(); + if (isSymbolRefAttr) { + if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) || + parser.parseRSquare()) + return failure(); + } } return success(); }))) return failure(); - bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs); - deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); + bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs); + bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs); + deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs); + deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs); return success(); } static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, - std::optional bindName, - std::optional deviceTypes) { - llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p, - [&](const auto &pair) { - p << std::get<0>(pair); - printSingleDeviceType(p, std::get<1>(pair)); - }); + std::optional bindIdName, + std::optional bindStrName, + std::optional deviceIdTypes, + std::optional deviceStrTypes) { + // Create combined vectors for all bind names and device types + llvm::SmallVector allBindNames; + llvm::SmallVector allDeviceTypes; + + // Append bindIdName and deviceIdTypes + if (hasDeviceTypeValues(deviceIdTypes)) { + allBindNames.append(bindIdName->begin(), bindIdName->end()); + allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end()); + } + + // Append bindStrName and deviceStrTypes + if (hasDeviceTypeValues(deviceStrTypes)) { + allBindNames.append(bindStrName->begin(), bindStrName->end()); + allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end()); + } + + // Print the combined sequence + if (!allBindNames.empty()) + llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p, + [&](const auto &pair) { + p << std::get<0>(pair); + printSingleDeviceType(p, std::get<1>(pair)); + }); } static ParseResult parseRoutineGangClause(OpAsmParser &parser, @@ -3654,19 +3703,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) { return hasDeviceType(getSeq(), deviceType); } -std::optional RoutineOp::getBindNameValue() { +std::optional> +RoutineOp::getBindNameValue() { return getBindNameValue(mlir::acc::DeviceType::None); } -std::optional +std::optional> RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) { - if (!hasDeviceTypeValues(getBindNameDeviceType())) + if (!hasDeviceTypeValues(getBindIdNameDeviceType()) && + !hasDeviceTypeValues(getBindStrNameDeviceType())) { return std::nullopt; - if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) { - auto attr = (*getBindName())[*pos]; - auto stringAttr = dyn_cast(attr); - return stringAttr.getValue(); } + + if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) { + auto attr = (*getBindIdName())[*pos]; + auto symbolRefAttr = dyn_cast(attr); + assert(symbolRefAttr && "expected SymbolRef"); + return symbolRefAttr; + } + + if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) { + auto attr = (*getBindStrName())[*pos]; + auto stringAttr = dyn_cast(attr); + assert(stringAttr && "expected String"); + return stringAttr; + } + return std::nullopt; } diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index aa16421cbec5..836efdb307f9 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -519,14 +519,44 @@ TEST_F(OpenACCOpsTest, routineOpTest) { op->removeGangDimDeviceTypeAttr(); op->removeGangDimAttr(); - op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone})); - op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")})); + op->setBindIdNameDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host)})); + op->setBindStrNameDeviceTypeAttr(b.getArrayAttr({dtypeNone})); + op->setBindIdNameAttr( + b.getArrayAttr({SymbolRefAttr::get(&context, "test_symbol")})); + op->setBindStrNameAttr(b.getArrayAttr({b.getStringAttr("fname")})); EXPECT_TRUE(op->getBindNameValue().has_value()); - EXPECT_EQ(op->getBindNameValue().value(), "fname"); - for (auto d : dtypesWithoutNone) - EXPECT_FALSE(op->getBindNameValue(d).has_value()); - op->removeBindNameDeviceTypeAttr(); - op->removeBindNameAttr(); + EXPECT_TRUE(op->getBindNameValue(DeviceType::Host).has_value()); + EXPECT_EQ(std::visit( + [](const auto &attr) -> std::string { + if constexpr (std::is_same_v, + mlir::StringAttr>) { + return attr.str(); + } else { + return attr.getLeafReference().str(); + } + }, + op->getBindNameValue().value()), + "fname"); + EXPECT_EQ(std::visit( + [](const auto &attr) -> std::string { + if constexpr (std::is_same_v, + mlir::StringAttr>) { + return attr.str(); + } else { + return attr.getLeafReference().str(); + } + }, + op->getBindNameValue(DeviceType::Host).value()), + "test_symbol"); + for (auto d : dtypesWithoutNone) { + if (d != DeviceType::Host) + EXPECT_FALSE(op->getBindNameValue(d).has_value()); + } + op->removeBindIdNameDeviceTypeAttr(); + op->removeBindStrNameDeviceTypeAttr(); + op->removeBindIdNameAttr(); + op->removeBindStrNameAttr(); } template