[flang][CUDA] Apply intrinsic operator overrides (#151018)
Fortran's intrinsic numeric and relational operators can be overridden with explicit interfaces so long as one or more of the dummy arguments have the DEVICE attribute. Semantics already allows this without complaint, but fails to replace the operations with the defined specific procedure calls when analyzing expressions.
This commit is contained in:
parent
0d6a67c1ad
commit
b01ab5318e
@ -162,7 +162,6 @@ public:
|
||||
warningsAreErrors_ = x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
SemanticsContext &set_debugModuleWriter(bool x) {
|
||||
debugModuleWriter_ = x;
|
||||
return *this;
|
||||
|
@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
|
||||
// legal.
|
||||
if (nbLhs == 0 && nbRhs > 1) {
|
||||
context_.Say(lhsLoc,
|
||||
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
|
||||
"More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US);
|
||||
}
|
||||
|
||||
if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
|
||||
Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
|
||||
if (evaluate::HasCUDADeviceAttrs(assign->lhs) &&
|
||||
evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
|
||||
if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
|
||||
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
|
||||
GetNbOfCUDADeviceSymbols(assign->rhs) == 1) {
|
||||
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) {
|
||||
return; // This is a special case handled on the host.
|
||||
}
|
||||
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
|
||||
|
@ -2081,7 +2081,7 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) {
|
||||
}
|
||||
|
||||
static bool ConflictsWithIntrinsicOperator(
|
||||
const GenericKind &kind, const Procedure &proc) {
|
||||
const GenericKind &kind, const Procedure &proc, SemanticsContext &context) {
|
||||
if (!kind.IsIntrinsicOperator()) {
|
||||
return false;
|
||||
}
|
||||
@ -2167,7 +2167,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind,
|
||||
}
|
||||
} else if (!checkDefinedOperatorArgs(opName, specific, proc)) {
|
||||
return false; // error was reported
|
||||
} else if (ConflictsWithIntrinsicOperator(kind, proc)) {
|
||||
} else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) {
|
||||
msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US;
|
||||
}
|
||||
if (msg) {
|
||||
|
@ -165,10 +165,17 @@ public:
|
||||
bool CheckForNullPointer(const char *where = "as an operand here");
|
||||
bool CheckForAssumedRank(const char *where = "as an operand here");
|
||||
|
||||
bool AnyCUDADeviceData() const;
|
||||
// Returns true if an interface has been defined for an intrinsic operator
|
||||
// with one or more device operands.
|
||||
bool HasDeviceDefinedIntrinsicOpOverride(const char *) const;
|
||||
template <typename E> bool HasDeviceDefinedIntrinsicOpOverride(E opr) const {
|
||||
return HasDeviceDefinedIntrinsicOpOverride(
|
||||
context_.context().languageFeatures().GetNames(opr));
|
||||
}
|
||||
|
||||
// Find and return a user-defined operator or report an error.
|
||||
// The provided message is used if there is no such operator.
|
||||
// If a definedOpSymbolPtr is provided, the caller must check
|
||||
// for its accessibility.
|
||||
MaybeExpr TryDefinedOp(
|
||||
const char *, parser::MessageFixedText, bool isUserOp = false);
|
||||
template <typename E>
|
||||
@ -183,6 +190,8 @@ public:
|
||||
void Dump(llvm::raw_ostream &);
|
||||
|
||||
private:
|
||||
bool HasDeviceDefinedIntrinsicOpOverride(
|
||||
const std::vector<const char *> &) const;
|
||||
MaybeExpr TryDefinedOp(
|
||||
const std::vector<const char *> &, parser::MessageFixedText);
|
||||
MaybeExpr TryBoundOp(const Symbol &, int passIndex);
|
||||
@ -202,7 +211,7 @@ private:
|
||||
void SayNoMatch(
|
||||
const std::string &, bool isAssignment = false, bool isAmbiguous = false);
|
||||
std::string TypeAsFortran(std::size_t);
|
||||
bool AnyUntypedOrMissingOperand();
|
||||
bool AnyUntypedOrMissingOperand() const;
|
||||
|
||||
ExpressionAnalyzer &context_;
|
||||
ActualArguments actuals_;
|
||||
@ -4497,13 +4506,20 @@ void ArgumentAnalyzer::Analyze(
|
||||
bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr,
|
||||
const DynamicType &leftType, const DynamicType &rightType) const {
|
||||
CHECK(actuals_.size() == 2);
|
||||
return semantics::IsIntrinsicRelational(
|
||||
opr, leftType, GetRank(0), rightType, GetRank(1));
|
||||
return !(context_.context().languageFeatures().IsEnabled(
|
||||
common::LanguageFeature::CUDA) &&
|
||||
HasDeviceDefinedIntrinsicOpOverride(opr)) &&
|
||||
semantics::IsIntrinsicRelational(
|
||||
opr, leftType, GetRank(0), rightType, GetRank(1));
|
||||
}
|
||||
|
||||
bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const {
|
||||
std::optional<DynamicType> leftType{GetType(0)};
|
||||
if (actuals_.size() == 1) {
|
||||
if (context_.context().languageFeatures().IsEnabled(
|
||||
common::LanguageFeature::CUDA) &&
|
||||
HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) {
|
||||
return false;
|
||||
} else if (actuals_.size() == 1) {
|
||||
if (IsBOZLiteral(0)) {
|
||||
return opr == NumericOperator::Add; // unary '+'
|
||||
} else {
|
||||
@ -4617,6 +4633,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArgumentAnalyzer::AnyCUDADeviceData() const {
|
||||
for (const std::optional<ActualArgument> &arg : actuals_) {
|
||||
if (arg) {
|
||||
if (const Expr<SomeType> *expr{arg->UnwrapExpr()}) {
|
||||
if (HasCUDADeviceAttrs(*expr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Some operations can be defined with explicit non-type-bound interfaces
|
||||
// that would erroneously conflict with intrinsic operations in their
|
||||
// types and ranks but have one or more dummy arguments with the DEVICE
|
||||
// attribute.
|
||||
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
|
||||
const char *opr) const {
|
||||
if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
|
||||
std::string oprNameString{"operator("s + opr + ')'};
|
||||
parser::CharBlock oprName{oprNameString};
|
||||
parser::Messages buffer;
|
||||
auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
|
||||
const auto &scope{context_.context().FindScope(source_)};
|
||||
if (Symbol * generic{scope.FindSymbol(oprName)}) {
|
||||
parser::Name name{generic->name(), generic};
|
||||
const Symbol *resultSymbol{nullptr};
|
||||
if (context_.AnalyzeDefinedOp(
|
||||
name, ActualArguments{actuals_}, resultSymbol)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
|
||||
const std::vector<const char *> &oprNames) const {
|
||||
for (const char *opr : oprNames) {
|
||||
if (HasDeviceDefinedIntrinsicOpOverride(opr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
|
||||
const char *opr, parser::MessageFixedText error, bool isUserOp) {
|
||||
if (AnyUntypedOrMissingOperand()) {
|
||||
@ -5135,7 +5198,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
|
||||
}
|
||||
}
|
||||
|
||||
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() {
|
||||
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
|
||||
for (const auto &actual : actuals_) {
|
||||
if (!actual ||
|
||||
(!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {
|
||||
|
49
flang/test/Semantics/bug1214.cuf
Normal file
49
flang/test/Semantics/bug1214.cuf
Normal file
@ -0,0 +1,49 @@
|
||||
! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
|
||||
module overrides
|
||||
type realResult
|
||||
real a
|
||||
end type
|
||||
interface operator(*)
|
||||
procedure :: multHostDevice, multDeviceHost
|
||||
end interface
|
||||
interface assignment(=)
|
||||
procedure :: assignHostResult, assignDeviceResult
|
||||
end interface
|
||||
contains
|
||||
elemental function multHostDevice(x, y) result(result)
|
||||
real, intent(in) :: x
|
||||
real, intent(in), device :: y
|
||||
type(realResult) result
|
||||
result%a = x * y
|
||||
end
|
||||
elemental function multDeviceHost(x, y) result(result)
|
||||
real, intent(in), device :: x
|
||||
real, intent(in) :: y
|
||||
type(realResult) result
|
||||
result%a = x * y
|
||||
end
|
||||
elemental subroutine assignHostResult(lhs, rhs)
|
||||
real, intent(out) :: lhs
|
||||
type(realResult), intent(in) :: rhs
|
||||
lhs = rhs%a
|
||||
end
|
||||
elemental subroutine assignDeviceResult(lhs, rhs)
|
||||
real, intent(out), device :: lhs
|
||||
type(realResult), intent(in) :: rhs
|
||||
lhs = rhs%a
|
||||
end
|
||||
end
|
||||
|
||||
program p
|
||||
use overrides
|
||||
real, device :: da, db
|
||||
real :: ha, hb
|
||||
!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da))
|
||||
db = 2. * da
|
||||
!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4))
|
||||
db = da * 2.
|
||||
!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da))
|
||||
ha = 2. * da
|
||||
!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4))
|
||||
ha = da * 2.
|
||||
end
|
@ -16,7 +16,7 @@ subroutine sub1()
|
||||
real, device :: adev(10), bdev(10)
|
||||
real :: ahost(10)
|
||||
|
||||
!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
|
||||
!ERROR: More than one reference to a CUDA object on the right hand side of the assignment
|
||||
ahost = adev + bdev
|
||||
|
||||
ahost = adev + adev
|
||||
|
Loading…
x
Reference in New Issue
Block a user