//===-- CufOpConversion.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Common/Fortran.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" #include "flang/Runtime/CUDA/descriptor.h" #include "flang/Runtime/CUDA/memory.h" #include "flang/Runtime/allocatable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace fir { #define GEN_PASS_DEF_CUFOPCONVERSION #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir using namespace fir; using namespace mlir; using namespace Fortran::runtime; using namespace Fortran::runtime::cuda; namespace { template static bool needDoubleDescriptor(OpTy op) { if (auto declareOp = mlir::dyn_cast_or_null(op.getBox().getDefiningOp())) { if (mlir::isa_and_nonnull( declareOp.getMemref().getDefiningOp())) { if (declareOp.getDataAttr() && *declareOp.getDataAttr() == cuf::DataAttribute::Pinned) return false; return true; } } else if (auto declareOp = mlir::dyn_cast_or_null( op.getBox().getDefiningOp())) { if (mlir::isa_and_nonnull( declareOp.getMemref().getDefiningOp())) { if (declareOp.getDataAttr() && *declareOp.getDataAttr() == cuf::DataAttribute::Pinned) return false; return true; } } return false; } template static mlir::LogicalResult convertOpToCall(OpTy op, mlir::PatternRewriter &rewriter, mlir::func::FuncOp func) { auto mod = op->template getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true) : builder.createBool(loc, false); mlir::Value errmsg; if (op.getErrmsg()) { errmsg = op.getErrmsg(); } else { mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType()); errmsg = builder.create(loc, boxNoneTy).getResult(); } llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, op.getBox(), hasStat, errmsg, sourceFile, sourceLine)}; auto callOp = builder.create(loc, func, args); rewriter.replaceOp(op, callOp); return mlir::success(); } struct CufAllocateOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::AllocateOp op, mlir::PatternRewriter &rewriter) const override { // TODO: Allocation with source will need a new entry point in the runtime. if (op.getSource()) return mlir::failure(); // TODO: Allocation using different stream. if (op.getStream()) return mlir::failure(); // TODO: Pinned is a reference to a logical value that can be set to true // when pinned allocation succeed. This will require a new entry point. if (op.getPinned()) return mlir::failure(); // TODO: Allocation of module variable will need more work as the descriptor // will be duplicated and needs to be synced after allocation. if (needDoubleDescriptor(op)) return mlir::failure(); // Allocation for local descriptor falls back on the standard runtime // AllocatableAllocate as the dedicated allocator is set in the descriptor // before the call. auto mod = op->template getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); return convertOpToCall(op, rewriter, func); } }; struct CufDeallocateOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::DeallocateOp op, mlir::PatternRewriter &rewriter) const override { // TODO: Allocation of module variable will need more work as the descriptor // will be duplicated and needs to be synced after allocation. if (needDoubleDescriptor(op)) return mlir::failure(); // Deallocation for local descriptor falls back on the standard runtime // AllocatableDeallocate as the dedicated deallocator is set in the // descriptor before the call. auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); return convertOpToCall(op, rewriter, func); } }; static bool inDeviceContext(mlir::Operation *op) { if (op->getParentOfType()) return true; if (auto funcOp = op->getParentOfType()) { if (auto cudaProcAttr = funcOp.getOperation()->getAttrOfType( cuf::getProcAttrName())) { return cudaProcAttr.getValue() != cuf::ProcAttribute::Host && cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice; } } return false; } struct CufAllocOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; CufAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl, fir::LLVMTypeConverter *typeConverter) : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::AllocOp op, mlir::PatternRewriter &rewriter) const override { auto boxTy = mlir::dyn_cast_or_null(op.getInType()); // Only convert cuf.alloc that allocates a descriptor. if (!boxTy) return failure(); if (inDeviceContext(op.getOperation())) { // In device context just replace the cuf.alloc operation with a fir.alloc // the cuf.free will be removed. rewriter.replaceOpWithNewOp( op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "", op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(), op.getShape()); return mlir::success(); } auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy); std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; mlir::Value sizeInBytes = builder.createIntegerConstant(loc, builder.getIndexType(), boxSize); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)}; auto callOp = builder.create(loc, func, args); auto convOp = builder.createConvert(loc, op.getResult().getType(), callOp.getResult(0)); rewriter.replaceOp(op, convOp); return mlir::success(); } private: mlir::DataLayout *dl; fir::LLVMTypeConverter *typeConverter; }; struct CufFreeOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::FreeOp op, mlir::PatternRewriter &rewriter) const override { // Only convert cuf.free on descriptor. if (!mlir::isa(op.getDevptr().getType())) return failure(); auto refTy = mlir::dyn_cast(op.getDevptr().getType()); if (!mlir::isa(refTy.getEleTy())) return failure(); if (inDeviceContext(op.getOperation())) { rewriter.eraseOp(op); return mlir::success(); } auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); return mlir::success(); } }; static int computeWidth(mlir::Location loc, mlir::Type type, fir::KindMapping &kindMap) { auto eleTy = fir::unwrapSequenceType(type); int width = 0; if (auto t{mlir::dyn_cast(eleTy)}) { width = t.getWidth() / 8; } else if (auto t{mlir::dyn_cast(eleTy)}) { width = t.getWidth() / 8; } else if (eleTy.isInteger(1)) { width = 1; } else if (auto t{mlir::dyn_cast(eleTy)}) { int kind = t.getFKind(); width = kindMap.getLogicalBitsize(kind) / 8; } else if (auto t{mlir::dyn_cast(eleTy)}) { int kind = t.getFKind(); int elemSize = kindMap.getRealBitsize(kind) / 8; width = 2 * elemSize; } else { llvm::report_fatal_error("unsupported type"); } return width; } static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) return rewriter.create(loc, toTy, val); return val; } struct CufDataTransferOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op, mlir::PatternRewriter &rewriter) const override { mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); // Only convert cuf.data_transfer with at least one descripor. if (!mlir::isa(srcTy) && !mlir::isa(dstTy)) return failure(); unsigned mode; if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) { mode = kHostToDevice; } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) { mode = kDeviceToHost; } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) { mode = kDeviceToDevice; } auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); if (mlir::isa(srcTy) && mlir::isa(dstTy)) { // Transfer between two descriptor. mlir::func::FuncOp func = fir::runtime::getRuntimeFunc( loc, builder); auto fTy = func.getFunctionType(); mlir::Value modeValue = builder.createIntegerConstant(loc, builder.getI32Type(), mode); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); mlir::Value dst = builder.loadIfRef(loc, op.getDst()); mlir::Value src = builder.loadIfRef(loc, op.getSrc()); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); } else if (mlir::isa(dstTy) && fir::isa_trivial(srcTy)) { // Scalar to descriptor transfer. mlir::Value val = op.getSrc(); if (op.getSrc().getDefiningOp() && mlir::isa(op.getSrc().getDefiningOp())) { mlir::Value alloc = builder.createTemporary(loc, srcTy); builder.create(loc, op.getSrc(), alloc); val = alloc; } mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); mlir::Value dst = builder.loadIfRef(loc, op.getDst()); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, dst, val, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); } else { mlir::Value modeValue = builder.createIntegerConstant(loc, builder.getI32Type(), mode); // Type used to compute the width. mlir::Type computeType = dstTy; auto seqTy = mlir::dyn_cast(dstTy); bool dstIsDesc = false; if (mlir::isa(dstTy)) { dstIsDesc = true; computeType = srcTy; seqTy = mlir::dyn_cast(srcTy); } fir::KindMapping kindMap{fir::getKindMapping(mod)}; int width = computeWidth(loc, computeType, kindMap); mlir::Value nbElement; mlir::Type idxTy = rewriter.getIndexType(); if (!op.getShape()) { nbElement = rewriter.create( loc, idxTy, rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize())); } else { auto shapeOp = mlir::dyn_cast(op.getShape().getDefiningOp()); nbElement = createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]); for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) { auto operand = createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]); nbElement = rewriter.create(loc, nbElement, operand); } } mlir::Value widthValue = rewriter.create( loc, idxTy, rewriter.getIntegerAttr(idxTy, width)); mlir::Value bytes = rewriter.create(loc, nbElement, widthValue); mlir::func::FuncOp func = dstIsDesc ? fir::runtime::getRuntimeFunc( loc, builder) : fir::runtime::getRuntimeFunc( loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); mlir::Value dst = dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst(); mlir::Value src = mlir::isa(srcTy) ? builder.loadIfRef(loc, op.getSrc()) : op.getSrc(); llvm::SmallVector args{ fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes, modeValue, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); } return mlir::success(); } }; class CufOpConversion : public fir::impl::CufOpConversionBase { public: void runOnOperation() override { auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); mlir::ConversionTarget target(*ctx); mlir::Operation *op = getOperation(); mlir::ModuleOp module = mlir::dyn_cast(op); if (!module) return signalPassFailure(); std::optional dl = fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false); fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, /*forceUnifiedTBAATree=*/false, *dl); target.addDynamicallyLegalOp([](::cuf::AllocOp op) { return !mlir::isa(op.getInType()); }); target.addDynamicallyLegalOp([](::cuf::FreeOp op) { if (auto refTy = mlir::dyn_cast_or_null( op.getDevptr().getType())) { return !mlir::isa(refTy.getEleTy()); } return true; }); target.addDynamicallyLegalOp( [](::cuf::AllocateOp op) { return needDoubleDescriptor(op); }); target.addDynamicallyLegalOp( [](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); }); target.addDynamicallyLegalOp( [](::cuf::DataTransferOp op) { mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); return !mlir::isa(srcTy) && !mlir::isa(dstTy); }); target.addLegalDialect(); patterns.insert(ctx, &*dl, &typeConverter); patterns.insert(ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(ctx), "error in CUF op conversion\n"); signalPassFailure(); } } }; } // namespace