//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Optimizer/CodeGen/CodeGen.h" #include "PassDetail.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Support/FIRContext.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/ArrayRef.h" #define DEBUG_TYPE "flang-codegen" // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" namespace { /// FIR conversion pattern template template class FIROpConversion : public mlir::ConvertOpToLLVMPattern { public: explicit FIROpConversion(fir::LLVMTypeConverter &lowering) : mlir::ConvertOpToLLVMPattern(lowering) {} protected: mlir::Type convertType(mlir::Type ty) const { return lowerTy().convertType(ty); } fir::LLVMTypeConverter &lowerTy() const { return *static_cast(this->getTypeConverter()); } }; /// FIR conversion pattern template template class FIROpAndTypeConversion : public FIROpConversion { public: using FIROpConversion::FIROpConversion; using OpAdaptor = typename FromOp::Adaptor; mlir::LogicalResult matchAndRewrite(FromOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const final { mlir::Type ty = this->convertType(op.getType()); return doRewrite(op, ty, adaptor, rewriter); } virtual mlir::LogicalResult doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const = 0; }; // Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto ty = convertType(addr.getType()); rewriter.replaceOpWithNewOp( addr, ty, addr.symbol().getRootReference().getValue()); return success(); } }; // `fir.call` -> `llvm.call` struct CallOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::CallOp call, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { SmallVector resultTys; for (auto r : call.getResults()) resultTys.push_back(convertType(r.getType())); rewriter.replaceOpWithNewOp( call, resultTys, adaptor.getOperands(), call->getAttrs()); return success(); } }; static mlir::Type getComplexEleTy(mlir::Type complex) { if (auto cc = complex.dyn_cast()) return cc.getElementType(); return complex.cast().getElementType(); } /// convert value of from-type to value of to-type struct ConvertOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; static bool isFloatingPointTy(mlir::Type ty) { return ty.isa(); } mlir::LogicalResult matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto fromTy = convertType(convert.value().getType()); auto toTy = convertType(convert.res().getType()); mlir::Value op0 = adaptor.getOperands()[0]; if (fromTy == toTy) { rewriter.replaceOp(convert, op0); return success(); } auto loc = convert.getLoc(); auto convertFpToFp = [&](mlir::Value val, unsigned fromBits, unsigned toBits, mlir::Type toTy) -> mlir::Value { if (fromBits == toBits) { // TODO: Converting between two floating-point representations with the // same bitwidth is not allowed for now. mlir::emitError(loc, "cannot implicitly convert between two floating-point " "representations of the same bitwidth"); return {}; } if (fromBits > toBits) return rewriter.create(loc, toTy, val); return rewriter.create(loc, toTy, val); }; // Complex to complex conversion. if (fir::isa_complex(convert.value().getType()) && fir::isa_complex(convert.res().getType())) { // Special case: handle the conversion of a complex such that both the // real and imaginary parts are converted together. auto zero = mlir::ArrayAttr::get(convert.getContext(), rewriter.getI32IntegerAttr(0)); auto one = mlir::ArrayAttr::get(convert.getContext(), rewriter.getI32IntegerAttr(1)); auto ty = convertType(getComplexEleTy(convert.value().getType())); auto rp = rewriter.create(loc, ty, op0, zero); auto ip = rewriter.create(loc, ty, op0, one); auto nt = convertType(getComplexEleTy(convert.res().getType())); auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt); auto rc = convertFpToFp(rp, fromBits, toBits, nt); auto ic = convertFpToFp(ip, fromBits, toBits, nt); auto un = rewriter.create(loc, toTy); auto i1 = rewriter.create(loc, toTy, un, rc, zero); rewriter.replaceOpWithNewOp(convert, toTy, i1, ic, one); return mlir::success(); } // Floating point to floating point conversion. if (isFloatingPointTy(fromTy)) { if (isFloatingPointTy(toTy)) { auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); auto v = convertFpToFp(op0, fromBits, toBits, toTy); rewriter.replaceOp(convert, v); return mlir::success(); } if (toTy.isa()) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } } else if (fromTy.isa()) { // Integer to integer conversion. if (toTy.isa()) { auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); assert(fromBits != toBits); if (fromBits > toBits) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } // Integer to floating point conversion. if (isFloatingPointTy(toTy)) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } // Integer to pointer conversion. if (toTy.isa()) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } } else if (fromTy.isa()) { // Pointer to integer conversion. if (toTy.isa()) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } // Pointer to pointer conversion. if (toTy.isa()) { rewriter.replaceOpWithNewOp(convert, toTy, op0); return mlir::success(); } } return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; } }; /// Lower `fir.has_value` operation to `llvm.return` operation. struct HasValueOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; /// Lower `fir.global` operation to `llvm.global` operation. /// `fir.insert_on_range` operations are replaced with constant dense attribute /// if they are applied on the full range. struct GlobalOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto tyAttr = convertType(global.getType()); if (global.getType().isa()) tyAttr = tyAttr.cast().getElementType(); auto loc = global.getLoc(); mlir::Attribute initAttr{}; if (global.initVal()) initAttr = global.initVal().getValue(); auto linkage = convertLinkage(global.linkName()); auto isConst = global.constant().hasValue(); auto g = rewriter.create( loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); auto &gr = g.getInitializerRegion(); rewriter.inlineRegionBefore(global.region(), gr, gr.end()); if (!gr.empty()) { // Replace insert_on_range with a constant dense attribute if the // initialization is on the full range. auto insertOnRangeOps = gr.front().getOps(); for (auto insertOp : insertOnRangeOps) { if (isFullRange(insertOp.coor(), insertOp.getType())) { auto seqTyAttr = convertType(insertOp.getType()); auto *op = insertOp.val().getDefiningOp(); auto constant = mlir::dyn_cast(op); if (!constant) { auto convertOp = mlir::dyn_cast(op); if (!convertOp) continue; constant = cast( convertOp.value().getDefiningOp()); } mlir::Type vecType = mlir::VectorType::get( insertOp.getType().getShape(), constant.getType()); auto denseAttr = mlir::DenseElementsAttr::get( vecType.cast(), constant.value()); rewriter.setInsertionPointAfter(insertOp); rewriter.replaceOpWithNewOp( insertOp, seqTyAttr, denseAttr); } } } rewriter.eraseOp(global); return success(); } bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { auto extents = seqTy.getShape(); if (indexes.size() / 2 != extents.size()) return false; for (unsigned i = 0; i < indexes.size(); i += 2) { if (indexes[i].cast().getInt() != 0) return false; if (indexes[i + 1].cast().getInt() != extents[i / 2] - 1) return false; } return true; } // TODO: String comparaison should be avoided. Replace linkName with an // enumeration. mlir::LLVM::Linkage convertLinkage(Optional optLinkage) const { if (optLinkage.hasValue()) { auto name = optLinkage.getValue(); if (name == "internal") return mlir::LLVM::Linkage::Internal; if (name == "linkonce") return mlir::LLVM::Linkage::Linkonce; if (name == "common") return mlir::LLVM::Linkage::Common; if (name == "weak") return mlir::LLVM::Linkage::Weak; } return mlir::LLVM::Linkage::External; } }; template void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, typename OP::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) { unsigned conds = select.getNumConditions(); auto cases = select.getCases().getValue(); mlir::Value selector = adaptor.selector(); auto loc = select.getLoc(); assert(conds > 0 && "select must have cases"); llvm::SmallVector destinations; llvm::SmallVector destinationsOperands; mlir::Block *defaultDestination; mlir::ValueRange defaultOperands; llvm::SmallVector caseValues; for (unsigned t = 0; t != conds; ++t) { mlir::Block *dest = select.getSuccessor(t); auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); const mlir::Attribute &attr = cases[t]; if (auto intAttr = attr.template dyn_cast()) { destinations.push_back(dest); destinationsOperands.push_back(destOps.hasValue() ? *destOps : ValueRange()); caseValues.push_back(intAttr.getInt()); continue; } assert(attr.template dyn_cast_or_null()); assert((t + 1 == conds) && "unit must be last"); defaultDestination = dest; defaultOperands = destOps.hasValue() ? *destOps : ValueRange(); } // LLVM::SwitchOp takes a i32 type for the selector. if (select.getSelector().getType() != rewriter.getI32Type()) selector = rewriter.create(loc, rewriter.getI32Type(), selector); rewriter.replaceOpWithNewOp( select, selector, /*defaultDestination=*/defaultDestination, /*defaultOperands=*/defaultOperands, /*caseValues=*/caseValues, /*caseDestinations=*/destinations, /*caseOperands=*/destinationsOperands, /*branchWeights=*/ArrayRef()); } /// conversion of fir::SelectOp to an if-then-else ladder struct SelectOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); return success(); } }; /// `fir.load` --> `llvm.load` struct LoadOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // fir.box is a special case because it is considered as an ssa values in // fir, but it is lowered as a pointer to a descriptor. So fir.ref // and fir.box end up being the same llvm types and loading a // fir.ref is actually a no op in LLVM. if (load.getType().isa()) { rewriter.replaceOp(load, adaptor.getOperands()[0]); } else { mlir::Type ty = convertType(load.getType()); ArrayRef at = load->getAttrs(); rewriter.replaceOpWithNewOp( load, ty, adaptor.getOperands(), at); } return success(); } }; /// conversion of fir::SelectRankOp to an if-then-else ladder struct SelectRankOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); return success(); } }; /// `fir.store` --> `llvm.store` struct StoreOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::StoreOp store, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { if (store.value().getType().isa()) { // fir.box value is actually in memory, load it first before storing it. mlir::Location loc = store.getLoc(); mlir::Type boxPtrTy = adaptor.getOperands()[0].getType(); auto val = rewriter.create( loc, boxPtrTy.cast().getElementType(), adaptor.getOperands()[0]); rewriter.replaceOpWithNewOp( store, val, adaptor.getOperands()[1]); } else { rewriter.replaceOpWithNewOp( store, adaptor.getOperands()[0], adaptor.getOperands()[1]); } return success(); } }; /// convert to LLVM IR dialect `undef` struct UndefOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::UndefOp undef, OpAdaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( undef, convertType(undef.getType())); return success(); } }; /// `fir.unreachable` --> `llvm.unreachable` struct UnreachableOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(unreach); return success(); } }; struct ZeroOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::ZeroOp zero, OpAdaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto ty = convertType(zero.getType()); if (ty.isa()) { rewriter.replaceOpWithNewOp(zero, ty); } else if (ty.isa()) { rewriter.replaceOpWithNewOp( zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { rewriter.replaceOpWithNewOp( zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); } else { // TODO: create ConstantAggregateZero for FIR aggregate/array types. return rewriter.notifyMatchFailure( zero, "conversion of fir.zero with aggregate type not implemented yet"); } return success(); } }; // Code shared between insert_value and extract_value Ops. struct ValueOpCommon { // Translate the arguments pertaining to any multidimensional array to // row-major order for LLVM-IR. static void toRowMajor(SmallVectorImpl &attrs, mlir::Type ty) { assert(ty && "type is null"); const auto end = attrs.size(); for (std::remove_const_t i = 0; i < end; ++i) { if (auto seq = ty.dyn_cast()) { const auto dim = getDimension(seq); if (dim > 1) { auto ub = std::min(i + dim, end); std::reverse(attrs.begin() + i, attrs.begin() + ub); i += dim - 1; } ty = getArrayElementType(seq); } else if (auto st = ty.dyn_cast()) { ty = st.getBody()[attrs[i].cast().getInt()]; } else { llvm_unreachable("index into invalid type"); } } } static llvm::SmallVector collectIndices(mlir::ConversionPatternRewriter &rewriter, mlir::ArrayAttr arrAttr) { llvm::SmallVector attrs; for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { if (i->isa()) { attrs.push_back(*i); } else { auto fieldName = i->cast().getValue(); ++i; auto ty = i->cast().getValue(); auto index = ty.cast().getFieldIndex(fieldName); attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index)); } } return attrs; } private: static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { unsigned result = 1; for (auto eleTy = ty.getElementType().dyn_cast(); eleTy; eleTy = eleTy.getElementType().dyn_cast()) ++result; return result; } static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { auto eleTy = ty.getElementType(); while (auto arrTy = eleTy.dyn_cast()) eleTy = arrTy.getElementType(); return eleTy; } }; /// Extract a subobject value from an ssa-value of aggregate type struct ExtractValueOpConversion : public FIROpAndTypeConversion, public ValueOpCommon { using FIROpAndTypeConversion::FIROpAndTypeConversion; mlir::LogicalResult doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto attrs = collectIndices(rewriter, extractVal.coor()); toRowMajor(attrs, adaptor.getOperands()[0].getType()); auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs); rewriter.replaceOpWithNewOp( extractVal, ty, adaptor.getOperands()[0], position); return success(); } }; /// InsertValue is the generalized instruction for the composition of new /// aggregate type values. struct InsertValueOpConversion : public FIROpAndTypeConversion, public ValueOpCommon { using FIROpAndTypeConversion::FIROpAndTypeConversion; mlir::LogicalResult doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto attrs = collectIndices(rewriter, insertVal.coor()); toRowMajor(attrs, adaptor.getOperands()[0].getType()); auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs); rewriter.replaceOpWithNewOp( insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1], position); return success(); } }; /// InsertOnRange inserts a value into a sequence over a range of offsets. struct InsertOnRangeOpConversion : public FIROpAndTypeConversion { using FIROpAndTypeConversion::FIROpAndTypeConversion; // Increments an array of subscripts in a row major fasion. void incrementSubscripts(const SmallVector &dims, SmallVector &subscripts) const { for (size_t i = dims.size(); i > 0; --i) { if (++subscripts[i - 1] < dims[i - 1]) { return; } subscripts[i - 1] = 0; } } mlir::LogicalResult doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { llvm::SmallVector dims; auto type = adaptor.getOperands()[0].getType(); // Iteratively extract the array dimensions from the type. while (auto t = type.dyn_cast()) { dims.push_back(t.getNumElements()); type = t.getElementType(); } SmallVector lBounds; SmallVector uBounds; // Extract integer value from the attribute SmallVector coordinates = llvm::to_vector<4>( llvm::map_range(range.coor(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); // Unzip the upper and lower bound and convert to a row major format. for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { uBounds.push_back(*i++); lBounds.push_back(*i); } auto &subscripts = lBounds; auto loc = range.getLoc(); mlir::Value lastOp = adaptor.getOperands()[0]; mlir::Value insertVal = adaptor.getOperands()[1]; auto i64Ty = rewriter.getI64Type(); while (subscripts != uBounds) { // Convert uint64_t's to Attribute's. SmallVector subscriptAttrs; for (const auto &subscript : subscripts) subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); lastOp = rewriter.create( loc, ty, lastOp, insertVal, ArrayAttr::get(range.getContext(), subscriptAttrs)); incrementSubscripts(dims, subscripts); } // Convert uint64_t's to Attribute's. SmallVector subscriptAttrs; for (const auto &subscript : subscripts) subscriptAttrs.push_back( IntegerAttr::get(rewriter.getI64Type(), subscript)); mlir::ArrayRef arrayRef(subscriptAttrs); rewriter.replaceOpWithNewOp( range, ty, lastOp, insertVal, ArrayAttr::get(range.getContext(), arrayRef)); return success(); } }; // // Primitive operations on Complex types // /// Generate inline code for complex addition/subtraction template mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, mlir::ConversionPatternRewriter &rewriter, fir::LLVMTypeConverter &lowering) { mlir::Value a = opnds[0]; mlir::Value b = opnds[1]; auto loc = sumop.getLoc(); auto ctx = sumop.getContext(); auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType())); mlir::Type ty = lowering.convertType(sumop.getType()); auto x0 = rewriter.create(loc, eleTy, a, c0); auto y0 = rewriter.create(loc, eleTy, a, c1); auto x1 = rewriter.create(loc, eleTy, b, c0); auto y1 = rewriter.create(loc, eleTy, b, c1); auto rx = rewriter.create(loc, eleTy, x0, x1); auto ry = rewriter.create(loc, eleTy, y0, y1); auto r0 = rewriter.create(loc, ty); auto r1 = rewriter.create(loc, ty, r0, rx, c0); return rewriter.create(loc, ty, r1, ry, c1); } struct AddcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // given: (x + iy) + (x' + iy') // result: (x + x') + i(y + y') auto r = complexSum(addc, adaptor.getOperands(), rewriter, lowerTy()); rewriter.replaceOp(addc, r.getResult()); return success(); } }; struct SubcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // given: (x + iy) - (x' + iy') // result: (x - x') + i(y - y') auto r = complexSum(subc, adaptor.getOperands(), rewriter, lowerTy()); rewriter.replaceOp(subc, r.getResult()); return success(); } }; /// Inlined complex multiply struct MulcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // TODO: Can we use a call to __muldc3 ? // given: (x + iy) * (x' + iy') // result: (xx'-yy')+i(xy'+yx') mlir::Value a = adaptor.getOperands()[0]; mlir::Value b = adaptor.getOperands()[1]; auto loc = mulc.getLoc(); auto *ctx = mulc.getContext(); auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType())); mlir::Type ty = convertType(mulc.getType()); auto x0 = rewriter.create(loc, eleTy, a, c0); auto y0 = rewriter.create(loc, eleTy, a, c1); auto x1 = rewriter.create(loc, eleTy, b, c0); auto y1 = rewriter.create(loc, eleTy, b, c1); auto xx = rewriter.create(loc, eleTy, x0, x1); auto yx = rewriter.create(loc, eleTy, y0, x1); auto xy = rewriter.create(loc, eleTy, x0, y1); auto ri = rewriter.create(loc, eleTy, xy, yx); auto yy = rewriter.create(loc, eleTy, y0, y1); auto rr = rewriter.create(loc, eleTy, xx, yy); auto ra = rewriter.create(loc, ty); auto r1 = rewriter.create(loc, ty, ra, rr, c0); auto r0 = rewriter.create(loc, ty, r1, ri, c1); rewriter.replaceOp(mulc, r0.getResult()); return success(); } }; /// Inlined complex division struct DivcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // TODO: Can we use a call to __divdc3 instead? // Just generate inline code for now. // given: (x + iy) / (x' + iy') // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' mlir::Value a = adaptor.getOperands()[0]; mlir::Value b = adaptor.getOperands()[1]; auto loc = divc.getLoc(); auto *ctx = divc.getContext(); auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); mlir::Type ty = convertType(divc.getType()); auto x0 = rewriter.create(loc, eleTy, a, c0); auto y0 = rewriter.create(loc, eleTy, a, c1); auto x1 = rewriter.create(loc, eleTy, b, c0); auto y1 = rewriter.create(loc, eleTy, b, c1); auto xx = rewriter.create(loc, eleTy, x0, x1); auto x1x1 = rewriter.create(loc, eleTy, x1, x1); auto yx = rewriter.create(loc, eleTy, y0, x1); auto xy = rewriter.create(loc, eleTy, x0, y1); auto yy = rewriter.create(loc, eleTy, y0, y1); auto y1y1 = rewriter.create(loc, eleTy, y1, y1); auto d = rewriter.create(loc, eleTy, x1x1, y1y1); auto rrn = rewriter.create(loc, eleTy, xx, yy); auto rin = rewriter.create(loc, eleTy, yx, xy); auto rr = rewriter.create(loc, eleTy, rrn, d); auto ri = rewriter.create(loc, eleTy, rin, d); auto ra = rewriter.create(loc, ty); auto r1 = rewriter.create(loc, ty, ra, rr, c0); auto r0 = rewriter.create(loc, ty, r1, ri, c1); rewriter.replaceOp(divc, r0.getResult()); return success(); } }; /// Inlined complex negation struct NegcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // given: -(x + iy) // result: -x - iy auto *ctxt = neg.getContext(); auto eleTy = convertType(getComplexEleTy(neg.getType())); auto ty = convertType(neg.getType()); auto loc = neg.getLoc(); mlir::Value o0 = adaptor.getOperands()[0]; auto c0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0)); auto c1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1)); auto rp = rewriter.create(loc, eleTy, o0, c0); auto ip = rewriter.create(loc, eleTy, o0, c1); auto nrp = rewriter.create(loc, eleTy, rp); auto nip = rewriter.create(loc, eleTy, ip); auto r = rewriter.create(loc, ty, o0, nrp, c0); rewriter.replaceOpWithNewOp(neg, ty, r, nip, c1); return success(); } }; } // namespace namespace { /// Convert FIR dialect to LLVM dialect /// /// This pass lowers all FIR dialect operations to LLVM IR dialect. An /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. /// /// This pass is not complete yet. We are upstreaming it in small patches. class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase { public: mlir::ModuleOp getModule() { return getOperation(); } void runOnOperation() override final { auto mod = getModule(); if (!forcedTargetTriple.empty()) { fir::setTargetTriple(mod, forcedTargetTriple); } auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::OwningRewritePatternList pattern(context); pattern.insert(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); mlir::ConversionTarget target{*context}; target.addLegalDialect(); // required NOPs for applying a full conversion target.addLegalOp(); // apply the patterns if (mlir::failed(mlir::applyFullConversion(getModule(), target, std::move(pattern)))) { signalPassFailure(); } } }; } // namespace std::unique_ptr fir::createFIRToLLVMPass() { return std::make_unique(); }