llvm-project/mlir/lib/Target/Cpp/TranslateToCpp.cpp
Kazu Hirata 1a0f482de8
[mlir] Remove unused includes (NFC) (#150476)
These are identified by misc-include-cleaner.  I've filtered out those
that break builds.  Also, I'm staying away from llvm-config.h,
config.h, and Compiler.h, which likely cause platform- or
compiler-specific build failures.
2025-07-24 11:23:53 -07:00

1902 lines
62 KiB
C++

//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <stack>
#define DEBUG_TYPE "translate-to-cpp"
using namespace mlir;
using namespace mlir::emitc;
using llvm::formatv;
/// Convenience functions to produce interleaved output with functions returning
/// a LogicalResult. This is different than those in STLExtras as functions used
/// on each element doesn't return a string.
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor>
inline LogicalResult
interleaveWithError(ForwardIterator begin, ForwardIterator end,
UnaryFunctor eachFn, NullaryFunctor betweenFn) {
if (begin == end)
return success();
if (failed(eachFn(*begin)))
return failure();
++begin;
for (; begin != end; ++begin) {
betweenFn();
if (failed(eachFn(*begin)))
return failure();
}
return success();
}
template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
inline LogicalResult interleaveWithError(const Container &c,
UnaryFunctor eachFn,
NullaryFunctor betweenFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
}
template <typename Container, typename UnaryFunctor>
inline LogicalResult interleaveCommaWithError(const Container &c,
raw_ostream &os,
UnaryFunctor eachFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
/// Return the precedence of a operator as an integer, higher values
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
.Case<emitc::AddOp>([&](auto op) { return 12; })
.Case<emitc::ApplyOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
.Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
.Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
.Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
.Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
.Case<emitc::CallOp>([&](auto op) { return 16; })
.Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
.Case<emitc::CastOp>([&](auto op) { return 15; })
.Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
switch (op.getPredicate()) {
case emitc::CmpPredicate::eq:
case emitc::CmpPredicate::ne:
return 8;
case emitc::CmpPredicate::lt:
case emitc::CmpPredicate::le:
case emitc::CmpPredicate::gt:
case emitc::CmpPredicate::ge:
return 9;
case emitc::CmpPredicate::three_way:
return 10;
}
return op->emitError("unsupported cmp predicate");
})
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LoadOp>([&](auto op) { return 16; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
.Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
.Case<emitc::MulOp>([&](auto op) { return 13; })
.Case<emitc::RemOp>([&](auto op) { return 13; })
.Case<emitc::SubOp>([&](auto op) { return 12; })
.Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
.Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
.Default([](auto op) { return op->emitError("unsupported operation"); });
}
namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
StringRef fileId);
/// Emits attribute or returns failure.
LogicalResult emitAttribute(Location loc, Attribute attr);
/// Emits operation 'op' with/without training semicolon or returns failure.
///
/// For operations that should never be followed by a semicolon, like ForOp,
/// the `trailingSemicolon` argument is ignored and a semicolon is not
/// emitted.
LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
/// Emits type 'type' or returns failure.
LogicalResult emitType(Location loc, Type type);
/// Emits array of types as a std::tuple of the emitted types.
/// - emits void for an empty array;
/// - emits the type of the only element for arrays of size one;
/// - emits a std::tuple otherwise;
LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
/// Emits array of types as a std::tuple of the emitted types independently of
/// the array size.
LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
/// Emits an assignment for a variable which has been declared previously.
LogicalResult emitVariableAssignment(OpResult result);
/// Emits a variable declaration for a result of an operation.
LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon);
/// Emits a declaration of a variable with the given type and name.
LogicalResult emitVariableDeclaration(Location loc, Type type,
StringRef name);
/// Emits the variable declaration and assignment prefix for 'op'.
/// - emits separate variable followed by std::tie for multi-valued operation;
/// - emits single type followed by variable for single result;
/// - emits nothing if no value produced by op;
/// Emits final '=' operator where a type is produced. Returns failure if
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);
/// Emits a global variable declaration or definition.
LogicalResult emitGlobalVariable(GlobalOp op);
/// Emits a label for the block.
LogicalResult emitLabel(Block &block);
/// Emits the operands and atttributes of the operation. All operands are
/// emitted first and then all attributes in alphabetical order.
LogicalResult emitOperandsAndAttributes(Operation &op,
ArrayRef<StringRef> exclude = {});
/// Emits the operands of the operation. All operands are emitted in order.
LogicalResult emitOperands(Operation &op);
/// Emits value as an operands of an operation
LogicalResult emitOperand(Value value);
/// Emit an expression as a C expression.
LogicalResult emitExpression(ExpressionOp expressionOp);
/// Insert the expression representing the operation into the value cache.
void cacheDeferredOpResult(Value value, StringRef str);
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);
/// Return the existing or a new name for a loop induction variable of an
/// emitc::ForOp.
StringRef getOrCreateInductionVarName(Value val);
// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);
// Returns the textual representation of a member (of object) operation.
std::string createMemberAccess(emitc::MemberOp op);
// Returns the textual representation of a member of pointer operation.
std::string createMemberAccess(emitc::MemberOfPtrOp op);
/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);
/// Whether to map an mlir integer to a unsigned integer in C++.
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
/// Abstract RAII helper function to manage entering/exiting C++ scopes.
struct Scope {
~Scope() { emitter.labelInScopeCount.pop(); }
private:
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
protected:
Scope(CppEmitter &emitter)
: valueMapperScope(emitter.valueMapper),
blockMapperScope(emitter.blockMapper), emitter(emitter) {
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
}
CppEmitter &emitter;
};
/// RAII helper function to manage entering/exiting functions, while re-using
/// value names.
struct FunctionScope : Scope {
FunctionScope(CppEmitter &emitter) : Scope(emitter) {
// Re-use value names.
emitter.resetValueCounter();
}
};
/// RAII helper function to manage entering/exiting emitc::forOp loops and
/// handle induction variable naming.
struct LoopScope : Scope {
LoopScope(CppEmitter &emitter) : Scope(emitter) {
emitter.increaseLoopNestingLevel();
}
~LoopScope() { emitter.decreaseLoopNestingLevel(); }
};
/// Returns wether the Value is assigned to a C++ variable in the scope.
bool hasValueInScope(Value val);
// Returns whether a label is assigned to the block.
bool hasBlockLabel(Block &block);
/// Returns the output stream.
raw_indented_ostream &ostream() { return os; };
/// Returns if all variables for op results and basic block arguments need to
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
return !fileId.empty() && file.getId() == fileId;
}
/// Get expression currently being emitted.
ExpressionOp getEmittedExpression() { return emittedExpression; }
/// Determine whether given value is part of the expression potentially being
/// emitted.
bool isPartOfCurrentExpression(Value value) {
if (!emittedExpression)
return false;
Operation *def = value.getDefiningOp();
if (!def)
return false;
auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
return operandExpression == emittedExpression;
};
// Resets the value counter to 0.
void resetValueCounter();
// Increases the loop nesting level by 1.
void increaseLoopNestingLevel();
// Decreases the loop nesting level by 1.
void decreaseLoopNestingLevel();
private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
/// Output stream to emit to.
raw_indented_ostream os;
/// Boolean to enforce that all variables for op results and block
/// arguments are declared at the beginning of the function. This also
/// includes results from ops located in nested regions.
bool declareVariablesAtTop;
/// Only emit file ops whos id matches this value.
std::string fileId;
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
/// Map from block to name of C++ label.
BlockMapper blockMapper;
/// Default values representing outermost scope.
llvm::ScopedHashTableScope<Value, std::string> defaultValueMapperScope;
llvm::ScopedHashTableScope<Block *, std::string> defaultBlockMapperScope;
std::stack<int64_t> labelInScopeCount;
/// Keeps track of the amount of nested loops the emitter currently operates
/// in.
uint64_t loopNestingLevel{0};
/// Emitter-level count of created values to enable unique identifiers.
unsigned int valueCount{0};
/// State of the current expression being emitted.
ExpressionOp emittedExpression;
SmallVector<int> emittedExpressionPrecedence;
void pushExpressionPrecedence(int precedence) {
emittedExpressionPrecedence.push_back(precedence);
}
void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
static int lowestPrecedence() { return 0; }
int getExpressionPrecedence() {
if (emittedExpressionPrecedence.empty())
return lowestPrecedence();
return emittedExpressionPrecedence.back();
}
};
} // namespace
/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
}
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
/// as part of its user. This function recommends inlining of any expressions
/// that can be inlined unless it is used by another expression, under the
/// assumption that any expression fusion/re-materialization was taken care of
/// by transformations run by the backend.
static bool shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline if expression is marked as such.
if (expressionOp.getDoNotInline())
return false;
// Do not inline expressions with side effects to prevent side-effect
// reordering.
if (expressionOp.hasSideEffects())
return false;
// Do not inline expressions with multiple uses.
Value result = expressionOp.getResult();
if (!result.hasOneUse())
return false;
Operation *user = *result.getUsers().begin();
// Do not inline expressions used by operations with deferred emission, since
// their translation requires the materialization of variables.
if (hasDeferredEmission(user))
return false;
// Do not inline expressions used by ops with the CExpressionInterface. If
// this was intended, the user could have been merged into the expression op.
return !isa<emitc::CExpressionInterface>(*user);
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Attribute value) {
OpResult result = operation->getResult(0);
// Only emit an assignment as the variable was already declared when printing
// the FuncOp.
if (emitter.shouldDeclareVariablesAtTop()) {
// Skip the assignment if the emitc.constant has no value.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
if (oAttr.getValue().empty())
return success();
}
if (failed(emitter.emitVariableAssignment(result)))
return failure();
return emitter.emitAttribute(operation->getLoc(), value);
}
// Emit a variable declaration for an emitc.constant op without value.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
if (oAttr.getValue().empty())
// The semicolon gets printed by the emitOperation function.
return emitter.emitVariableDeclaration(result,
/*trailingSemicolon=*/false);
}
// Emit a variable declaration.
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
return emitter.emitAttribute(operation->getLoc(), value);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
Attribute value = constantOp.getValue();
return printConstantOp(emitter, operation, value);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::VariableOp variableOp) {
Operation *operation = variableOp.getOperation();
Attribute value = variableOp.getValue();
return printConstantOp(emitter, operation, value);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::GlobalOp globalOp) {
return emitter.emitGlobalVariable(globalOp);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
if (failed(emitter.emitVariableAssignment(result)))
return failure();
return emitter.emitOperand(assignOp.getValue());
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
if (failed(emitter.emitAssignPrefix(*loadOp)))
return failure();
return emitter.emitOperand(loadOp.getOperand());
}
static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
if (failed(emitter.emitOperand(operation->getOperand(0))))
return failure();
os << " " << binaryOperator << " ";
if (failed(emitter.emitOperand(operation->getOperand(1))))
return failure();
return success();
}
static LogicalResult printUnaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef unaryOperator) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
os << unaryOperator;
if (failed(emitter.emitOperand(operation->getOperand(0))))
return failure();
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
Operation *operation = addOp.getOperation();
return printBinaryOperation(emitter, operation, "+");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
Operation *operation = divOp.getOperation();
return printBinaryOperation(emitter, operation, "/");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
Operation *operation = mulOp.getOperation();
return printBinaryOperation(emitter, operation, "*");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
Operation *operation = remOp.getOperation();
return printBinaryOperation(emitter, operation, "%");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
Operation *operation = subOp.getOperation();
return printBinaryOperation(emitter, operation, "-");
}
static LogicalResult emitSwitchCase(CppEmitter &emitter,
raw_indented_ostream &os, Region &region) {
for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
std::next(iteratorOp) != end; ++iteratorOp) {
if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
return failure();
}
os << "break;\n";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::SwitchOp switchOp) {
raw_indented_ostream &os = emitter.ostream();
os << "switch (";
if (failed(emitter.emitOperand(switchOp.getArg())))
return failure();
os << ") {";
for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
os << "\ncase " << std::get<0>(pair) << ": {\n";
os.indent();
if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
return failure();
os.unindent() << "}";
}
os << "\ndefault: {\n";
os.indent();
if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
return failure();
os.unindent() << "}\n}";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
Operation *operation = cmpOp.getOperation();
StringRef binaryOperator;
switch (cmpOp.getPredicate()) {
case emitc::CmpPredicate::eq:
binaryOperator = "==";
break;
case emitc::CmpPredicate::ne:
binaryOperator = "!=";
break;
case emitc::CmpPredicate::lt:
binaryOperator = "<";
break;
case emitc::CmpPredicate::le:
binaryOperator = "<=";
break;
case emitc::CmpPredicate::gt:
binaryOperator = ">";
break;
case emitc::CmpPredicate::ge:
binaryOperator = ">=";
break;
case emitc::CmpPredicate::three_way:
binaryOperator = "<=>";
break;
}
return printBinaryOperation(emitter, operation, binaryOperator);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConditionalOp conditionalOp) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitAssignPrefix(*conditionalOp)))
return failure();
if (failed(emitter.emitOperand(conditionalOp.getCondition())))
return failure();
os << " ? ";
if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
return failure();
os << " : ";
if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
return failure();
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();
FailureOr<SmallVector<ReplacementItem>> items =
verbatimOp.parseFormatString();
if (failed(items))
return failure();
auto fmtArg = verbatimOp.getFmtArgs().begin();
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
} else {
if (failed(emitter.emitOperand(*fmtArg++)))
return failure();
}
}
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
cf::BranchOp branchOp) {
raw_ostream &os = emitter.ostream();
Block &successor = *branchOp.getSuccessor();
for (auto pair :
llvm::zip(branchOp.getOperands(), successor.getArguments())) {
Value &operand = std::get<0>(pair);
BlockArgument &argument = std::get<1>(pair);
os << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}
os << "goto ";
if (!(emitter.hasBlockLabel(successor)))
return branchOp.emitOpError("unable to find label for successor block");
os << emitter.getOrCreateName(successor);
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
cf::CondBranchOp condBranchOp) {
raw_indented_ostream &os = emitter.ostream();
Block &trueSuccessor = *condBranchOp.getTrueDest();
Block &falseSuccessor = *condBranchOp.getFalseDest();
os << "if (";
if (failed(emitter.emitOperand(condBranchOp.getCondition())))
return failure();
os << ") {\n";
os.indent();
// If condition is true.
for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
trueSuccessor.getArguments())) {
Value &operand = std::get<0>(pair);
BlockArgument &argument = std::get<1>(pair);
os << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}
os << "goto ";
if (!(emitter.hasBlockLabel(trueSuccessor))) {
return condBranchOp.emitOpError("unable to find label for successor block");
}
os << emitter.getOrCreateName(trueSuccessor) << ";\n";
os.unindent() << "} else {\n";
os.indent();
// If condition is false.
for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
falseSuccessor.getArguments())) {
Value &operand = std::get<0>(pair);
BlockArgument &argument = std::get<1>(pair);
os << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}
os << "goto ";
if (!(emitter.hasBlockLabel(falseSuccessor))) {
return condBranchOp.emitOpError()
<< "unable to find label for successor block";
}
os << emitter.getOrCreateName(falseSuccessor) << ";\n";
os.unindent() << "}";
return success();
}
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
StringRef callee) {
if (failed(emitter.emitAssignPrefix(*callOp)))
return failure();
raw_ostream &os = emitter.ostream();
os << callee << "(";
if (failed(emitter.emitOperands(*callOp)))
return failure();
os << ")";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
Operation *operation = callOp.getOperation();
StringRef callee = callOp.getCallee();
return printCallOperation(emitter, operation, callee);
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
Operation *operation = callOp.getOperation();
StringRef callee = callOp.getCallee();
return printCallOperation(emitter, operation, callee);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::CallOpaqueOp callOpaqueOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *callOpaqueOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << callOpaqueOp.getCallee();
// Template arguments can't refer to SSA values and as such the template
// arguments which are supplied in form of attributes can be emitted as is. We
// don't need to handle integer attributes specially like we do for arguments
// - see below.
auto emitTemplateArgs = [&](Attribute attr) -> LogicalResult {
return emitter.emitAttribute(op.getLoc(), attr);
};
if (callOpaqueOp.getTemplateArgs()) {
os << "<";
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
emitTemplateArgs)))
return failure();
os << ">";
}
auto emitArgs = [&](Attribute attr) -> LogicalResult {
if (auto t = dyn_cast<IntegerAttr>(attr)) {
// Index attributes are treated specially as operand index.
if (t.getType().isIndex()) {
int64_t idx = t.getInt();
Value operand = op.getOperand(idx);
if (!emitter.hasValueInScope(operand))
return op.emitOpError("operand ")
<< idx << "'s value not defined in scope";
os << emitter.getOrCreateName(operand);
return success();
}
}
if (failed(emitter.emitAttribute(op.getLoc(), attr)))
return failure();
return success();
};
os << "(";
LogicalResult emittedArgs =
callOpaqueOp.getArgs()
? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
: emitter.emitOperands(op);
if (failed(emittedArgs))
return failure();
os << ")";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ApplyOp applyOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *applyOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << applyOp.getApplicableOperator();
os << emitter.getOrCreateName(applyOp.getOperand());
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::BitwiseAndOp bitwiseAndOp) {
Operation *operation = bitwiseAndOp.getOperation();
return printBinaryOperation(emitter, operation, "&");
}
static LogicalResult
printOperation(CppEmitter &emitter,
emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
Operation *operation = bitwiseLeftShiftOp.getOperation();
return printBinaryOperation(emitter, operation, "<<");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::BitwiseNotOp bitwiseNotOp) {
Operation *operation = bitwiseNotOp.getOperation();
return printUnaryOperation(emitter, operation, "~");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::BitwiseOrOp bitwiseOrOp) {
Operation *operation = bitwiseOrOp.getOperation();
return printBinaryOperation(emitter, operation, "|");
}
static LogicalResult
printOperation(CppEmitter &emitter,
emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
Operation *operation = bitwiseRightShiftOp.getOperation();
return printBinaryOperation(emitter, operation, ">>");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::BitwiseXorOp bitwiseXorOp) {
Operation *operation = bitwiseXorOp.getOperation();
return printBinaryOperation(emitter, operation, "^");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::UnaryPlusOp unaryPlusOp) {
Operation *operation = unaryPlusOp.getOperation();
return printUnaryOperation(emitter, operation, "+");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::UnaryMinusOp unaryMinusOp) {
Operation *operation = unaryMinusOp.getOperation();
return printUnaryOperation(emitter, operation, "-");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *castOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "(";
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
os << ") ";
return emitter.emitOperand(castOp.getOperand());
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ExpressionOp expressionOp) {
if (shouldBeInlined(expressionOp))
return success();
Operation &op = *expressionOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
return emitter.emitExpression(expressionOp);
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::IncludeOp includeOp) {
raw_ostream &os = emitter.ostream();
os << "#include ";
if (includeOp.getIsStandardInclude())
os << "<" << includeOp.getInclude() << ">";
else
os << "\"" << includeOp.getInclude() << "\"";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::LogicalAndOp logicalAndOp) {
Operation *operation = logicalAndOp.getOperation();
return printBinaryOperation(emitter, operation, "&&");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::LogicalNotOp logicalNotOp) {
Operation *operation = logicalNotOp.getOperation();
return printUnaryOperation(emitter, operation, "!");
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::LogicalOrOp logicalOrOp) {
Operation *operation = logicalOrOp.getOperation();
return printBinaryOperation(emitter, operation, "||");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
// Utility function to determine whether a value is an expression that will be
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
auto expressionOp =
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
};
os << "for (";
if (failed(
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
return failure();
os << " ";
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " = ";
if (failed(emitter.emitOperand(forOp.getLowerBound())))
return failure();
os << "; ";
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " < ";
Value upperBound = forOp.getUpperBound();
bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
if (upperBoundRequiresParentheses)
os << "(";
if (failed(emitter.emitOperand(upperBound)))
return failure();
if (upperBoundRequiresParentheses)
os << ")";
os << "; ";
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " += ";
if (failed(emitter.emitOperand(forOp.getStep())))
return failure();
os << ") {\n";
os.indent();
CppEmitter::LoopScope lScope(emitter);
Region &forRegion = forOp.getRegion();
auto regionOps = forRegion.getOps();
// We skip the trailing yield op.
for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
return failure();
}
os.unindent() << "}";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
raw_indented_ostream &os = emitter.ostream();
// Helper function to emit all ops except the last one, expected to be
// emitc::yield.
auto emitAllExceptLast = [&emitter](Region &region) {
Region::OpIterator it = region.op_begin(), end = region.op_end();
for (; std::next(it) != end; ++it) {
if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
return failure();
}
assert(isa<emitc::YieldOp>(*it) &&
"Expected last operation in the region to be emitc::yield");
return success();
};
os << "if (";
if (failed(emitter.emitOperand(ifOp.getCondition())))
return failure();
os << ") {\n";
os.indent();
if (failed(emitAllExceptLast(ifOp.getThenRegion())))
return failure();
os.unindent() << "}";
Region &elseRegion = ifOp.getElseRegion();
if (!elseRegion.empty()) {
os << " else {\n";
os.indent();
if (failed(emitAllExceptLast(elseRegion)))
return failure();
os.unindent() << "}";
}
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
func::ReturnOp returnOp) {
raw_ostream &os = emitter.ostream();
os << "return";
switch (returnOp.getNumOperands()) {
case 0:
return success();
case 1:
os << " ";
if (failed(emitter.emitOperand(returnOp.getOperand(0))))
return failure();
return success();
default:
os << " std::make_tuple(";
if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
return failure();
os << ")";
return success();
}
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ReturnOp returnOp) {
raw_ostream &os = emitter.ostream();
os << "return";
if (returnOp.getNumOperands() == 0)
return success();
os << " ";
if (failed(emitter.emitOperand(returnOp.getOperand())))
return failure();
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
for (Operation &op : moduleOp) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
return failure();
}
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
raw_indented_ostream &os = emitter.ostream();
os << "class " << classOp.getSymName();
if (classOp.getFinalSpecifier())
os << " final";
os << " {\n public:\n";
os.indent();
for (Operation &op : classOp) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
return failure();
}
os.unindent();
os << "};";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
return failure();
os << " " << fieldOp.getSymName() << ";";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
GetFieldOp getFieldOp) {
raw_indented_ostream &os = emitter.ostream();
Value result = getFieldOp.getResult();
if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
return failure();
os << " ";
if (failed(emitter.emitOperand(result)))
return failure();
os << " = ";
os << getFieldOp.getFieldName().str();
return success();
}
static LogicalResult printOperation(CppEmitter &emitter, FileOp file) {
if (!emitter.shouldEmitFile(file))
return success();
for (Operation &op : file) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
return failure();
}
return success();
}
static LogicalResult printFunctionArgs(CppEmitter &emitter,
Operation *functionOp,
ArrayRef<Type> arguments) {
raw_indented_ostream &os = emitter.ostream();
return (
interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
return emitter.emitType(functionOp->getLoc(), arg);
}));
}
static LogicalResult printFunctionArgs(CppEmitter &emitter,
Operation *functionOp,
Region::BlockArgListType arguments) {
raw_indented_ostream &os = emitter.ostream();
return (interleaveCommaWithError(
arguments, os, [&](BlockArgument arg) -> LogicalResult {
return emitter.emitVariableDeclaration(
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
}));
}
static LogicalResult printFunctionBody(CppEmitter &emitter,
Operation *functionOp,
Region::BlockListType &blocks) {
raw_indented_ostream &os = emitter.ostream();
os.indent();
if (emitter.shouldDeclareVariablesAtTop()) {
// Declare all variables that hold op results including those from nested
// regions.
WalkResult result =
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
(isa<emitc::ExpressionOp>(op) &&
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return WalkResult::skip();
for (OpResult result : op->getResults()) {
if (failed(emitter.emitVariableDeclaration(
result, /*trailingSemicolon=*/true))) {
return WalkResult(
op->emitError("unable to declare result variable for op"));
}
}
return WalkResult::advance();
});
if (result.wasInterrupted())
return failure();
}
// Create label names for basic blocks.
for (Block &block : blocks) {
emitter.getOrCreateName(block);
}
// Declare variables for basic block arguments.
for (Block &block : llvm::drop_begin(blocks)) {
for (BlockArgument &arg : block.getArguments()) {
if (emitter.hasValueInScope(arg))
return functionOp->emitOpError(" block argument #")
<< arg.getArgNumber() << " is out of scope";
if (isa<ArrayType, LValueType>(arg.getType()))
return functionOp->emitOpError("cannot emit block argument #")
<< arg.getArgNumber() << " with type " << arg.getType();
if (failed(
emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
return failure();
}
os << " " << emitter.getOrCreateName(arg) << ";\n";
}
}
for (Block &block : blocks) {
// Only print a label if the block has predecessors.
if (!block.hasNoPredecessors()) {
if (failed(emitter.emitLabel(block)))
return failure();
}
for (Operation &op : block.getOperations()) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
return failure();
}
}
os.unindent();
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
func::FuncOp functionOp) {
// We need to declare variables at top if the function has multiple blocks.
if (!emitter.shouldDeclareVariablesAtTop() &&
functionOp.getBlocks().size() > 1) {
return functionOp.emitOpError(
"with multiple blocks needs variables declared at top");
}
if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
return functionOp.emitOpError()
<< "cannot emit lvalue type as argument type";
}
if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
return functionOp.emitOpError() << "cannot emit array type as result type";
}
CppEmitter::FunctionScope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
os << " " << functionOp.getName();
os << "(";
Operation *operation = functionOp.getOperation();
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
os << "}\n";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
emitc::FuncOp functionOp) {
// We need to declare variables at top if the function has multiple blocks.
if (!emitter.shouldDeclareVariablesAtTop() &&
functionOp.getBlocks().size() > 1) {
return functionOp.emitOpError(
"with multiple blocks needs variables declared at top");
}
CppEmitter::FunctionScope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
}
}
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
os << " " << functionOp.getName();
os << "(";
Operation *operation = functionOp.getOperation();
if (functionOp.isExternal()) {
if (failed(printFunctionArgs(emitter, operation,
functionOp.getArgumentTypes())))
return failure();
os << ");";
return success();
}
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
os << "}\n";
return success();
}
static LogicalResult printOperation(CppEmitter &emitter,
DeclareFuncOp declareFuncOp) {
raw_indented_ostream &os = emitter.ostream();
CppEmitter::FunctionScope scope(emitter);
auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
declareFuncOp, declareFuncOp.getSymNameAttr());
if (!functionOp)
return failure();
if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
}
}
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
os << " " << functionOp.getName();
os << "(";
Operation *operation = functionOp.getOperation();
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
os << ");";
return success();
}
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
StringRef fileId)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
fileId(fileId.str()), defaultValueMapperScope(valueMapper),
defaultBlockMapperScope(blockMapper) {
labelInScopeCount.push(0);
}
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getValue());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
return out;
}
std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getOperand());
ss << "." << op.getMember();
return out;
}
std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getOperand());
ss << "->" << op.getMember();
return out;
}
void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
if (!valueMapper.count(value))
valueMapper.insert(value, str.str());
}
/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (!valueMapper.count(val)) {
assert(!hasDeferredEmission(val.getDefiningOp()) &&
"cacheDeferredOpResult should have been called on this value, "
"update the emitOperation function.");
valueMapper.insert(val, formatv("v{0}", ++valueCount));
}
return *valueMapper.begin(val);
}
/// Return the existing or a new name for a loop induction variable Value.
/// Loop induction variables follow natural naming: i, j, k, ..., t, uX.
StringRef CppEmitter::getOrCreateInductionVarName(Value val) {
if (!valueMapper.count(val)) {
int64_t identifier = 'i' + loopNestingLevel;
if (identifier >= 'i' && identifier <= 't') {
valueMapper.insert(val,
formatv("{0}{1}", (char)identifier, ++valueCount));
} else {
// If running out of letters, continue with uX.
valueMapper.insert(val, formatv("u{0}", ++valueCount));
}
}
return *valueMapper.begin(val);
}
/// Return the existing or a new label for a Block.
StringRef CppEmitter::getOrCreateName(Block &block) {
if (!blockMapper.count(&block))
blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
return *blockMapper.begin(&block);
}
bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
switch (val) {
case IntegerType::Signless:
return false;
case IntegerType::Signed:
return false;
case IntegerType::Unsigned:
return true;
}
llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
}
bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
bool CppEmitter::hasBlockLabel(Block &block) {
return blockMapper.count(&block);
}
LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
auto printInt = [&](const APInt &val, bool isUnsigned) {
if (val.getBitWidth() == 1) {
if (val.getBoolValue())
os << "true";
else
os << "false";
} else {
SmallString<128> strValue;
val.toString(strValue, 10, !isUnsigned, false);
os << strValue;
}
};
auto printFloat = [&](const APFloat &val) {
if (val.isFinite()) {
SmallString<128> strValue;
// Use default values of toString except don't truncate zeros.
val.toString(strValue, 0, 0, false);
os << strValue;
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
case llvm::APFloatBase::S_IEEEhalf:
os << "f16";
break;
case llvm::APFloatBase::S_BFloat:
os << "bf16";
break;
case llvm::APFloatBase::S_IEEEsingle:
os << "f";
break;
case llvm::APFloatBase::S_IEEEdouble:
break;
default:
llvm_unreachable("unsupported floating point type");
};
} else if (val.isNaN()) {
os << "NAN";
} else if (val.isInfinity()) {
if (val.isNegative())
os << "-";
os << "INFINITY";
}
};
// Print floating point attributes.
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
fAttr.getType())) {
return emitError(
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
printFloat(fAttr.getValue());
return success();
}
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
dense.getElementType())) {
return emitError(
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
os << '}';
return success();
}
// Print integer attributes.
if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
return success();
}
if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
printInt(iAttr.getValue(), false);
return success();
}
}
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
});
os << '}';
return success();
}
if (auto iType = dyn_cast<IndexType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
os << '}';
return success();
}
}
// Print opaque attributes.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
os << oAttr.getValue();
return success();
}
// Print symbolic reference attributes.
if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
if (sAttr.getNestedReferences().size() > 1)
return emitError(loc, "attribute has more than 1 nested reference");
os << sAttr.getRootReference().getValue();
return success();
}
// Print type attributes.
if (auto type = dyn_cast<TypeAttr>(attr))
return emitType(loc, type.getValue());
return emitError(loc, "cannot emit attribute: ") << attr;
}
LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
assert(emittedExpressionPrecedence.empty() &&
"Expected precedence stack to be empty");
Operation *rootOp = expressionOp.getRootOp();
emittedExpression = expressionOp;
FailureOr<int> precedence = getOperatorPrecedence(rootOp);
if (failed(precedence))
return failure();
pushExpressionPrecedence(precedence.value());
if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
return failure();
popExpressionPrecedence();
assert(emittedExpressionPrecedence.empty() &&
"Expected precedence stack to be empty");
emittedExpression = nullptr;
return success();
}
LogicalResult CppEmitter::emitOperand(Value value) {
if (isPartOfCurrentExpression(value)) {
Operation *def = value.getDefiningOp();
assert(def && "Expected operand to be defined by an operation");
FailureOr<int> precedence = getOperatorPrecedence(def);
if (failed(precedence))
return failure();
// Sub-expressions with equal or lower precedence need to be parenthesized,
// as they might be evaluated in the wrong order depending on the shape of
// the expression tree.
bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
if (encloseInParenthesis)
os << "(";
pushExpressionPrecedence(precedence.value());
if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
return failure();
if (encloseInParenthesis)
os << ")";
popExpressionPrecedence();
return success();
}
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
os << getOrCreateName(value);
return success();
}
LogicalResult CppEmitter::emitOperands(Operation &op) {
return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
// If an expression is being emitted, push lowest precedence as these
// operands are either wrapped by parenthesis.
if (getEmittedExpression())
pushExpressionPrecedence(lowestPrecedence());
if (failed(emitOperand(operand)))
return failure();
if (getEmittedExpression())
popExpressionPrecedence();
return success();
});
}
LogicalResult
CppEmitter::emitOperandsAndAttributes(Operation &op,
ArrayRef<StringRef> exclude) {
if (failed(emitOperands(op)))
return failure();
// Insert comma in between operands and non-filtered attributes if needed.
if (op.getNumOperands() > 0) {
for (NamedAttribute attr : op.getAttrs()) {
if (!llvm::is_contained(exclude, attr.getName().strref())) {
os << ", ";
break;
}
}
}
// Emit attributes.
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
if (llvm::is_contained(exclude, attr.getName().strref()))
return success();
os << "/* " << attr.getName().getValue() << " */";
if (failed(emitAttribute(op.getLoc(), attr.getValue())))
return failure();
return success();
};
return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
}
LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
if (!hasValueInScope(result)) {
return result.getDefiningOp()->emitOpError(
"result variable for the operation has not been declared");
}
os << getOrCreateName(result) << " = ";
return success();
}
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
if (hasDeferredEmission(result.getDefiningOp()))
return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
}
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
result.getType(),
getOrCreateName(result))))
return failure();
if (trailingSemicolon)
os << ";\n";
return success();
}
LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
if (op.getExternSpecifier())
os << "extern ";
else if (op.getStaticSpecifier())
os << "static ";
if (op.getConstSpecifier())
os << "const ";
if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
op.getSymName()))) {
return failure();
}
std::optional<Attribute> initialValue = op.getInitialValue();
if (initialValue) {
os << " = ";
if (failed(emitAttribute(op->getLoc(), *initialValue)))
return failure();
}
os << ";";
return success();
}
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
return success();
switch (op.getNumResults()) {
case 0:
break;
case 1: {
OpResult result = op.getResult(0);
if (shouldDeclareVariablesAtTop()) {
if (failed(emitVariableAssignment(result)))
return failure();
} else {
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
return failure();
os << " = ";
}
break;
}
default:
if (!shouldDeclareVariablesAtTop()) {
for (OpResult result : op.getResults()) {
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
return failure();
}
}
os << "std::tie(";
interleaveComma(op.getResults(), os,
[&](Value result) { os << getOrCreateName(result); });
os << ") = ";
}
return success();
}
LogicalResult CppEmitter::emitLabel(Block &block) {
if (!hasBlockLabel(block))
return block.getParentOp()->emitError("label for block not found");
// FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
// label instead of using `getOStream`.
os.getOStream() << getOrCreateName(block) << ":\n";
return success();
}
LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
LogicalResult status =
llvm::TypeSwitch<Operation *, LogicalResult>(&op)
// Builtin ops.
.Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
// CF ops.
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp,
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<emitc::GetGlobalOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getName());
return success();
})
.Case<emitc::LiteralOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
})
.Case<emitc::MemberOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
.Case<emitc::MemberOfPtrOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
.Case<emitc::SubscriptOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
return success();
})
.Default([&](Operation *) {
return op.emitOpError("unable to find printer for op");
});
if (failed(status))
return failure();
if (hasDeferredEmission(&op))
return success();
if (getEmittedExpression() ||
(isa<emitc::ExpressionOp>(op) &&
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return success();
// Never emit a semicolon for some operations, especially if endening with
// `}`.
trailingSemicolon &=
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::FileOp, emitc::ForOp,
emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(
op);
os << (trailingSemicolon ? ";\n" : "\n");
return success();
}
LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
StringRef name) {
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, arrType.getElementType())))
return failure();
os << " " << name;
for (auto dim : arrType.getShape()) {
os << "[" << dim << "]";
}
return success();
}
if (failed(emitType(loc, type)))
return failure();
os << " " << name;
return success();
}
LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
case 1:
return (os << "bool"), success();
case 8:
case 16:
case 32:
case 64:
if (shouldMapToUnsigned(iType.getSignedness()))
return (os << "uint" << iType.getWidth() << "_t"), success();
else
return (os << "int" << iType.getWidth() << "_t"), success();
default:
return emitError(loc, "cannot emit integer type ") << type;
}
}
if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (os << "_Float16"), success();
else if (llvm::isa<BFloat16Type>(type))
return (os << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
}
case 32:
return (os << "float"), success();
case 64:
return (os << "double"), success();
default:
return emitError(loc, "cannot emit float type ") << type;
}
}
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
return (os << "ssize_t"), success();
if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
return (os << "ptrdiff_t"), success();
if (auto tType = dyn_cast<TensorType>(type)) {
if (!tType.hasRank())
return emitError(loc, "cannot emit unranked tensor type");
if (!tType.hasStaticShape())
return emitError(loc, "cannot emit tensor type with non static shape");
os << "Tensor<";
if (isa<ArrayType>(tType.getElementType()))
return emitError(loc, "cannot emit tensor of array type ") << type;
if (failed(emitType(loc, tType.getElementType())))
return failure();
auto shape = tType.getShape();
for (auto dimSize : shape) {
os << ", ";
os << dimSize;
}
os << ">";
return success();
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
os << oType.getValue();
return success();
}
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
for (auto dim : aType.getShape())
os << "[" << dim << "]";
return success();
}
if (auto lType = dyn_cast<emitc::LValueType>(type))
return emitType(loc, lType.getValueType());
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
if (isa<ArrayType>(pType.getPointee()))
return emitError(loc, "cannot emit pointer to array type ") << type;
if (failed(emitType(loc, pType.getPointee())))
return failure();
os << "*";
return success();
}
return emitError(loc, "cannot emit type ") << type;
}
LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
switch (types.size()) {
case 0:
os << "void";
return success();
case 1:
return emitType(loc, types.front());
default:
return emitTupleType(loc, types);
}
}
LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
return emitError(loc, "cannot emit tuple of array type");
}
os << "std::tuple<";
if (failed(interleaveCommaWithError(
types, os, [&](Type type) { return emitType(loc, type); })))
return failure();
os << ">";
return success();
}
void CppEmitter::resetValueCounter() { valueCount = 0; }
void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; }
void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; }
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
StringRef fileId) {
CppEmitter emitter(os, declareVariablesAtTop, fileId);
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
}