[flang][cuda] Introduce cuf.set_allocator_idx operation (#148717)
This commit is contained in:
parent
5eecec8e81
commit
2c6771889a
@ -62,6 +62,15 @@ void RTDEF(CUFDescriptorCheckSection)(
|
||||
}
|
||||
}
|
||||
|
||||
void RTDEF(CUFSetAllocatorIndex)(
|
||||
Descriptor *, int index, const char *sourceFile, int sourceLine) {
|
||||
if (!desc) {
|
||||
Terminator terminator{sourceFile, sourceLine};
|
||||
terminator.Crash("descriptor is null");
|
||||
}
|
||||
desc->SetAllocIdx(index);
|
||||
}
|
||||
|
||||
RT_EXT_API_GROUP_END
|
||||
}
|
||||
} // namespace Fortran::runtime::cuda
|
||||
|
@ -72,3 +72,13 @@ TEST(AllocatableCUFTest, DescriptorAllocationTest) {
|
||||
EXPECT_TRUE(desc != nullptr);
|
||||
RTNAME(CUFFreeDescriptor)(desc);
|
||||
}
|
||||
|
||||
TEST(AllocatableCUFTest, CUFSetAllocatorIndex) {
|
||||
using Fortran::common::TypeCategory;
|
||||
RTNAME(CUFRegisterAllocator)();
|
||||
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
|
||||
auto a{createAllocatable(TypeCategory::Real, 4)};
|
||||
EXPECT_EQ((int)kDefaultAllocator, a->GetAllocIdx());
|
||||
RTNAME(CUFSetAllocatorIndex)(*a, kDeviceAllocatorPos, __FILE__, __LINE__);
|
||||
EXPECT_EQ((int)kDeviceAllocatorPos, a->GetAllocIdx());
|
||||
}
|
||||
|
@ -31,6 +31,10 @@ void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Value desc);
|
||||
|
||||
/// Generate runtime call to set the allocator index in the descriptor.
|
||||
void genSetAllocatorIndex(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Value desc, mlir::Value index);
|
||||
|
||||
} // namespace fir::runtime::cuda
|
||||
|
||||
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_
|
||||
|
@ -388,4 +388,25 @@ def cuf_StreamCastOp : cuf_Op<"stream_cast", [NoMemoryEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def cuf_SetAllocatorIndexOp : cuf_Op<"set_allocator_idx", []> {
|
||||
let summary = "Set the allocator index in a descriptor";
|
||||
|
||||
let description = [{
|
||||
Allocator index in the Fortran descriptor is used to retrived the correct
|
||||
CUDA allocator to allocate the memory on the device.
|
||||
In many cases the allocator index is set when the descriptor is created. For
|
||||
device components, the descriptor is part of the derived-type itself and
|
||||
needs to be set after the derived-type is allocated in managed memory.
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<fir_ReferenceType, "", [MemRead, MemWrite]>:$box,
|
||||
cuf_DataAttributeAttr:$data_attr);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$box `:` qualified(type($box)) attr-dict
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // FORTRAN_DIALECT_CUF_CUF_OPS
|
||||
|
@ -41,6 +41,10 @@ void RTDECL(CUFSyncGlobalDescriptor)(
|
||||
void RTDECL(CUFDescriptorCheckSection)(
|
||||
const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);
|
||||
|
||||
/// Set the allocator index with the provided value.
|
||||
void RTDECL(CUFSetAllocatorIndex)(Descriptor *, int index,
|
||||
const char *sourceFile = nullptr, int sourceLine = 0);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
} // namespace Fortran::runtime::cuda
|
||||
|
@ -47,3 +47,18 @@ void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder,
|
||||
builder, loc, fTy, desc, sourceFile, sourceLine)};
|
||||
builder.create<fir::CallOp>(loc, func, args);
|
||||
}
|
||||
|
||||
void fir::runtime::cuda::genSetAllocatorIndex(fir::FirOpBuilder &builder,
|
||||
mlir::Location loc,
|
||||
mlir::Value desc,
|
||||
mlir::Value index) {
|
||||
mlir::func::FuncOp func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(CUFSetAllocatorIndex)>(loc, builder);
|
||||
auto fTy = func.getFunctionType();
|
||||
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
mlir::Value sourceLine =
|
||||
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
|
||||
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
|
||||
builder, loc, fTy, desc, index, sourceFile, sourceLine)};
|
||||
builder.create<fir::CallOp>(loc, func, args);
|
||||
}
|
||||
|
@ -345,6 +345,17 @@ llvm::LogicalResult cuf::StreamCastOp::verify() {
|
||||
return checkStreamType(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SetAllocatorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
llvm::LogicalResult cuf::SetAllocatorIndexOp::verify() {
|
||||
if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
|
||||
return emitOpError(
|
||||
"expect box to be a reference to class or box type value");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Tablegen operators
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "flang/Runtime/CUDA/memory.h"
|
||||
#include "flang/Runtime/CUDA/pointer.h"
|
||||
#include "flang/Runtime/allocatable.h"
|
||||
#include "flang/Runtime/allocator-registry-consts.h"
|
||||
#include "flang/Support/Fortran.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
@ -923,6 +924,34 @@ struct CUFSyncDescriptorOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct CUFSetAllocatorIndexOpConversion
|
||||
: public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(cuf::SetAllocatorIndexOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
fir::FirOpBuilder builder(rewriter, mod);
|
||||
mlir::Location loc = op.getLoc();
|
||||
int idx = kDefaultAllocator;
|
||||
if (op.getDataAttr() == cuf::DataAttribute::Device) {
|
||||
idx = kDeviceAllocatorPos;
|
||||
} else if (op.getDataAttr() == cuf::DataAttribute::Managed) {
|
||||
idx = kManagedAllocatorPos;
|
||||
} else if (op.getDataAttr() == cuf::DataAttribute::Unified) {
|
||||
idx = kUnifiedAllocatorPos;
|
||||
} else if (op.getDataAttr() == cuf::DataAttribute::Pinned) {
|
||||
idx = kPinnedAllocatorPos;
|
||||
}
|
||||
mlir::Value index =
|
||||
builder.createIntegerConstant(loc, builder.getI32Type(), idx);
|
||||
fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index);
|
||||
op.erase();
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
@ -984,8 +1013,8 @@ void cuf::populateCUFToFIRConversionPatterns(
|
||||
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
|
||||
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
|
||||
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
|
||||
CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
|
||||
patterns.getContext());
|
||||
CUFFreeOpConversion, CUFSyncDescriptorOpConversion,
|
||||
CUFSetAllocatorIndexOpConversion>(patterns.getContext());
|
||||
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
|
||||
&dl, &converter);
|
||||
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
|
||||
|
@ -94,4 +94,19 @@ func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
|
||||
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
|
||||
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
|
||||
|
||||
|
||||
func.func @_QQsetalloc() {
|
||||
%0 = cuf.alloc !fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}> {bindc_name = "d1", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEd1"} -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
|
||||
%1 = fir.coordinate_of %0, a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
|
||||
cuf.set_allocator_idx %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @_QQsetalloc() {
|
||||
// CHECK: %[[DT:.*]] = fir.call @_FortranACUFMemAlloc
|
||||
// CHECK: %[[CONV:.*]] = fir.convert %[[DT]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
|
||||
// CHECK: %[[COMP:.*]] = fir.coordinate_of %[[CONV]], a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
|
||||
// CHECK: %[[DESC:.*]] = fir.convert %[[COMP]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
|
||||
// CHECK: fir.call @_FortranACUFSetAllocatorIndex(%[[DESC]], %c2{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
|
||||
|
||||
} // end module
|
||||
|
Loading…
x
Reference in New Issue
Block a user