[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types (#112456)
Eliminate `llvm::` namespace qualifier for commonly used types in MLIR TableGen backends to reduce code clutter.
This commit is contained in:
parent
6e02e19cd3
commit
659192b184
@ -17,6 +17,12 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::DefInit;
|
||||
using llvm::Init;
|
||||
using llvm::ListInit;
|
||||
using llvm::Record;
|
||||
using llvm::RecordVal;
|
||||
using llvm::StringInit;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttrOrTypeBuilder
|
||||
@ -35,14 +41,13 @@ bool AttrOrTypeBuilder::hasInferredContextParameter() const {
|
||||
// AttrOrTypeDef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
|
||||
AttrOrTypeDef::AttrOrTypeDef(const Record *def) : def(def) {
|
||||
// Populate the builders.
|
||||
auto *builderList =
|
||||
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
|
||||
const auto *builderList =
|
||||
dyn_cast_or_null<ListInit>(def->getValueInit("builders"));
|
||||
if (builderList && !builderList->empty()) {
|
||||
for (const llvm::Init *init : builderList->getValues()) {
|
||||
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
|
||||
def->getLoc());
|
||||
for (const Init *init : builderList->getValues()) {
|
||||
AttrOrTypeBuilder builder(cast<DefInit>(init)->getDef(), def->getLoc());
|
||||
|
||||
// Ensure that all parameters have names.
|
||||
for (const AttrOrTypeBuilder::Parameter ¶m :
|
||||
@ -56,16 +61,16 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
|
||||
|
||||
// Populate the traits.
|
||||
if (auto *traitList = def->getValueAsListInit("traits")) {
|
||||
SmallPtrSet<const llvm::Init *, 32> traitSet;
|
||||
SmallPtrSet<const Init *, 32> traitSet;
|
||||
traits.reserve(traitSet.size());
|
||||
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
|
||||
[&](const llvm::ListInit *traitList) {
|
||||
llvm::unique_function<void(const ListInit *)> processTraitList =
|
||||
[&](const ListInit *traitList) {
|
||||
for (auto *traitInit : *traitList) {
|
||||
if (!traitSet.insert(traitInit).second)
|
||||
continue;
|
||||
|
||||
// If this is an interface, add any bases to the trait list.
|
||||
auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
|
||||
auto *traitDef = cast<DefInit>(traitInit)->getDef();
|
||||
if (traitDef->isSubClassOf("Interface")) {
|
||||
if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
|
||||
processTraitList(bases);
|
||||
@ -111,7 +116,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
|
||||
}
|
||||
|
||||
Dialect AttrOrTypeDef::getDialect() const {
|
||||
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
|
||||
const auto *dialect = dyn_cast<DefInit>(def->getValue("dialect")->getValue());
|
||||
return Dialect(dialect ? dialect->getDef() : nullptr);
|
||||
}
|
||||
|
||||
@ -126,8 +131,8 @@ StringRef AttrOrTypeDef::getCppBaseClassName() const {
|
||||
}
|
||||
|
||||
bool AttrOrTypeDef::hasDescription() const {
|
||||
const llvm::RecordVal *desc = def->getValue("description");
|
||||
return desc && isa<llvm::StringInit>(desc->getValue());
|
||||
const RecordVal *desc = def->getValue("description");
|
||||
return desc && isa<StringInit>(desc->getValue());
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeDef::getDescription() const {
|
||||
@ -135,8 +140,8 @@ StringRef AttrOrTypeDef::getDescription() const {
|
||||
}
|
||||
|
||||
bool AttrOrTypeDef::hasSummary() const {
|
||||
const llvm::RecordVal *summary = def->getValue("summary");
|
||||
return summary && isa<llvm::StringInit>(summary->getValue());
|
||||
const RecordVal *summary = def->getValue("summary");
|
||||
return summary && isa<StringInit>(summary->getValue());
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeDef::getSummary() const {
|
||||
@ -249,9 +254,9 @@ StringRef TypeDef::getTypeName() const {
|
||||
template <typename InitT>
|
||||
auto AttrOrTypeParameter::getDefValue(StringRef name) const {
|
||||
std::optional<decltype(std::declval<InitT>().getValue())> result;
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
|
||||
if (auto *init = param->getDef()->getValue(name))
|
||||
if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
|
||||
if (const auto *param = dyn_cast<DefInit>(getDef()))
|
||||
if (const auto *init = param->getDef()->getValue(name))
|
||||
if (const auto *value = dyn_cast_or_null<InitT>(init->getValue()))
|
||||
result = value->getValue();
|
||||
return result;
|
||||
}
|
||||
@ -270,20 +275,20 @@ std::string AttrOrTypeParameter::getAccessorName() const {
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
|
||||
return getDefValue<llvm::StringInit>("allocator");
|
||||
return getDefValue<StringInit>("allocator");
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getComparator() const {
|
||||
return getDefValue<llvm::StringInit>("comparator").value_or("$_lhs == $_rhs");
|
||||
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getCppType() const {
|
||||
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
|
||||
if (auto *stringType = dyn_cast<StringInit>(getDef()))
|
||||
return stringType->getValue();
|
||||
auto cppType = getDefValue<llvm::StringInit>("cppType");
|
||||
auto cppType = getDefValue<StringInit>("cppType");
|
||||
if (cppType)
|
||||
return *cppType;
|
||||
if (auto *init = dyn_cast<llvm::DefInit>(getDef()))
|
||||
if (const auto *init = dyn_cast<DefInit>(getDef()))
|
||||
llvm::PrintFatalError(
|
||||
init->getDef()->getLoc(),
|
||||
Twine("Missing `cppType` field in Attribute/Type parameter: ") +
|
||||
@ -295,34 +300,33 @@ StringRef AttrOrTypeParameter::getCppType() const {
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getCppAccessorType() const {
|
||||
return getDefValue<llvm::StringInit>("cppAccessorType")
|
||||
.value_or(getCppType());
|
||||
return getDefValue<StringInit>("cppAccessorType").value_or(getCppType());
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getCppStorageType() const {
|
||||
return getDefValue<llvm::StringInit>("cppStorageType").value_or(getCppType());
|
||||
return getDefValue<StringInit>("cppStorageType").value_or(getCppType());
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getConvertFromStorage() const {
|
||||
return getDefValue<llvm::StringInit>("convertFromStorage").value_or("$_self");
|
||||
return getDefValue<StringInit>("convertFromStorage").value_or("$_self");
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeParameter::getParser() const {
|
||||
return getDefValue<llvm::StringInit>("parser");
|
||||
return getDefValue<StringInit>("parser");
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeParameter::getPrinter() const {
|
||||
return getDefValue<llvm::StringInit>("printer");
|
||||
return getDefValue<StringInit>("printer");
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeParameter::getSummary() const {
|
||||
return getDefValue<llvm::StringInit>("summary");
|
||||
return getDefValue<StringInit>("summary");
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getSyntax() const {
|
||||
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
|
||||
if (auto *stringType = dyn_cast<StringInit>(getDef()))
|
||||
return stringType->getValue();
|
||||
return getDefValue<llvm::StringInit>("syntax").value_or(getCppType());
|
||||
return getDefValue<StringInit>("syntax").value_or(getCppType());
|
||||
}
|
||||
|
||||
bool AttrOrTypeParameter::isOptional() const {
|
||||
@ -330,17 +334,14 @@ bool AttrOrTypeParameter::isOptional() const {
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
|
||||
std::optional<StringRef> result =
|
||||
getDefValue<llvm::StringInit>("defaultValue");
|
||||
std::optional<StringRef> result = getDefValue<StringInit>("defaultValue");
|
||||
return result && !result->empty() ? result : std::nullopt;
|
||||
}
|
||||
|
||||
const llvm::Init *AttrOrTypeParameter::getDef() const {
|
||||
return def->getArg(index);
|
||||
}
|
||||
const Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
|
||||
|
||||
std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
|
||||
if (const auto *param = dyn_cast<DefInit>(getDef()))
|
||||
if (param->getDef()->isSubClassOf("Constraint"))
|
||||
return Constraint(param->getDef());
|
||||
return std::nullopt;
|
||||
@ -351,8 +352,8 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
|
||||
const llvm::Init *paramDef = param->getDef();
|
||||
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
|
||||
const Init *paramDef = param->getDef();
|
||||
if (const auto *paramDefInit = dyn_cast<DefInit>(paramDef))
|
||||
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ StringRef Attribute::getReturnType() const {
|
||||
// Return the type constraint corresponding to the type of this attribute, or
|
||||
// std::nullopt if this is not a TypedAttr.
|
||||
std::optional<Type> Attribute::getValueType() const {
|
||||
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
|
||||
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("valueType")))
|
||||
return Type(defInit->getDef());
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -92,8 +92,7 @@ StringRef Attribute::getConstBuilderTemplate() const {
|
||||
}
|
||||
|
||||
Attribute Attribute::getBaseAttr() const {
|
||||
if (const auto *defInit =
|
||||
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
|
||||
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("baseAttr"))) {
|
||||
return Attribute(defInit).getBaseAttr();
|
||||
}
|
||||
return *this;
|
||||
@ -132,7 +131,7 @@ Dialect Attribute::getDialect() const {
|
||||
return Dialect(nullptr);
|
||||
}
|
||||
|
||||
const llvm::Record &Attribute::getDef() const { return *def; }
|
||||
const Record &Attribute::getDef() const { return *def; }
|
||||
|
||||
ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
|
||||
assert(def->isSubClassOf("ConstantAttr") &&
|
||||
@ -147,12 +146,12 @@ StringRef ConstantAttr::getConstantValue() const {
|
||||
return def->getValueAsString("value");
|
||||
}
|
||||
|
||||
EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
|
||||
EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
|
||||
assert(isSubClassOf("EnumAttrCaseInfo") &&
|
||||
"must be subclass of TableGen 'EnumAttrInfo' class");
|
||||
}
|
||||
|
||||
EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
|
||||
EnumAttrCase::EnumAttrCase(const DefInit *init)
|
||||
: EnumAttrCase(init->getDef()) {}
|
||||
|
||||
StringRef EnumAttrCase::getSymbol() const {
|
||||
@ -163,16 +162,16 @@ StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
|
||||
|
||||
int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
|
||||
|
||||
const llvm::Record &EnumAttrCase::getDef() const { return *def; }
|
||||
const Record &EnumAttrCase::getDef() const { return *def; }
|
||||
|
||||
EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
|
||||
EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
|
||||
assert(isSubClassOf("EnumAttrInfo") &&
|
||||
"must be subclass of TableGen 'EnumAttr' class");
|
||||
}
|
||||
|
||||
EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
|
||||
EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
|
||||
|
||||
EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
|
||||
EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
|
||||
|
||||
bool EnumAttr::classof(const Attribute *attr) {
|
||||
return attr->isSubClassOf("EnumAttrInfo");
|
||||
@ -218,8 +217,8 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
|
||||
std::vector<EnumAttrCase> cases;
|
||||
cases.reserve(inits->size());
|
||||
|
||||
for (const llvm::Init *init : *inits) {
|
||||
cases.emplace_back(cast<llvm::DefInit>(init));
|
||||
for (const Init *init : *inits) {
|
||||
cases.emplace_back(cast<DefInit>(init));
|
||||
}
|
||||
|
||||
return cases;
|
||||
@ -229,7 +228,7 @@ bool EnumAttr::genSpecializedAttr() const {
|
||||
return def->getValueAsBit("genSpecializedAttr");
|
||||
}
|
||||
|
||||
const llvm::Record *EnumAttr::getBaseAttrClass() const {
|
||||
const Record *EnumAttr::getBaseAttrClass() const {
|
||||
return def->getValueAsDef("baseAttrClass");
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,11 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::DagInit;
|
||||
using llvm::DefInit;
|
||||
using llvm::Init;
|
||||
using llvm::Record;
|
||||
using llvm::StringInit;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Builder::Parameter
|
||||
@ -19,9 +24,9 @@ using namespace mlir::tblgen;
|
||||
|
||||
/// Return a string containing the C++ type of this parameter.
|
||||
StringRef Builder::Parameter::getCppType() const {
|
||||
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
|
||||
if (const auto *stringInit = dyn_cast<StringInit>(def))
|
||||
return stringInit->getValue();
|
||||
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
|
||||
const Record *record = cast<DefInit>(def)->getDef();
|
||||
// Inlining the first part of `Record::getValueAsString` to give better
|
||||
// error messages.
|
||||
const llvm::RecordVal *type = record->getValue("type");
|
||||
@ -35,9 +40,9 @@ StringRef Builder::Parameter::getCppType() const {
|
||||
/// Return an optional string containing the default value to use for this
|
||||
/// parameter.
|
||||
std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
|
||||
if (isa<llvm::StringInit>(def))
|
||||
if (isa<StringInit>(def))
|
||||
return std::nullopt;
|
||||
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
|
||||
const Record *record = cast<DefInit>(def)->getDef();
|
||||
std::optional<StringRef> value =
|
||||
record->getValueAsOptionalString("defaultValue");
|
||||
return value && !value->empty() ? value : std::nullopt;
|
||||
@ -47,18 +52,17 @@ std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
|
||||
// Builder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
|
||||
: def(record) {
|
||||
Builder::Builder(const Record *record, ArrayRef<SMLoc> loc) : def(record) {
|
||||
// Initialize the parameters of the builder.
|
||||
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
|
||||
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
|
||||
const DagInit *dag = def->getValueAsDag("dagParams");
|
||||
auto *defInit = dyn_cast<DefInit>(dag->getOperator());
|
||||
if (!defInit || defInit->getDef()->getName() != "ins")
|
||||
PrintFatalError(def->getLoc(), "expected 'ins' in builders");
|
||||
|
||||
bool seenDefaultValue = false;
|
||||
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
|
||||
const llvm::StringInit *paramName = dag->getArgName(i);
|
||||
const llvm::Init *paramValue = dag->getArg(i);
|
||||
const StringInit *paramName = dag->getArgName(i);
|
||||
const Init *paramValue = dag->getArg(i);
|
||||
Parameter param(paramName ? paramName->getValue()
|
||||
: std::optional<StringRef>(),
|
||||
paramValue);
|
||||
|
||||
@ -24,32 +24,32 @@ using namespace mlir::tblgen;
|
||||
|
||||
/// Generate a unique label based on the current file name to prevent name
|
||||
/// collisions if multiple generated files are included at once.
|
||||
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
|
||||
static std::string getUniqueOutputLabel(const RecordKeeper &records,
|
||||
StringRef tag) {
|
||||
// Use the input file name when generating a unique name.
|
||||
std::string inputFilename = records.getInputFilename();
|
||||
|
||||
// Drop all but the base filename.
|
||||
StringRef nameRef = llvm::sys::path::filename(inputFilename);
|
||||
StringRef nameRef = sys::path::filename(inputFilename);
|
||||
nameRef.consume_back(".td");
|
||||
|
||||
// Sanitize any invalid characters.
|
||||
std::string uniqueName(tag);
|
||||
for (char c : nameRef) {
|
||||
if (llvm::isAlnum(c) || c == '_')
|
||||
if (isAlnum(c) || c == '_')
|
||||
uniqueName.push_back(c);
|
||||
else
|
||||
uniqueName.append(llvm::utohexstr((unsigned char)c));
|
||||
uniqueName.append(utohexstr((unsigned char)c));
|
||||
}
|
||||
return uniqueName;
|
||||
}
|
||||
|
||||
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
|
||||
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
|
||||
raw_ostream &os, const RecordKeeper &records, StringRef tag)
|
||||
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitOpConstraints(
|
||||
ArrayRef<const llvm::Record *> opDefs) {
|
||||
ArrayRef<const Record *> opDefs) {
|
||||
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
|
||||
emitTypeConstraints();
|
||||
emitAttrConstraints();
|
||||
@ -58,7 +58,7 @@ void StaticVerifierFunctionEmitter::emitOpConstraints(
|
||||
}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitPatternConstraints(
|
||||
const llvm::ArrayRef<DagLeaf> constraints) {
|
||||
const ArrayRef<DagLeaf> constraints) {
|
||||
collectPatternConstraints(constraints);
|
||||
emitPatternConstraints();
|
||||
}
|
||||
@ -298,7 +298,7 @@ void StaticVerifierFunctionEmitter::collectOpConstraints(
|
||||
}
|
||||
|
||||
void StaticVerifierFunctionEmitter::collectPatternConstraints(
|
||||
const llvm::ArrayRef<DagLeaf> constraints) {
|
||||
const ArrayRef<DagLeaf> constraints) {
|
||||
for (auto &leaf : constraints) {
|
||||
assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
|
||||
collectConstraint(
|
||||
@ -313,7 +313,7 @@ void StaticVerifierFunctionEmitter::collectPatternConstraints(
|
||||
|
||||
std::string mlir::tblgen::escapeString(StringRef value) {
|
||||
std::string ret;
|
||||
llvm::raw_string_ostream os(ret);
|
||||
raw_string_ostream os(ret);
|
||||
os.write_escaped(value);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -16,17 +16,22 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::DagInit;
|
||||
using llvm::DefInit;
|
||||
using llvm::Init;
|
||||
using llvm::ListInit;
|
||||
using llvm::Record;
|
||||
using llvm::StringInit;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InterfaceMethod
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
|
||||
const llvm::DagInit *args = def->getValueAsDag("arguments");
|
||||
InterfaceMethod::InterfaceMethod(const Record *def) : def(def) {
|
||||
const DagInit *args = def->getValueAsDag("arguments");
|
||||
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
|
||||
arguments.push_back(
|
||||
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
|
||||
args->getArgNameStr(i)});
|
||||
arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(),
|
||||
args->getArgNameStr(i)});
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,18 +77,17 @@ bool InterfaceMethod::arg_empty() const { return arguments.empty(); }
|
||||
// Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Interface::Interface(const llvm::Record *def) : def(def) {
|
||||
Interface::Interface(const Record *def) : def(def) {
|
||||
assert(def->isSubClassOf("Interface") &&
|
||||
"must be subclass of TableGen 'Interface' class");
|
||||
|
||||
// Initialize the interface methods.
|
||||
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
|
||||
for (const llvm::Init *init : listInit->getValues())
|
||||
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
|
||||
auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods"));
|
||||
for (const Init *init : listInit->getValues())
|
||||
methods.emplace_back(cast<DefInit>(init)->getDef());
|
||||
|
||||
// Initialize the interface base classes.
|
||||
auto *basesInit =
|
||||
dyn_cast<llvm::ListInit>(def->getValueInit("baseInterfaces"));
|
||||
auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces"));
|
||||
// Chained inheritance will produce duplicates in the base interface set.
|
||||
StringSet<> basesAdded;
|
||||
llvm::unique_function<void(Interface)> addBaseInterfaceFn =
|
||||
@ -98,8 +102,8 @@ Interface::Interface(const llvm::Record *def) : def(def) {
|
||||
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
|
||||
basesAdded.insert(baseInterface.getName());
|
||||
};
|
||||
for (const llvm::Init *init : basesInit->getValues())
|
||||
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
|
||||
for (const Init *init : basesInit->getValues())
|
||||
addBaseInterfaceFn(Interface(cast<DefInit>(init)->getDef()));
|
||||
}
|
||||
|
||||
// Return the name of this interface.
|
||||
|
||||
@ -35,9 +35,12 @@ using namespace mlir::tblgen;
|
||||
|
||||
using llvm::DagInit;
|
||||
using llvm::DefInit;
|
||||
using llvm::Init;
|
||||
using llvm::ListInit;
|
||||
using llvm::Record;
|
||||
using llvm::StringInit;
|
||||
|
||||
Operator::Operator(const llvm::Record &def)
|
||||
Operator::Operator(const Record &def)
|
||||
: dialect(def.getValueAsDef("opDialect")), def(def) {
|
||||
// The first `_` in the op's TableGen def name is treated as separating the
|
||||
// dialect prefix and the op class name. The dialect prefix will be ignored if
|
||||
@ -179,7 +182,7 @@ StringRef Operator::getExtraClassDefinition() const {
|
||||
return def.getValueAsString(attr);
|
||||
}
|
||||
|
||||
const llvm::Record &Operator::getDef() const { return def; }
|
||||
const Record &Operator::getDef() const { return def; }
|
||||
|
||||
bool Operator::skipDefaultBuilders() const {
|
||||
return def.getValueAsBit("skipDefaultBuilders");
|
||||
@ -429,7 +432,7 @@ void Operator::populateTypeInferenceInfo(
|
||||
// Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
|
||||
// result type inference graph.
|
||||
for (const Trait &trait : traits) {
|
||||
const llvm::Record &def = trait.getDef();
|
||||
const Record &def = trait.getDef();
|
||||
|
||||
// If the infer type op interface was manually added, then treat it as
|
||||
// intention that the op needs special handling.
|
||||
@ -614,9 +617,8 @@ void Operator::populateOpStructure() {
|
||||
def.getLoc(),
|
||||
"unsupported attribute modelling, only single class expected");
|
||||
}
|
||||
attributes.push_back(
|
||||
{cast<llvm::StringInit>(val.getNameInit())->getValue(),
|
||||
Attribute(cast<DefInit>(val.getValue()))});
|
||||
attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(),
|
||||
Attribute(cast<DefInit>(val.getValue()))});
|
||||
}
|
||||
}
|
||||
|
||||
@ -701,7 +703,7 @@ void Operator::populateOpStructure() {
|
||||
// tablegen is easy, making them unique less so, so dedupe here.
|
||||
if (auto *traitList = def.getValueAsListInit("traits")) {
|
||||
// This is uniquing based on pointers of the trait.
|
||||
SmallPtrSet<const llvm::Init *, 32> traitSet;
|
||||
SmallPtrSet<const Init *, 32> traitSet;
|
||||
traits.reserve(traitSet.size());
|
||||
|
||||
// The declaration order of traits imply the verification order of traits.
|
||||
@ -721,8 +723,8 @@ void Operator::populateOpStructure() {
|
||||
" to precede it in traits list");
|
||||
};
|
||||
|
||||
std::function<void(const llvm::ListInit *)> insert;
|
||||
insert = [&](const llvm::ListInit *traitList) {
|
||||
std::function<void(const ListInit *)> insert;
|
||||
insert = [&](const ListInit *traitList) {
|
||||
for (auto *traitInit : *traitList) {
|
||||
auto *def = cast<DefInit>(traitInit)->getDef();
|
||||
if (def->isSubClassOf("TraitList")) {
|
||||
@ -777,11 +779,10 @@ void Operator::populateOpStructure() {
|
||||
}
|
||||
|
||||
// Populate the builders.
|
||||
auto *builderList =
|
||||
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
|
||||
auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
|
||||
if (builderList && !builderList->empty()) {
|
||||
for (const llvm::Init *init : builderList->getValues())
|
||||
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
|
||||
for (const Init *init : builderList->getValues())
|
||||
builders.emplace_back(cast<DefInit>(init)->getDef(), def.getLoc());
|
||||
} else if (skipDefaultBuilders()) {
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
@ -814,13 +815,12 @@ StringRef Operator::getSummary() const {
|
||||
|
||||
bool Operator::hasAssemblyFormat() const {
|
||||
auto *valueInit = def.getValueInit("assemblyFormat");
|
||||
return isa<llvm::StringInit>(valueInit);
|
||||
return isa<StringInit>(valueInit);
|
||||
}
|
||||
|
||||
StringRef Operator::getAssemblyFormat() const {
|
||||
return TypeSwitch<const llvm::Init *, StringRef>(
|
||||
def.getValueInit("assemblyFormat"))
|
||||
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
|
||||
return TypeSwitch<const Init *, StringRef>(def.getValueInit("assemblyFormat"))
|
||||
.Case<StringInit>([&](auto *init) { return init->getValue(); });
|
||||
}
|
||||
|
||||
void Operator::print(llvm::raw_ostream &os) const {
|
||||
@ -833,9 +833,9 @@ void Operator::print(llvm::raw_ostream &os) const {
|
||||
}
|
||||
}
|
||||
|
||||
auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
|
||||
auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
|
||||
-> VariableDecorator {
|
||||
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
|
||||
return VariableDecorator(cast<DefInit>(init)->getDef());
|
||||
}
|
||||
|
||||
auto Operator::getArgToOperandOrAttribute(int index) const
|
||||
|
||||
@ -26,7 +26,12 @@
|
||||
using namespace mlir;
|
||||
using namespace tblgen;
|
||||
|
||||
using llvm::DagInit;
|
||||
using llvm::dbgs;
|
||||
using llvm::DefInit;
|
||||
using llvm::formatv;
|
||||
using llvm::IntInit;
|
||||
using llvm::Record;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DagLeaf
|
||||
@ -61,31 +66,31 @@ bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
|
||||
Constraint DagLeaf::getAsConstraint() const {
|
||||
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||
"the DAG leaf must be operand or attribute");
|
||||
return Constraint(cast<llvm::DefInit>(def)->getDef());
|
||||
return Constraint(cast<DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
ConstantAttr DagLeaf::getAsConstantAttr() const {
|
||||
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
|
||||
return ConstantAttr(cast<llvm::DefInit>(def));
|
||||
return ConstantAttr(cast<DefInit>(def));
|
||||
}
|
||||
|
||||
EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
|
||||
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
|
||||
return EnumAttrCase(cast<llvm::DefInit>(def));
|
||||
return EnumAttrCase(cast<DefInit>(def));
|
||||
}
|
||||
|
||||
std::string DagLeaf::getConditionTemplate() const {
|
||||
return getAsConstraint().getConditionTemplate();
|
||||
}
|
||||
|
||||
llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
|
||||
StringRef DagLeaf::getNativeCodeTemplate() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
|
||||
return cast<DefInit>(def)->getDef()->getValueAsString("expression");
|
||||
}
|
||||
|
||||
int DagLeaf::getNumReturnsOfNativeCode() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
|
||||
return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
|
||||
}
|
||||
|
||||
std::string DagLeaf::getStringAttr() const {
|
||||
@ -93,7 +98,7 @@ std::string DagLeaf::getStringAttr() const {
|
||||
return def->getAsUnquotedString();
|
||||
}
|
||||
bool DagLeaf::isSubClassOf(StringRef superclass) const {
|
||||
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
|
||||
if (auto *defInit = dyn_cast_or_null<DefInit>(def))
|
||||
return defInit->getDef()->isSubClassOf(superclass);
|
||||
return false;
|
||||
}
|
||||
@ -108,7 +113,7 @@ void DagLeaf::print(raw_ostream &os) const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool DagNode::isNativeCodeCall() const {
|
||||
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
|
||||
if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
|
||||
return defInit->getDef()->isSubClassOf("NativeCodeCall");
|
||||
return false;
|
||||
}
|
||||
@ -119,25 +124,24 @@ bool DagNode::isOperation() const {
|
||||
!isVariadic();
|
||||
}
|
||||
|
||||
llvm::StringRef DagNode::getNativeCodeTemplate() const {
|
||||
StringRef DagNode::getNativeCodeTemplate() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
return cast<DefInit>(node->getOperator())
|
||||
->getDef()
|
||||
->getValueAsString("expression");
|
||||
}
|
||||
|
||||
int DagNode::getNumReturnsOfNativeCode() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
return cast<DefInit>(node->getOperator())
|
||||
->getDef()
|
||||
->getValueAsInt("numReturns");
|
||||
}
|
||||
|
||||
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
|
||||
StringRef DagNode::getSymbol() const { return node->getNameStr(); }
|
||||
|
||||
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
|
||||
const llvm::Record *opDef =
|
||||
cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
auto [it, inserted] = mapper->try_emplace(opDef);
|
||||
if (inserted)
|
||||
it->second = std::make_unique<Operator>(opDef);
|
||||
@ -158,11 +162,11 @@ int DagNode::getNumOps() const {
|
||||
int DagNode::getNumArgs() const { return node->getNumArgs(); }
|
||||
|
||||
bool DagNode::isNestedDagArg(unsigned index) const {
|
||||
return isa<llvm::DagInit>(node->getArg(index));
|
||||
return isa<DagInit>(node->getArg(index));
|
||||
}
|
||||
|
||||
DagNode DagNode::getArgAsNestedDag(unsigned index) const {
|
||||
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
|
||||
return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
|
||||
}
|
||||
|
||||
DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
|
||||
@ -175,27 +179,27 @@ StringRef DagNode::getArgName(unsigned index) const {
|
||||
}
|
||||
|
||||
bool DagNode::isReplaceWithValue() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "replaceWithValue";
|
||||
}
|
||||
|
||||
bool DagNode::isLocationDirective() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "location";
|
||||
}
|
||||
|
||||
bool DagNode::isReturnTypeDirective() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "returnType";
|
||||
}
|
||||
|
||||
bool DagNode::isEither() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "either";
|
||||
}
|
||||
|
||||
bool DagNode::isVariadic() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "variadic";
|
||||
}
|
||||
|
||||
@ -246,7 +250,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
|
||||
LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
if (op)
|
||||
@ -277,26 +281,26 @@ std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||
LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
|
||||
std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
|
||||
return std::string(
|
||||
formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
|
||||
LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
|
||||
return std::string(
|
||||
formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
|
||||
LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
assert(index < 0);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Operand: {
|
||||
@ -307,11 +311,11 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
// the value itself.
|
||||
if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Result: {
|
||||
@ -323,14 +327,14 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
if (!op->getResult(index).isVariadic())
|
||||
v = std::string(formatv("(*{0}.begin())", v));
|
||||
auto repl = formatv(fmt, v);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
|
||||
// If this op has no result at all but still we bind a symbol to it, it
|
||||
// means we want to capture the op itself.
|
||||
if (op->getNumResults() == 0) {
|
||||
LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
|
||||
LLVM_DEBUG(dbgs() << name << " (Op)\n");
|
||||
return formatv(fmt, name);
|
||||
}
|
||||
|
||||
@ -347,14 +351,14 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
values.push_back(std::string(formatv(fmt, v)));
|
||||
}
|
||||
auto repl = llvm::join(values, separator);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Value: {
|
||||
assert(index < 0);
|
||||
assert(op == nullptr);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (Value)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
@ -363,13 +367,13 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
if (index >= 0) {
|
||||
std::string repl =
|
||||
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
|
||||
return repl;
|
||||
}
|
||||
// If it doesn't specify certain element, unpack them all.
|
||||
auto repl =
|
||||
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
}
|
||||
@ -378,19 +382,19 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
|
||||
LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Operand: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Result: {
|
||||
if (index >= 0) {
|
||||
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
|
||||
@ -404,14 +408,14 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
|
||||
}
|
||||
auto repl = llvm::join(values, separator);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Value: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
assert(op == nullptr);
|
||||
auto repl = formatv(fmt, formatv("{{{0}}", name));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (Value)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
@ -420,12 +424,12 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
if (index >= 0) {
|
||||
std::string repl =
|
||||
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
|
||||
return repl;
|
||||
}
|
||||
auto repl =
|
||||
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
}
|
||||
@ -614,7 +618,7 @@ void SymbolInfoMap::assignUniqueAlternativeNames() {
|
||||
// Pattern
|
||||
//==----------------------------------------------------------------------===//
|
||||
|
||||
Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||
Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
|
||||
: def(*def), recordOpMap(mapper) {}
|
||||
|
||||
DagNode Pattern::getSourcePattern() const {
|
||||
@ -628,26 +632,26 @@ int Pattern::getNumResultPatterns() const {
|
||||
|
||||
DagNode Pattern::getResultPattern(unsigned index) const {
|
||||
auto *results = def.getValueAsListInit("resultPatterns");
|
||||
return DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
return DagNode(cast<DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
|
||||
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
|
||||
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
|
||||
infoMap.assignUniqueAlternativeNames();
|
||||
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
|
||||
}
|
||||
|
||||
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
|
||||
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
|
||||
auto pattern = getResultPattern(i);
|
||||
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
|
||||
LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
|
||||
}
|
||||
|
||||
const Operator &Pattern::getSourceRootOp() {
|
||||
@ -664,7 +668,7 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
||||
ret.reserve(listInit->size());
|
||||
|
||||
for (auto *it : *listInit) {
|
||||
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||
auto *dagInit = dyn_cast<DagInit>(it);
|
||||
if (!dagInit)
|
||||
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
||||
"constraints should be DAG nodes");
|
||||
@ -680,7 +684,7 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
||||
entities.emplace_back(argName->getValue());
|
||||
}
|
||||
|
||||
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
|
||||
ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
|
||||
dagInit->getNameStr(), std::move(entities));
|
||||
}
|
||||
return ret;
|
||||
@ -693,19 +697,19 @@ int Pattern::getNumSupplementalPatterns() const {
|
||||
|
||||
DagNode Pattern::getSupplementalPattern(unsigned index) const {
|
||||
auto *results = def.getValueAsListInit("supplementalPatterns");
|
||||
return DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
return DagNode(cast<DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
int Pattern::getBenefit() const {
|
||||
// The initial benefit value is a heuristic with number of ops in the source
|
||||
// pattern.
|
||||
int initBenefit = getSourcePattern().getNumOps();
|
||||
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
|
||||
const DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||
if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
|
||||
PrintFatalError(&def,
|
||||
"The 'addBenefit' takes and only takes one integer value");
|
||||
}
|
||||
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||
return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
|
||||
}
|
||||
|
||||
std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
|
||||
@ -736,8 +740,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
if (tree.isNativeCodeCall()) {
|
||||
if (!treeName.empty()) {
|
||||
if (!isSrcPattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
|
||||
<< treeName << '\n');
|
||||
LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
|
||||
<< treeName << '\n');
|
||||
verifyBind(
|
||||
infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
|
||||
treeName);
|
||||
@ -820,8 +824,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
// The name attached to the DAG node's operator is for representing the
|
||||
// results generated from this op. It should be remembered as bound results.
|
||||
if (!treeName.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "found symbol bound to op result: " << treeName << '\n');
|
||||
LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
|
||||
<< '\n');
|
||||
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
|
||||
}
|
||||
|
||||
@ -896,8 +900,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
// `$_` is a special symbol meaning ignore the current argument.
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||
<< treeArgName << '\n');
|
||||
LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
|
||||
<< treeArgName << '\n');
|
||||
verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
|
||||
treeArgName);
|
||||
}
|
||||
|
||||
@ -20,15 +20,18 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace tblgen;
|
||||
using llvm::Init;
|
||||
using llvm::Record;
|
||||
using llvm::SpecificBumpPtrAllocator;
|
||||
|
||||
// Construct a Predicate from a record.
|
||||
Pred::Pred(const llvm::Record *record) : def(record) {
|
||||
Pred::Pred(const Record *record) : def(record) {
|
||||
assert(def->isSubClassOf("Pred") &&
|
||||
"must be a subclass of TableGen 'Pred' class");
|
||||
}
|
||||
|
||||
// Construct a Predicate from an initializer.
|
||||
Pred::Pred(const llvm::Init *init) {
|
||||
Pred::Pred(const Init *init) {
|
||||
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
|
||||
def = defInit->getDef();
|
||||
}
|
||||
@ -48,12 +51,12 @@ bool Pred::isCombined() const {
|
||||
|
||||
ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
|
||||
|
||||
CPred::CPred(const llvm::Record *record) : Pred(record) {
|
||||
CPred::CPred(const Record *record) : Pred(record) {
|
||||
assert(def->isSubClassOf("CPred") &&
|
||||
"must be a subclass of Tablegen 'CPred' class");
|
||||
}
|
||||
|
||||
CPred::CPred(const llvm::Init *init) : Pred(init) {
|
||||
CPred::CPred(const Init *init) : Pred(init) {
|
||||
assert((!def || def->isSubClassOf("CPred")) &&
|
||||
"must be a subclass of Tablegen 'CPred' class");
|
||||
}
|
||||
@ -64,22 +67,22 @@ std::string CPred::getConditionImpl() const {
|
||||
return std::string(def->getValueAsString("predExpr"));
|
||||
}
|
||||
|
||||
CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
|
||||
CombinedPred::CombinedPred(const Record *record) : Pred(record) {
|
||||
assert(def->isSubClassOf("CombinedPred") &&
|
||||
"must be a subclass of Tablegen 'CombinedPred' class");
|
||||
}
|
||||
|
||||
CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
|
||||
CombinedPred::CombinedPred(const Init *init) : Pred(init) {
|
||||
assert((!def || def->isSubClassOf("CombinedPred")) &&
|
||||
"must be a subclass of Tablegen 'CombinedPred' class");
|
||||
}
|
||||
|
||||
const llvm::Record *CombinedPred::getCombinerDef() const {
|
||||
const Record *CombinedPred::getCombinerDef() const {
|
||||
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
|
||||
return def->getValueAsDef("kind");
|
||||
}
|
||||
|
||||
std::vector<const llvm::Record *> CombinedPred::getChildren() const {
|
||||
std::vector<const Record *> CombinedPred::getChildren() const {
|
||||
assert(def->getValue("children") &&
|
||||
"CombinedPred must have a value 'children'");
|
||||
return def->getValueAsListOfDefs("children");
|
||||
@ -156,7 +159,7 @@ static void performSubstitutions(std::string &str,
|
||||
// All nodes are created within "allocator".
|
||||
static PredNode *
|
||||
buildPredicateTree(const Pred &root,
|
||||
llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
|
||||
SpecificBumpPtrAllocator<PredNode> &allocator,
|
||||
ArrayRef<Subst> substitutions) {
|
||||
auto *rootNode = allocator.Allocate();
|
||||
new (rootNode) PredNode;
|
||||
@ -351,7 +354,7 @@ static std::string getCombinedCondition(const PredNode &root) {
|
||||
}
|
||||
|
||||
std::string CombinedPred::getConditionImpl() const {
|
||||
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
|
||||
SpecificBumpPtrAllocator<PredNode> allocator;
|
||||
auto *predicateTree = buildPredicateTree(*this, allocator, {});
|
||||
predicateTree =
|
||||
propagateGroundTruth(predicateTree,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::Record;
|
||||
|
||||
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
|
||||
: TypeConstraint(init->getDef()) {}
|
||||
@ -42,7 +43,7 @@ StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const {
|
||||
// Returns the builder call for this constraint if this is a buildable type,
|
||||
// returns std::nullopt otherwise.
|
||||
std::optional<StringRef> TypeConstraint::getBuilderCall() const {
|
||||
const llvm::Record *baseType = def;
|
||||
const Record *baseType = def;
|
||||
if (isVariableLength())
|
||||
baseType = baseType->getValueAsDef("baseType");
|
||||
|
||||
@ -64,7 +65,7 @@ StringRef TypeConstraint::getCppType() const {
|
||||
return def->getValueAsString("cppType");
|
||||
}
|
||||
|
||||
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
|
||||
Type::Type(const Record *record) : TypeConstraint(record) {}
|
||||
|
||||
Dialect Type::getDialect() const {
|
||||
return Dialect(def->getValueAsDef("dialect"));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user