[flang][cuda] Lower attribute for local variable (#81076)

This is a first simple patch to introduce a new FIR attribute to carry
the CUDA variable attribute information to hlfir.declare and fir.declare
operations. It currently lowers this information for local variables.

The texture attribute is omitted since it is rejected by semantic and
will not make its way to MLIR.

This new attribute is added as optional attribute to the hlfir.declare
and fir.declare operations.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2024-02-08 10:03:08 -08:00 committed by GitHub
parent 758fd59d01
commit abc4f74df7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 121 additions and 21 deletions

View File

@ -137,6 +137,12 @@ translateSymbolAttributes(mlir::MLIRContext *mlirContext,
fir::FortranVariableFlagsEnum extraFlags = fir::FortranVariableFlagsEnum extraFlags =
fir::FortranVariableFlagsEnum::None); fir::FortranVariableFlagsEnum::None);
/// Translate the CUDA Fortran attributes of \p sym into the FIR CUDA attribute
/// representation.
fir::CUDAAttributeAttr
translateSymbolCUDAAttribute(mlir::MLIRContext *mlirContext,
const Fortran::semantics::Symbol &sym);
/// Map a symbol to a given fir::ExtendedValue. This will generate an /// Map a symbol to a given fir::ExtendedValue. This will generate an
/// hlfir.declare when lowering to HLFIR and map the hlfir.declare result to the /// hlfir.declare when lowering to HLFIR and map the hlfir.declare result to the
/// symbol. /// symbol.

View File

@ -233,11 +233,11 @@ translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
fir::FortranVariableOpInterface fortranVariable); fir::FortranVariableOpInterface fortranVariable);
/// Generate declaration for a fir::ExtendedValue in memory. /// Generate declaration for a fir::ExtendedValue in memory.
fir::FortranVariableOpInterface genDeclare(mlir::Location loc, fir::FortranVariableOpInterface
fir::FirOpBuilder &builder, genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv, const fir::ExtendedValue &exv, llvm::StringRef name,
llvm::StringRef name, fir::FortranVariableFlagsAttr flags,
fir::FortranVariableFlagsAttr flags); fir::CUDAAttributeAttr cudaAttr = {});
/// Generate an hlfir.associate to build a variable from an expression value. /// Generate an hlfir.associate to build a variable from an expression value.
/// The type of the variable must be provided so that scalar logicals are /// The type of the variable must be provided so that scalar logicals are

View File

@ -55,7 +55,28 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
let returnType = "::fir::FortranVariableFlagsEnum"; let returnType = "::fir::FortranVariableFlagsEnum";
let convertFromStorage = "$_self.getFlags()"; let convertFromStorage = "$_self.getFlags()";
let constBuilderCall = let constBuilderCall =
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)"; "::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
}
def CUDAconstant : I32EnumAttrCase<"Constant", 0, "constant">;
def CUDAdevice : I32EnumAttrCase<"Device", 1, "device">;
def CUDAmanaged : I32EnumAttrCase<"Managed", 2, "managed">;
def CUDApinned : I32EnumAttrCase<"Pinned", 3, "pinned">;
def CUDAshared : I32EnumAttrCase<"Shared", 4, "shared">;
def CUDAunified : I32EnumAttrCase<"Unified", 5, "unified">;
// Texture is omitted since it is obsolete and rejected by semantic.
def fir_CUDAAttribute : I32EnumAttr<
"CUDAAttribute",
"CUDA Fortran variable attributes",
[CUDAconstant, CUDAdevice, CUDAmanaged, CUDApinned, CUDAshared,
CUDAunified]> {
let genSpecializedAttr = 0;
let cppNamespace = "::fir";
}
def fir_CUDAAttributeAttr : EnumAttr<fir_Dialect, fir_CUDAAttribute, "cuda"> {
let assemblyFormat = [{ ```<` $value `>` }];
} }
def fir_BoxFieldAttr : I32EnumAttr< def fir_BoxFieldAttr : I32EnumAttr<

