[flang][cuda] Introduce cuf.set_allocator_idx operation (#148717)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-07-14 17:23:18 -07:00 committed by GitHub
parent 5eecec8e81
commit 2c6771889a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 120 additions and 2 deletions

View File

@ -62,6 +62,15 @@ void RTDEF(CUFDescriptorCheckSection)(
}
}
void RTDEF(CUFSetAllocatorIndex)(
Descriptor *, int index, const char *sourceFile, int sourceLine) {
if (!desc) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash("descriptor is null");
}
desc->SetAllocIdx(index);
}
RT_EXT_API_GROUP_END
}
} // namespace Fortran::runtime::cuda

View File

@ -72,3 +72,13 @@ TEST(AllocatableCUFTest, DescriptorAllocationTest) {
EXPECT_TRUE(desc != nullptr);
RTNAME(CUFFreeDescriptor)(desc);
}
TEST(AllocatableCUFTest, CUFSetAllocatorIndex) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
EXPECT_EQ((int)kDefaultAllocator, a->GetAllocIdx());
RTNAME(CUFSetAllocatorIndex)(*a, kDeviceAllocatorPos, __FILE__, __LINE__);
EXPECT_EQ((int)kDeviceAllocatorPos, a->GetAllocIdx());
}

View File

@ -31,6 +31,10 @@ void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value desc);
/// Generate runtime call to set the allocator index in the descriptor.
void genSetAllocatorIndex(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value desc, mlir::Value index);
} // namespace fir::runtime::cuda
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_

View File

@ -388,4 +388,25 @@ def cuf_StreamCastOp : cuf_Op<"stream_cast", [NoMemoryEffect]> {
let hasVerifier = 1;
}
def cuf_SetAllocatorIndexOp : cuf_Op<"set_allocator_idx", []> {
let summary = "Set the allocator index in a descriptor";
let description = [{
Allocator index in the Fortran descriptor is used to retrived the correct
CUDA allocator to allocate the memory on the device.
In many cases the allocator index is set when the descriptor is created. For
device components, the descriptor is part of the derived-type itself and
needs to be set after the derived-type is allocated in managed memory.
}];
let arguments = (ins Arg<fir_ReferenceType, "", [MemRead, MemWrite]>:$box,
cuf_DataAttributeAttr:$data_attr);
let assemblyFormat = [{
$box `:` qualified(type($box)) attr-dict
}];
let hasVerifier = 1;
}
#endif // FORTRAN_DIALECT_CUF_CUF_OPS

View File

@ -41,6 +41,10 @@ void RTDECL(CUFSyncGlobalDescriptor)(
void RTDECL(CUFDescriptorCheckSection)(
const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);
/// Set the allocator index with the provided value.
void RTDECL(CUFSetAllocatorIndex)(Descriptor *, int index,
const char *sourceFile = nullptr, int sourceLine = 0);
} // extern "C"
} // namespace Fortran::runtime::cuda

View File

@ -47,3 +47,18 @@ void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder,
builder, loc, fTy, desc, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
}
void fir::runtime::cuda::genSetAllocatorIndex(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value desc,
mlir::Value index) {
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFSetAllocatorIndex)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, desc, index, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
}

View File

@ -345,6 +345,17 @@ llvm::LogicalResult cuf::StreamCastOp::verify() {
return checkStreamType(*this);
}
//===----------------------------------------------------------------------===//
// SetAllocatorOp
//===----------------------------------------------------------------------===//
llvm::LogicalResult cuf::SetAllocatorIndexOp::verify() {
if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
return emitOpError(
"expect box to be a reference to class or box type value");
return mlir::success();
}
// Tablegen operators
#define GET_OP_CLASSES

View File

@ -22,6 +22,7 @@
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/CUDA/pointer.h"
#include "flang/Runtime/allocatable.h"
#include "flang/Runtime/allocator-registry-consts.h"
#include "flang/Support/Fortran.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/DLTI/DLTI.h"
@ -923,6 +924,34 @@ struct CUFSyncDescriptorOpConversion
}
};
struct CUFSetAllocatorIndexOpConversion
: public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(cuf::SetAllocatorIndexOp op,
mlir::PatternRewriter &rewriter) const override {
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
int idx = kDefaultAllocator;
if (op.getDataAttr() == cuf::DataAttribute::Device) {
idx = kDeviceAllocatorPos;
} else if (op.getDataAttr() == cuf::DataAttribute::Managed) {
idx = kManagedAllocatorPos;
} else if (op.getDataAttr() == cuf::DataAttribute::Unified) {
idx = kUnifiedAllocatorPos;
} else if (op.getDataAttr() == cuf::DataAttribute::Pinned) {
idx = kPinnedAllocatorPos;
}
mlir::Value index =
builder.createIntegerConstant(loc, builder.getI32Type(), idx);
fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index);
op.erase();
return mlir::success();
}
};
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
@ -984,8 +1013,8 @@ void cuf::populateCUFToFIRConversionPatterns(
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
patterns.getContext());
CUFFreeOpConversion, CUFSyncDescriptorOpConversion,
CUFSetAllocatorIndexOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(

View File

@ -94,4 +94,19 @@ func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
func.func @_QQsetalloc() {
%0 = cuf.alloc !fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}> {bindc_name = "d1", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEd1"} -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
%1 = fir.coordinate_of %0, a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
cuf.set_allocator_idx %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
return
}
// CHECK-LABEL: func.func @_QQsetalloc() {
// CHECK: %[[DT:.*]] = fir.call @_FortranACUFMemAlloc
// CHECK: %[[CONV:.*]] = fir.convert %[[DT]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
// CHECK: %[[COMP:.*]] = fir.coordinate_of %[[CONV]], a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
// CHECK: %[[DESC:.*]] = fir.convert %[[COMP]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFSetAllocatorIndex(%[[DESC]], %c2{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
} // end module