- Change OpClass new method addition to find and eliminate any existing methods that are made redundant by the newly added method, as well as detect if the newly added method will be redundant and return nullptr in that case. - To facilitate that, add the notion of resolved and unresolved parameters, where resolved parameters have each parameter type known, so that redundancy checks on methods with same name but different parameter types can be done. - Eliminate existing code to avoid adding conflicting/redundant build methods and rely on this new mechanism to eliminate conflicting build methods. Fixes https://bugs.llvm.org/show_bug.cgi?id=47095 Differential Revision: https://reviews.llvm.org/D87059
336 lines
11 KiB
C++
336 lines
11 KiB
C++
//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/TableGen/OpClass.h"
|
|
|
|
#include "mlir/TableGen/Format.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <unordered_set>
|
|
|
|
#define DEBUG_TYPE "mlir-tblgen-opclass"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tblgen;
|
|
|
|
namespace {
|
|
|
|
// Returns space to be emitted after the given C++ `type`. return "" if the
|
|
// ends with '&' or '*', or is empty, else returns " ".
|
|
StringRef getSpaceAfterType(StringRef type) {
|
|
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
|
|
}
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodParameter definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
|
|
if (properties & PP_Optional)
|
|
os << "/*optional*/";
|
|
os << type << getSpaceAfterType(type) << name;
|
|
if (emitDefault && !defaultValue.empty())
|
|
os << " = " << defaultValue;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodParameters definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Factory methods to construct the correct type of `OpMethodParameters`
|
|
// object based on the arguments.
|
|
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
|
|
return std::make_unique<OpMethodResolvedParameters>();
|
|
}
|
|
|
|
std::unique_ptr<OpMethodParameters>
|
|
OpMethodParameters::create(StringRef params) {
|
|
return std::make_unique<OpMethodUnresolvedParameters>(params);
|
|
}
|
|
|
|
std::unique_ptr<OpMethodParameters>
|
|
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
|
|
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
|
|
}
|
|
|
|
std::unique_ptr<OpMethodParameters>
|
|
OpMethodParameters::create(StringRef type, StringRef name,
|
|
StringRef defaultValue) {
|
|
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodUnresolvedParameters definitions
|
|
//===----------------------------------------------------------------------===//
|
|
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
|
|
os << parameters;
|
|
}
|
|
|
|
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
|
|
// We need to remove the default values for parameters in method definition.
|
|
// TODO: We are using '=' and ',' as delimiters for parameter
|
|
// initializers. This is incorrect for initializer list with more than one
|
|
// element. Change to a more robust approach.
|
|
llvm::SmallVector<StringRef, 4> tokens;
|
|
StringRef params = parameters;
|
|
while (!params.empty()) {
|
|
std::pair<StringRef, StringRef> parts = params.split("=");
|
|
tokens.push_back(parts.first);
|
|
params = parts.second.split(',').second;
|
|
}
|
|
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodResolvedParameters definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Returns true if a method with these parameters makes a method with parameters
|
|
// `other` redundant. This should return true only if all possible calls to the
|
|
// other method can be replaced by calls to this method.
|
|
bool OpMethodResolvedParameters::makesRedundant(
|
|
const OpMethodResolvedParameters &other) const {
|
|
const size_t otherNumParams = other.getNumParameters();
|
|
const size_t thisNumParams = getNumParameters();
|
|
|
|
// All calls to the other method can be replaced this method only if this
|
|
// method has the same or more arguments number of arguments as the other, and
|
|
// the common arguments have the same type.
|
|
if (thisNumParams < otherNumParams)
|
|
return false;
|
|
for (int idx : llvm::seq<int>(0, otherNumParams))
|
|
if (parameters[idx].getType() != other.parameters[idx].getType())
|
|
return false;
|
|
|
|
// If all the common arguments have the same type, we can elide the other
|
|
// method if this method has the same number of arguments as other or the
|
|
// first argument after the common ones has a default value (and by C++
|
|
// requirement, all the later ones will also have a default value).
|
|
return thisNumParams == otherNumParams ||
|
|
parameters[otherNumParams].hasDefaultValue();
|
|
}
|
|
|
|
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
|
|
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
|
param.writeDeclTo(os);
|
|
});
|
|
}
|
|
|
|
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
|
|
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
|
param.writeDefTo(os);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodSignature definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Returns if a method with this signature makes a method with `other` signature
|
|
// redundant. Only supports resolved parameters.
|
|
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
|
|
if (methodName != other.methodName)
|
|
return false;
|
|
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
|
|
auto *resolvedOther =
|
|
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
|
|
if (resolvedThis && resolvedOther)
|
|
return resolvedThis->makesRedundant(*resolvedOther);
|
|
return false;
|
|
}
|
|
|
|
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
|
|
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
|
|
parameters->writeDeclTo(os);
|
|
os << ")";
|
|
}
|
|
|
|
void OpMethodSignature::writeDefTo(raw_ostream &os,
|
|
StringRef namePrefix) const {
|
|
os << returnType << getSpaceAfterType(returnType) << namePrefix
|
|
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
|
|
parameters->writeDefTo(os);
|
|
os << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethodBody definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
|
|
|
|
OpMethodBody &OpMethodBody::operator<<(Twine content) {
|
|
if (isEffective)
|
|
body.append(content.str());
|
|
return *this;
|
|
}
|
|
|
|
OpMethodBody &OpMethodBody::operator<<(int content) {
|
|
if (isEffective)
|
|
body.append(std::to_string(content));
|
|
return *this;
|
|
}
|
|
|
|
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
|
|
if (isEffective)
|
|
body.append(content.str());
|
|
return *this;
|
|
}
|
|
|
|
void OpMethodBody::writeTo(raw_ostream &os) const {
|
|
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
|
|
os << bodyRef;
|
|
if (bodyRef.empty() || bodyRef.back() != '\n')
|
|
os << "\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpMethod definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void OpMethod::writeDeclTo(raw_ostream &os) const {
|
|
os.indent(2);
|
|
if (isStatic())
|
|
os << "static ";
|
|
methodSignature.writeDeclTo(os);
|
|
os << ";";
|
|
}
|
|
|
|
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
|
// Do not write definition if the method is decl only.
|
|
if (properties & MP_Declaration)
|
|
return;
|
|
methodSignature.writeDefTo(os, namePrefix);
|
|
os << " {\n";
|
|
methodBody.writeTo(os);
|
|
os << "}";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpConstructor definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
|
|
memberInitializers.append(std::string(llvm::formatv(
|
|
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
|
|
}
|
|
|
|
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
|
// Do not write definition if the method is decl only.
|
|
if (properties & MP_Declaration)
|
|
return;
|
|
|
|
methodSignature.writeDefTo(os, namePrefix);
|
|
os << " " << memberInitializers << " {\n";
|
|
methodBody.writeTo(os);
|
|
os << "}";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Class definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Class::Class(StringRef name) : className(name) {}
|
|
|
|
void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
|
|
std::string varName = formatv("{0} {1}", type, name).str();
|
|
std::string field = defaultValue.empty()
|
|
? varName
|
|
: formatv("{0} = {1}", varName, defaultValue).str();
|
|
fields.push_back(std::move(field));
|
|
}
|
|
void Class::writeDeclTo(raw_ostream &os) const {
|
|
bool hasPrivateMethod = false;
|
|
os << "class " << className << " {\n";
|
|
os << "public:\n";
|
|
|
|
forAllMethods([&](const OpMethod &method) {
|
|
if (!method.isPrivate()) {
|
|
method.writeDeclTo(os);
|
|
os << '\n';
|
|
} else {
|
|
hasPrivateMethod = true;
|
|
}
|
|
});
|
|
|
|
os << '\n';
|
|
os << "private:\n";
|
|
if (hasPrivateMethod) {
|
|
forAllMethods([&](const OpMethod &method) {
|
|
if (method.isPrivate()) {
|
|
method.writeDeclTo(os);
|
|
os << '\n';
|
|
}
|
|
});
|
|
os << '\n';
|
|
}
|
|
|
|
for (const auto &field : fields)
|
|
os.indent(2) << field << ";\n";
|
|
os << "};\n";
|
|
}
|
|
|
|
void Class::writeDefTo(raw_ostream &os) const {
|
|
forAllMethods([&](const OpMethod &method) {
|
|
method.writeDefTo(os, className);
|
|
os << "\n\n";
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpClass definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
|
|
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
|
|
|
|
void OpClass::addTrait(Twine trait) {
|
|
auto traitStr = trait.str();
|
|
if (traitsSet.insert(traitStr).second)
|
|
traitsVec.push_back(std::move(traitStr));
|
|
}
|
|
|
|
void OpClass::writeDeclTo(raw_ostream &os) const {
|
|
os << "class " << className << " : public ::mlir::Op<" << className;
|
|
for (const auto &trait : traitsVec)
|
|
os << ", " << trait;
|
|
os << "> {\npublic:\n";
|
|
os << " using Op::Op;\n";
|
|
os << " using Adaptor = " << className << "Adaptor;\n";
|
|
|
|
bool hasPrivateMethod = false;
|
|
forAllMethods([&](const OpMethod &method) {
|
|
if (!method.isPrivate()) {
|
|
method.writeDeclTo(os);
|
|
os << "\n";
|
|
} else {
|
|
hasPrivateMethod = true;
|
|
}
|
|
});
|
|
|
|
// TODO: Add line control markers to make errors easier to debug.
|
|
if (!extraClassDeclaration.empty())
|
|
os << extraClassDeclaration << "\n";
|
|
|
|
if (hasPrivateMethod) {
|
|
os << "\nprivate:\n";
|
|
forAllMethods([&](const OpMethod &method) {
|
|
if (method.isPrivate()) {
|
|
method.writeDeclTo(os);
|
|
os << "\n";
|
|
}
|
|
});
|
|
}
|
|
|
|
os << "};\n";
|
|
}
|