View File

@ -3027,7 +3027,8 @@ def fir_DeclareOp : fir_Op<"declare", [AttrSizedOperandSegments,
Optional<AnyShapeOrShiftType>:$shape, Optional<AnyShapeOrShiftType>:$shape,
Variadic<AnyIntegerType>:$typeparams, Variadic<AnyIntegerType>:$typeparams,
Builtin_StringAttr:$uniq_name, Builtin_StringAttr:$uniq_name,
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
); );
let results = (outs AnyRefOrBox); let results = (outs AnyRefOrBox);

View File

@ -88,7 +88,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
Optional<AnyShapeOrShiftType>:$shape, Optional<AnyShapeOrShiftType>:$shape,
Variadic<AnyIntegerType>:$typeparams, Variadic<AnyIntegerType>:$typeparams,
Builtin_StringAttr:$uniq_name, Builtin_StringAttr:$uniq_name,
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
); );
let results = (outs AnyFortranVariable, AnyRefOrBoxLike); let results = (outs AnyFortranVariable, AnyRefOrBoxLike);
@ -101,7 +102,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
let builders = [ let builders = [
OpBuilder<(ins "mlir::Value":$memref, "llvm::StringRef":$uniq_name, OpBuilder<(ins "mlir::Value":$memref, "llvm::StringRef":$uniq_name,
CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams,
CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>]; CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs,
CArg<"fir::CUDAAttributeAttr", "{}">:$cuda_attr)>];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Get the variable original base (same as input). It lacks /// Get the variable original base (same as input). It lacks

View File

@ -1579,6 +1579,38 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
return fir::FortranVariableFlagsAttr::get(mlirContext, flags); return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
} }
fir::CUDAAttributeAttr Fortran::lower::translateSymbolCUDAAttribute(
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
Fortran::semantics::GetCUDADataAttr(&sym);
if (cudaAttr) {
fir::CUDAAttribute attr;
switch (*cudaAttr) {
case Fortran::common::CUDADataAttr::Constant:
attr = fir::CUDAAttribute::Constant;
break;
case Fortran::common::CUDADataAttr::Device:
attr = fir::CUDAAttribute::Device;
break;
case Fortran::common::CUDADataAttr::Managed:
attr = fir::CUDAAttribute::Managed;
break;
case Fortran::common::CUDADataAttr::Pinned:
attr = fir::CUDAAttribute::Pinned;
break;
case Fortran::common::CUDADataAttr::Shared:
attr = fir::CUDAAttribute::Shared;
break;
case Fortran::common::CUDADataAttr::Texture:
// Obsolete attribute
break;
}
return fir::CUDAAttributeAttr::get(mlirContext, attr);
}
return {};
}
/// Map a symbol to its FIR address and evaluated specification expressions. /// Map a symbol to its FIR address and evaluated specification expressions.
/// Not for symbols lowered to fir.box. /// Not for symbols lowered to fir.box.
/// Will optionally create fir.declare. /// Will optionally create fir.declare.
@ -1618,6 +1650,8 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
auto name = converter.mangleName(sym); auto name = converter.mangleName(sym);
fir::FortranVariableFlagsAttr attributes = fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(), sym); Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
fir::CUDAAttributeAttr cudaAttr =
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(), sym);
if (isCrayPointee) { if (isCrayPointee) {
mlir::Type baseType = mlir::Type baseType =
@ -1664,7 +1698,7 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
return; return;
} }
auto newBase = builder.create<hlfir::DeclareOp>( auto newBase = builder.create<hlfir::DeclareOp>(
loc, base, name, shapeOrShift, lenParams, attributes); loc, base, name, shapeOrShift, lenParams, attributes, cudaAttr);
symMap.addVariableDefinition(sym, newBase, force); symMap.addVariableDefinition(sym, newBase, force);
return; return;
} }
@ -1709,9 +1743,12 @@ void Fortran::lower::genDeclareSymbol(
fir::FortranVariableFlagsAttr attributes = fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes( Fortran::lower::translateSymbolAttributes(
builder.getContext(), sym.GetUltimate(), extraFlags); builder.getContext(), sym.GetUltimate(), extraFlags);
fir::CUDAAttributeAttr cudaAttr =
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(),
sym.GetUltimate());
auto name = converter.mangleName(sym); auto name = converter.mangleName(sym);
hlfir::EntityWithAttributes declare = hlfir::EntityWithAttributes declare =
hlfir::genDeclare(loc, builder, exv, name, attributes); hlfir::genDeclare(loc, builder, exv, name, attributes, cudaAttr);
symMap.addVariableDefinition(sym, declare.getIfVariableInterface(), force); symMap.addVariableDefinition(sym, declare.getIfVariableInterface(), force);
return; return;
} }

