[openacc][flang] Support two type bindName representation in acc routine (#149147)

Based on the OpenACC specification — which states that if the bind name
is given as an identifier it should be resolved according to the
compiled language, and if given as a string it should be used unmodified
— we introduce two distinct `bindName` representations for `acc routine`
to handle each case appropriately: one as an array of `SymbolRefAttr`
for identifiers and another as an array of `StringAttr` for strings.

To ensure correct correspondence between bind names and devices, this
patch also introduces two separate sets of device attributes. The
routine operation is extended accordingly, along with the necessary
updates to the OpenACC dialect and its lowering.
This commit is contained in:
delaram-talaashrafi 2025-07-17 12:38:02 -04:00 committed by GitHub
parent 661cbd5a52
commit 0dae924c1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 209 additions and 73 deletions

View File

@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
return std::nullopt;
}
// Helper function to extract string value from bind name variant
static std::optional<llvm::StringRef> getBindNameStringValue(
const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
&bindNameValue) {
if (!bindNameValue.has_value())
return std::nullopt;
return std::visit(
[](const auto &attr) -> std::optional<llvm::StringRef> {
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::StringAttr>) {
return attr.getValue();
} else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::SymbolRefAttr>) {
return attr.getLeafReference();
} else {
return std::nullopt;
}
},
bindNameValue.value());
}
static bool compareDeviceTypeInfo(
mlir::acc::RoutineOp op,
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo(
for (uint32_t dtypeInt = 0;
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
if (op.getBindNameValue(dtype) !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
if (bindNameValue !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
bindNameValue !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
return false;
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
return false;
@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
void createOpenACCRoutineConstruct(
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
llvm::SmallVector<mlir::Attribute> &bindStrNames,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDimValues,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct(
0) {
// If the routine is already specified with the same clauses, just skip
// the operation creation.
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
gangDeviceTypes, gangDimValues,
gangDimDeviceTypes, seqDeviceTypes,
workerDeviceTypes, vectorDeviceTypes) &&
@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct(
modBuilder.create<mlir::acc::RoutineOp>(
loc, routineOpStr,
mlir::SymbolRefAttr::get(builder.getContext(), funcName),
getArrayAttrOrNull(builder, bindNames),
getArrayAttrOrNull(builder, bindNameDeviceTypes),
getArrayAttrOrNull(builder, bindIdNames),
getArrayAttrOrNull(builder, bindStrNames),
getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
getArrayAttrOrNull(builder, workerDeviceTypes),
getArrayAttrOrNull(builder, vectorDeviceTypes),
getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo(
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindNames,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindIdNames,
llvm::SmallVector<mlir::Attribute> &bindStrNames,
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDimValues,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo(
if (dinfo.bindNameOpt().has_value()) {
const auto &bindName = dinfo.bindNameOpt().value();
mlir::Attribute bindNameAttr;
if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
if (const auto &bindSym{
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
bindIdNames.push_back(bindNameAttr);
bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
} else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
bindNameAttr = builder.getStringAttr(*bindStr);
} else if (const auto &bindSym{
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
bindStrNames.push_back(bindNameAttr);
bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
} else {
llvm_unreachable("Unsupported bind name type");
}
bindNames.push_back(bindNameAttr);
bindNameDeviceTypes.push_back(getDeviceTypeAttr());
}
}
@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
bool hasNohost{false};
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimDeviceTypes, gangDimValues;
workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
gangDimValues;
for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
// Device Independent Attributes
@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
}
// Note: Device Independent Attributes are set to the
// none device type in `info`.
interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
vectorDeviceTypes, workerDeviceTypes,
bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
interpretRoutineDeviceInfo(
converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
// Device Dependent Attributes
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
info.deviceTypeInfos()) {
interpretRoutineDeviceInfo(
converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
vectorDeviceTypes, workerDeviceTypes,
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
bindIdNames, bindStrNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
}
}
createOpenACCRoutineConstruct(
converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
workerDeviceTypes, vectorDeviceTypes);
}
static void

View File

@ -2,13 +2,14 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost

View File

@ -30,6 +30,6 @@ end interface
end subroutine
! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}

View File

@ -29,6 +29,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <variant>
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"

View File

@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
}];
let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
OptionalAttr<StrArrayAttr>:$bindName,
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
OptionalAttr<StrArrayAttr>:$bindStrName,
OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$worker,
OptionalAttr<DeviceTypeArrayAttr>:$vector,
OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
std::optional<int64_t> getGangDimValue();
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
std::optional<llvm::StringRef> getBindNameValue();
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
$sym_name `func` `(` $func_name `)`
oilist (
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
`bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
| `worker` custom<DeviceTypeArrayAttr>($worker)
| `vector` custom<DeviceTypeArrayAttr>($vector)

View File

@ -21,6 +21,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <variant>
using namespace mlir;
using namespace acc;
@ -3461,40 +3462,88 @@ LogicalResult acc::RoutineOp::verify() {
return success();
}
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
mlir::ArrayAttr &deviceTypes) {
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
static ParseResult parseBindName(OpAsmParser &parser,
mlir::ArrayAttr &bindIdName,
mlir::ArrayAttr &bindStrName,
mlir::ArrayAttr &deviceIdTypes,
mlir::ArrayAttr &deviceStrTypes) {
llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
if (failed(parser.parseCommaSeparatedList([&]() {
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
mlir::Attribute newAttr;
bool isSymbolRefAttr;
auto parseResult = parser.parseAttribute(newAttr);
if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
bindIdNameAttrs.push_back(symbolRefAttr);
isSymbolRefAttr = true;
} else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
bindStrNameAttrs.push_back(stringAttr);
isSymbolRefAttr = false;
}
if (parseResult)
return failure();
if (failed(parser.parseOptionalLSquare())) {
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
if (isSymbolRefAttr) {
deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
} else {
deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
}
} else {
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
if (isSymbolRefAttr) {
if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
} else {
if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
}
}
return success();
})))
return failure();
bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
return success();
}
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
std::optional<mlir::ArrayAttr> bindName,
std::optional<mlir::ArrayAttr> deviceTypes) {
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
[&](const auto &pair) {
p << std::get<0>(pair);
printSingleDeviceType(p, std::get<1>(pair));
});
std::optional<mlir::ArrayAttr> bindIdName,
std::optional<mlir::ArrayAttr> bindStrName,
std::optional<mlir::ArrayAttr> deviceIdTypes,
std::optional<mlir::ArrayAttr> deviceStrTypes) {
// Create combined vectors for all bind names and device types
llvm::SmallVector<mlir::Attribute> allBindNames;
llvm::SmallVector<mlir::Attribute> allDeviceTypes;
// Append bindIdName and deviceIdTypes
if (hasDeviceTypeValues(deviceIdTypes)) {
allBindNames.append(bindIdName->begin(), bindIdName->end());
allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
}
// Append bindStrName and deviceStrTypes
if (hasDeviceTypeValues(deviceStrTypes)) {
allBindNames.append(bindStrName->begin(), bindStrName->end());
allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
}
// Print the combined sequence
if (!allBindNames.empty())
llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
[&](const auto &pair) {
p << std::get<0>(pair);
printSingleDeviceType(p, std::get<1>(pair));
});
}
static ParseResult parseRoutineGangClause(OpAsmParser &parser,
@ -3654,19 +3703,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
return hasDeviceType(getSeq(), deviceType);
}
std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
RoutineOp::getBindNameValue() {
return getBindNameValue(mlir::acc::DeviceType::None);
}
std::optional<llvm::StringRef>
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
if (!hasDeviceTypeValues(getBindNameDeviceType()))
if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
!hasDeviceTypeValues(getBindStrNameDeviceType())) {
return std::nullopt;
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
auto attr = (*getBindName())[*pos];
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
return stringAttr.getValue();
}
if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
auto attr = (*getBindIdName())[*pos];
auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
assert(symbolRefAttr && "expected SymbolRef");
return symbolRefAttr;
}
if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
auto attr = (*getBindStrName())[*pos];
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
assert(stringAttr && "expected String");
return stringAttr;
}
return std::nullopt;
}

View File

@ -519,14 +519,44 @@ TEST_F(OpenACCOpsTest, routineOpTest) {
op->removeGangDimDeviceTypeAttr();
op->removeGangDimAttr();
op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")}));
op->setBindIdNameDeviceTypeAttr(
b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host)}));
op->setBindStrNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
op->setBindIdNameAttr(
b.getArrayAttr({SymbolRefAttr::get(&context, "test_symbol")}));
op->setBindStrNameAttr(b.getArrayAttr({b.getStringAttr("fname")}));
EXPECT_TRUE(op->getBindNameValue().has_value());
EXPECT_EQ(op->getBindNameValue().value(), "fname");
for (auto d : dtypesWithoutNone)
EXPECT_FALSE(op->getBindNameValue(d).has_value());
op->removeBindNameDeviceTypeAttr();
op->removeBindNameAttr();
EXPECT_TRUE(op->getBindNameValue(DeviceType::Host).has_value());
EXPECT_EQ(std::visit(
[](const auto &attr) -> std::string {
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::StringAttr>) {
return attr.str();
} else {
return attr.getLeafReference().str();
}
},
op->getBindNameValue().value()),
"fname");
EXPECT_EQ(std::visit(
[](const auto &attr) -> std::string {
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::StringAttr>) {
return attr.str();
} else {
return attr.getLeafReference().str();
}
},
op->getBindNameValue(DeviceType::Host).value()),
"test_symbol");
for (auto d : dtypesWithoutNone) {
if (d != DeviceType::Host)
EXPECT_FALSE(op->getBindNameValue(d).has_value());
}
op->removeBindIdNameDeviceTypeAttr();
op->removeBindStrNameDeviceTypeAttr();
op->removeBindIdNameAttr();
op->removeBindStrNameAttr();
}
template <typename Op>