2025-07-31 19:16:42 +02:00

448 lines
17 KiB
C++

//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/CharUnits.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"
#include <memory>
using namespace mlir;
using namespace cir;
namespace {
struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
LoweringPreparePass() = default;
void runOnOperation() override;
void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerArrayDtor(cir::ArrayDtor op);
void lowerArrayCtor(cir::ArrayCtor op);
cir::FuncOp buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type,
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
///
/// AST related
/// -----------
clang::ASTContext *astCtx;
/// Tracks current module.
mlir::ModuleOp mlirModule;
void setASTContext(clang::ASTContext *c) { astCtx = c; }
};
} // namespace
cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type, cir::GlobalLinkageKind linkage) {
cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
mlirModule, StringAttr::get(mlirModule->getContext(), name)));
if (!f) {
f = builder.create<cir::FuncOp>(loc, name, type);
f.setLinkageAttr(
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
mlir::SymbolTable::setSymbolVisibility(
f, mlir::SymbolTable::Visibility::Private);
assert(!cir::MissingFeatures::opFuncExtraAttrs());
}
return f;
}
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
return builder.createComplexCreate(op.getLoc(), src, imag);
}
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind elemToBoolKind) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
if (!mlir::isa<cir::BoolType>(op.getType()))
return builder.createComplexReal(op.getLoc(), src);
// Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
cir::BoolType boolTy = builder.getBoolTy();
mlir::Value srcRealToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
mlir::Value srcImagToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
}
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind scalarCastKind) {
CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
auto dstComplexElemTy =
mlir::cast<cir::ComplexType>(op.getType()).getElementType();
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
dstComplexElemTy);
mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
dstComplexElemTy);
return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
}
void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
mlir::MLIRContext &ctx = getContext();
mlir::Value loweredValue = [&]() -> mlir::Value {
switch (op.getKind()) {
case cir::CastKind::float_to_complex:
case cir::CastKind::int_to_complex:
return lowerScalarToComplexCast(ctx, op);
case cir::CastKind::float_complex_to_real:
case cir::CastKind::int_complex_to_real:
return lowerComplexToScalarCast(ctx, op, op.getKind());
case cir::CastKind::float_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
case cir::CastKind::int_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
case cir::CastKind::float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
case cir::CastKind::float_complex_to_int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
case cir::CastKind::int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
case cir::CastKind::int_complex_to_float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
default:
return nullptr;
}
}();
if (loweredValue) {
op.replaceAllUsesWith(loweredValue);
op.erase();
}
}
static mlir::Value buildComplexBinOpLibCall(
LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
cir::FPTypeInterface elementTy =
mlir::cast<cir::FPTypeInterface>(ty.getElementType());
llvm::StringRef libFuncName = libFuncNameGetter(
llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
// Insert a declaration for the runtime function to be used in Complex
// multiplication and division when needed
cir::FuncOp libFunc;
{
mlir::OpBuilder::InsertionGuard ipGuard{builder};
builder.setInsertionPointToStart(pass.mlirModule.getBody());
libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
}
cir::CallOp call =
builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
return call.getResult();
}
static llvm::StringRef
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__mulhc3";
case llvm::APFloat::S_IEEEsingle:
return "__mulsc3";
case llvm::APFloat::S_IEEEdouble:
return "__muldc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__multc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__mulxc3";
case llvm::APFloat::S_IEEEquad:
return "__multc3";
default:
llvm_unreachable("unsupported floating point type");
}
}
static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexMulOp op,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
mlir::Value resultRealLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultRealRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultReal = builder.createBinop(
loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
mlir::Value resultImag = builder.createBinop(
loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
mlir::Value algebraicResult =
builder.createComplexCreate(loc, resultReal, resultImag);
cir::ComplexType complexTy = op.getType();
cir::ComplexRangeKind rangeKind = op.getRange();
if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
rangeKind == cir::ComplexRangeKind::Basic ||
rangeKind == cir::ComplexRangeKind::Improved ||
rangeKind == cir::ComplexRangeKind::Promoted)
return algebraicResult;
assert(!cir::MissingFeatures::fastMathFlags());
// Check whether the real part and the imaginary part of the result are both
// NaN. If so, emit a library call to compute the multiplication instead.
// We check a value against NaN by comparing the value against itself.
mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
mlir::Value resultRealAndImagAreNaN =
builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
return builder
.create<cir::TernaryOp>(
loc, resultRealAndImagAreNaN,
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value libCallResult = buildComplexBinOpLibCall(
pass, builder, &getComplexMulLibCallName, loc, complexTy,
lhsReal, lhsImag, rhsReal, rhsImag);
builder.createYield(loc, libCallResult);
},
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc, algebraicResult);
})
.getResult();
}
void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
lhsImag, rhsReal, rhsImag);
op.replaceAllUsesWith(loweredResult);
op.erase();
}
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
return;
mlir::Location loc = op.getLoc();
cir::UnaryOpKind opKind = op.getKind();
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Value operand = op.getInput();
mlir::Value operandReal = builder.createComplexReal(loc, operand);
mlir::Value operandImag = builder.createComplexImag(loc, operand);
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case cir::UnaryOpKind::Inc:
case cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;
case cir::UnaryOpKind::Plus:
case cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = builder.createUnaryOp(loc, opKind, operandImag);
break;
case cir::UnaryOpKind::Not:
resultReal = operandReal;
resultImag =
builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
break;
}
mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
op.replaceAllUsesWith(result);
op.erase();
}
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder,
clang::ASTContext *astCtx,
mlir::Operation *op, mlir::Type eltTy,
mlir::Value arrayAddr, uint64_t arrayLen,
bool isCtor) {
// Generate loop to call into ctor/dtor for every element.
mlir::Location loc = op->getLoc();
// TODO: instead of getting the size from the AST context, create alias for
// PtrDiffTy and unify with CIRGen stuff.
const unsigned sizeTypeSize =
astCtx->getTypeSize(astCtx->getSignedSizeType());
uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
mlir::Value endOffsetVal =
builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
auto begin = cir::CastOp::create(builder, loc, eltTy,
cir::CastKind::array_to_ptrdecay, arrayAddr);
mlir::Value end =
cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
mlir::Value start = isCtor ? begin : end;
mlir::Value stop = isCtor ? end : begin;
mlir::Value tmpAddr = builder.createAlloca(
loc, /*addr type*/ builder.getPointerTo(eltTy),
/*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
builder.createStore(loc, start, tmpAddr);
cir::DoWhileOp loop = builder.createDoWhile(
loc,
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
mlir::Type boolTy = cir::BoolType::get(b.getContext());
auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
currentElement, stop);
builder.createCondition(cmp);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
cir::CallOp ctorCall;
op->walk([&](cir::CallOp c) { ctorCall = c; });
assert(ctorCall && "expected ctor call");
// Array elements get constructed in order but destructed in reverse.
mlir::Value stride;
if (isCtor)
stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
else
stride = builder.getSignedInt(loc, -1, sizeTypeSize);
ctorCall->moveBefore(stride.getDefiningOp());
ctorCall->setOperand(0, currentElement);
auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
currentElement, stride);
// Store the element pointer to the temporary variable
builder.createStore(loc, nextElement, tmpAddr);
builder.createYield(loc);
});
op->replaceAllUsesWith(loop);
op->erase();
}
void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
false);
}
void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
true);
}
void LoweringPreparePass::runOnOp(mlir::Operation *op) {
if (auto arrayCtor = dyn_cast<ArrayCtor>(op))
lowerArrayCtor(arrayCtor);
else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op))
lowerArrayDtor(arrayDtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
lowerComplexMulOp(complexMul);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}
void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
if (isa<::mlir::ModuleOp>(op))
mlirModule = cast<::mlir::ModuleOp>(op);
llvm::SmallVector<mlir::Operation *> opsToTransform;
op->walk([&](mlir::Operation *op) {
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
cir::ComplexMulOp, cir::UnaryOp>(op))
opsToTransform.push_back(op);
});
for (mlir::Operation *o : opsToTransform)
runOnOp(o);
}
std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
return std::make_unique<LoweringPreparePass>();
}
std::unique_ptr<Pass>
mlir::createLoweringPreparePass(clang::ASTContext *astCtx) {
auto pass = std::make_unique<LoweringPreparePass>();
pass->setASTContext(astCtx);
return std::move(pass);
}