View File

@ -198,7 +198,8 @@ mlir::Value hlfir::Entity::getFirBase() const {
fir::FortranVariableOpInterface fir::FortranVariableOpInterface
hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv, llvm::StringRef name, const fir::ExtendedValue &exv, llvm::StringRef name,
fir::FortranVariableFlagsAttr flags) { fir::FortranVariableFlagsAttr flags,
fir::CUDAAttributeAttr cudaAttr) {
mlir::Value base = fir::getBase(exv); mlir::Value base = fir::getBase(exv);
assert(fir::conformsWithPassByRef(base.getType()) && assert(fir::conformsWithPassByRef(base.getType()) &&
@ -228,7 +229,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
}, },
[](const auto &) {}); [](const auto &) {});
auto declareOp = builder.create<hlfir::DeclareOp>( auto declareOp = builder.create<hlfir::DeclareOp>(
loc, base, name, shapeOrShift, lenParams, flags); loc, base, name, shapeOrShift, lenParams, flags, cudaAttr);
return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation()); return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation());
} }

View File

@ -14,6 +14,7 @@
#include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "mlir/IR/AttributeSupport.h" #include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallString.h"
@ -297,5 +298,5 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() { void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr, addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr, LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
UpperBoundAttr>(); UpperBoundAttr, CUDAAttributeAttr>();
} }

View File

@ -123,14 +123,15 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value memref, mlir::OperationState &result, mlir::Value memref,
llvm::StringRef uniq_name, mlir::Value shape, llvm::StringRef uniq_name, mlir::Value shape,
mlir::ValueRange typeparams, mlir::ValueRange typeparams,
fir::FortranVariableFlagsAttr fortran_attrs) { fir::FortranVariableFlagsAttr fortran_attrs,
fir::CUDAAttributeAttr cuda_attr) {
auto nameAttr = builder.getStringAttr(uniq_name); auto nameAttr = builder.getStringAttr(uniq_name);
mlir::Type inputType = memref.getType(); mlir::Type inputType = memref.getType();
bool hasExplicitLbs = hasExplicitLowerBounds(shape); bool hasExplicitLbs = hasExplicitLowerBounds(shape);
mlir::Type hlfirVariableType = mlir::Type hlfirVariableType =
getHLFIRVariableType(inputType, hasExplicitLbs); getHLFIRVariableType(inputType, hasExplicitLbs);
build(builder, result, {hlfirVariableType, inputType}, memref, shape, build(builder, result, {hlfirVariableType, inputType}, memref, shape,
typeparams, nameAttr, fortran_attrs); typeparams, nameAttr, fortran_attrs, cuda_attr);
} }
mlir::LogicalResult hlfir::DeclareOp::verify() { mlir::LogicalResult hlfir::DeclareOp::verify() {

View File

@ -320,12 +320,16 @@ public:
mlir::Location loc = declareOp->getLoc(); mlir::Location loc = declareOp->getLoc();
mlir::Value memref = declareOp.getMemref(); mlir::Value memref = declareOp.getMemref();
fir::FortranVariableFlagsAttr fortranAttrs; fir::FortranVariableFlagsAttr fortranAttrs;
fir::CUDAAttributeAttr cudaAttr;
if (auto attrs = declareOp.getFortranAttrs()) if (auto attrs = declareOp.getFortranAttrs())
fortranAttrs = fortranAttrs =
fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs); fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs);
if (auto attr = declareOp.getCudaAttr())
cudaAttr = fir::CUDAAttributeAttr::get(rewriter.getContext(), *attr);
auto firDeclareOp = rewriter.create<fir::DeclareOp>( auto firDeclareOp = rewriter.create<fir::DeclareOp>(
loc, memref.getType(), memref, declareOp.getShape(), loc, memref.getType(), memref, declareOp.getShape(),
declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs); declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs,
cudaAttr);
// Propagate other attributes from hlfir.declare to fir.declare. // Propagate other attributes from hlfir.declare to fir.declare.
// OpenACC's acc.declare is one example. Right now, the propagation // OpenACC's acc.declare is one example. Right now, the propagation

View File

@ -0,0 +1,22 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s --check-prefix=FIR
! Test lowering of CUDA attribute on local variables.
subroutine local_var_attrs
real, constant :: rc
real, device :: rd
real, allocatable, managed :: rm
real, allocatable, pinned :: rp
end subroutine
! CHECK-LABEL: func.func @_QPlocal_var_attrs()
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>

View File

@ -49,7 +49,8 @@ TEST_F(FortranVariableTest, SimpleScalar) {
auto name = mlir::StringAttr::get(&context, "x"); auto name = mlir::StringAttr::get(&context, "x");
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr, auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
/*shape=*/mlir::Value{}, /*typeParams=*/std::nullopt, name, /*shape=*/mlir::Value{}, /*typeParams=*/std::nullopt, name,
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); /*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
/*cuda_attr=*/fir::CUDAAttributeAttr{});
fir::FortranVariableOpInterface fortranVariable = declare; fir::FortranVariableOpInterface fortranVariable = declare;
EXPECT_FALSE(fortranVariable.isArray()); EXPECT_FALSE(fortranVariable.isArray());
@ -74,7 +75,8 @@ TEST_F(FortranVariableTest, CharacterScalar) {
auto name = mlir::StringAttr::get(&context, "x"); auto name = mlir::StringAttr::get(&context, "x");
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr, auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
/*shape=*/mlir::Value{}, typeParams, name, /*shape=*/mlir::Value{}, typeParams, name,
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); /*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
/*cuda_attr=*/fir::CUDAAttributeAttr{});
fir::FortranVariableOpInterface fortranVariable = declare; fir::FortranVariableOpInterface fortranVariable = declare;
EXPECT_FALSE(fortranVariable.isArray()); EXPECT_FALSE(fortranVariable.isArray());
@ -104,7 +106,8 @@ TEST_F(FortranVariableTest, SimpleArray) {
auto name = mlir::StringAttr::get(&context, "x"); auto name = mlir::StringAttr::get(&context, "x");
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr, auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
shape, /*typeParams*/ std::nullopt, name, shape, /*typeParams*/ std::nullopt, name,
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); /*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
/*cuda_attr=*/fir::CUDAAttributeAttr{});
fir::FortranVariableOpInterface fortranVariable = declare; fir::FortranVariableOpInterface fortranVariable = declare;
EXPECT_TRUE(fortranVariable.isArray()); EXPECT_TRUE(fortranVariable.isArray());
@ -134,7 +137,8 @@ TEST_F(FortranVariableTest, CharacterArray) {
auto name = mlir::StringAttr::get(&context, "x"); auto name = mlir::StringAttr::get(&context, "x");
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr, auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
shape, typeParams, name, shape, typeParams, name,
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); /*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
/*cuda_attr=*/fir::CUDAAttributeAttr{});
fir::FortranVariableOpInterface fortranVariable = declare; fir::FortranVariableOpInterface fortranVariable = declare;
EXPECT_TRUE(fortranVariable.isArray()); EXPECT_TRUE(fortranVariable.isArray());