[flang] Fix fir.call setCalleeFromCallable (#187124)

The CallOpInterface setCalleeFromCallable allows either value or
SymbolRef to be passed in. However, the implementation showed an issue
because while it was able to set attribute, it would fall-through and
also try to set value.

This PR improves the implementation to handle updating the callee even
when switching modes (direct vs indirect) and adds testing for these
APIs.
This commit is contained in:
Razvan Lupusoru 2026-03-18 10:49:06 -07:00 committed by GitHub
parent ba231aaaa0
commit f8db5db958
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 466 additions and 7 deletions

View File

@ -2700,13 +2700,7 @@ def fir_CallOp : fir_Op<"call",
}
/// Set the callee for this operation.
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
if (auto calling =
(*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
(*this)->setAttr(getCalleeAttrName(),
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
void setCalleeFromCallable(mlir::CallInterfaceCallable callee);
}];
}

View File

@ -1350,6 +1350,56 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
result.addTypes(results);
}
void fir::CallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
// Handling a direct call.
bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
(*this)->setAttr(getCalleeAttrName(), symbolRef);
// If it was indirect before, the operand list and associated attributes
// needs to be fixed up.
if (wasIndirect) {
assert(getNumOperands() > 0 && "indirect call must have callee operand");
(*this)->eraseOperand(0);
// Fix arg_attrs to remove the first (callee) operand if needed.
if (auto argAttrs = getArgAttrsAttr()) {
// Since we already removed the first operand, check that number
// of attributes is one more than number of operands.
assert(argAttrs.size() == getNumOperands() + 1 &&
"arg_attrs must be one-per-operand");
llvm::SmallVector<mlir::Attribute> newAttrs(argAttrs.begin() + 1,
argAttrs.end());
if (newAttrs.empty())
(*this)->removeAttr(getArgAttrsAttrName());
else
(*this)->setAttr(getArgAttrsAttrName(),
mlir::ArrayAttr::get(getContext(), newAttrs));
}
}
return;
}
// The provided callee makes this an indirect call now.
bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
(*this)->removeAttr(getCalleeAttrNameStr());
mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
if (wasIndirect) {
setOperand(0, calleeVal);
} else {
(*this)->insertOperands(0, calleeVal);
// Make arg_attrs consistent in size with operands by adding an empty dict
// for the callee.
if (auto argAttrs = getArgAttrsAttr()) {
assert(argAttrs.size() == getNumOperands() - 1 &&
"arg_attrs must be one-per-operand");
llvm::SmallVector<mlir::Attribute> newAttrs;
newAttrs.reserve(1 + argAttrs.size());
newAttrs.push_back(mlir::DictionaryAttr::get(getContext(), {}));
newAttrs.append(argAttrs.begin(), argAttrs.end());
(*this)->setAttr(getArgAttrsAttrName(),
mlir::ArrayAttr::get(getContext(), newAttrs));
}
}
}
//===----------------------------------------------------------------------===//
// CharConvertOp
//===----------------------------------------------------------------------===//

View File

@ -34,6 +34,7 @@ add_flang_unittest(FlangOptimizerTests
Builder/Runtime/ReductionTest.cpp
Builder/Runtime/StopTest.cpp
Builder/Runtime/TransformationalTest.cpp
FIRCallInterfaceTest.cpp
FIRContextTest.cpp
FIRTypesTest.cpp
FortranVariableTest.cpp

View File

@ -0,0 +1,414 @@
//===- FIRCallInterfaceTest.cpp - fir::CallOp setCalleeFromCallable tests -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Tests for CallOpInterface on fir::CallOp.
//
//===----------------------------------------------------------------------===//
#include "gtest/gtest.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
namespace {
static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
return llvm::isa<SymbolRefAttr>(callable);
}
static bool isValue(mlir::CallInterfaceCallable callable) {
return llvm::isa<Value>(callable);
}
/// Creates a module and a function with entry block. Builder insertion point is
/// set to the block start. Returns (func, block) so tests can create calls in
/// the block and use block arguments as callee/args.
std::pair<func::FuncOp, Block *> createModuleWithFunction(
OpBuilder &builder, Location loc, StringRef name, FunctionType funcType) {
ModuleOp module = ModuleOp::create(builder, loc);
builder.setInsertionPointToStart(module.getBody());
auto func = func::FuncOp::create(builder, loc, name, funcType);
func.setPrivate();
Block *block = func.addEntryBlock();
builder.setInsertionPointToStart(block);
return {func, block};
}
} // namespace
struct FIRCallInterfaceTest : public testing::Test {
void SetUp() override { fir::support::loadDialects(context); }
MLIRContext context;
};
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto funcType = builder.getFunctionType({}, {});
(void)createModuleWithFunction(builder, loc, "target", funcType);
// Direct call: fir.call @target()
auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
auto callOp = fir::CallOp::create(
builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 0u);
// Change to another symbol; should remain direct with no extra operand.
auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other");
callOp.setCalleeFromCallable(newCallTargetRef);
EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
.getRootReference()
.getValue(),
"other");
EXPECT_EQ(callOp.getNumOperands(), 0u);
EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto funcType = builder.getFunctionType({}, {});
auto containerType = builder.getFunctionType({funcType}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
// Indirect call: fir.call %arg0()
Value callTargetValue = block->getArgument(0);
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{callTargetValue});
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
// Switch to direct call; operand 0 must be removed.
auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
callOp.setCalleeFromCallable(callTargetRef);
EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
.getRootReference()
.getValue(),
"direct_target");
EXPECT_EQ(callOp.getNumOperands(), 0u);
EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto funcType = builder.getFunctionType({}, {});
auto containerType = builder.getFunctionType({funcType}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
// Direct call first
auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
auto callOp = fir::CallOp::create(
builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 0u);
// Switch to indirect; attribute must be unset, operand 0 set.
Value callTargetValue = block->getArgument(0);
callOp.setCalleeFromCallable(callTargetValue);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), callTargetValue);
EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto funcType = builder.getFunctionType({}, {});
auto containerType = builder.getFunctionType({funcType, funcType}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value callTarget0 = block->getArgument(0);
Value callTarget1 = block->getArgument(1);
// Indirect call: fir.call %arg0()
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{callTarget0});
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), callTarget0);
// Switch to other indirect call target; should remain indirect, operand 0
// updated.
callOp.setCalleeFromCallable(callTarget1);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), callTarget1);
EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value calleeVal = block->getArgument(0);
Value argVal = block->getArgument(1);
// Direct call with one argument: fir.call @target(%arg)
auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), argVal);
// Switch to indirect; callee must be inserted at 0, arg preserved.
callOp.setCalleeFromCallable(calleeVal);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), calleeVal);
EXPECT_EQ(callOp.getOperand(1), argVal);
EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value calleeVal = block->getArgument(0);
Value argVal = block->getArgument(1);
// Indirect call with one argument: fir.call %callee(%arg)
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), calleeVal);
EXPECT_EQ(callOp.getOperand(1), argVal);
// Switch to direct; callee operand must be removed, arg preserved.
auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
callOp.setCalleeFromCallable(callTargetRef);
EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), argVal);
EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
TEST_F(
FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect_withArgs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value callee0 = block->getArgument(0);
Value callee1 = block->getArgument(1);
Value argVal = block->getArgument(2);
// Indirect call with one argument: fir.call %callee0(%arg)
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), callee0);
EXPECT_EQ(callOp.getOperand(1), argVal);
// Switch to other indirect callee; operand 0 updated, arg preserved.
callOp.setCalleeFromCallable(callee1);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), callee1);
EXPECT_EQ(callOp.getOperand(1), argVal);
EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
}
static ArrayAttr makeArgAttrs(
MLIRContext *ctx, llvm::ArrayRef<DictionaryAttr> dicts) {
llvm::SmallVector<Attribute> attrs(dicts.begin(), dicts.end());
return ArrayAttr::get(ctx, attrs);
}
static DictionaryAttr makeTestArgDict(MLIRContext *ctx, StringRef value) {
return DictionaryAttr::get(ctx,
{NamedAttribute(
StringAttr::get(ctx, "test.attr"), StringAttr::get(ctx, value))});
}
TEST_F(
FIRCallInterfaceTest, setCalleeFromCallable_directToDirect_withArgAttrs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value argVal = block->getArgument(1);
// Direct call with one argument and arg_attrs.
auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
callOp->setAttr(callOp.getArgAttrsAttrName(),
makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
// Switch to another direct callee
auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other_target");
callOp.setCalleeFromCallable(newCallTargetRef);
EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
.getRootReference()
.getValue(),
"other_target");
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), argVal);
ArrayAttr argAttrs = callOp.getArgAttrsAttr();
ASSERT_TRUE(argAttrs);
ASSERT_EQ(argAttrs.size(), 1u);
EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
StringAttr::get(&context, "arg0"));
}
TEST_F(
FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgAttrs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value calleeVal = block->getArgument(0);
Value argVal = block->getArgument(1);
// Direct call with one argument and arg_attrs for that argument.
auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
callOp->setAttr(callOp.getArgAttrsAttrName(),
makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
// Switch to indirect
callOp.setCalleeFromCallable(calleeVal);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), calleeVal);
EXPECT_EQ(callOp.getOperand(1), argVal);
ArrayAttr argAttrs = callOp.getArgAttrsAttr();
ASSERT_TRUE(argAttrs);
ASSERT_EQ(argAttrs.size(), 2u);
// First entry is empty dict for callee.
EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
// Second entry preserves the argument's attribute.
auto argDict = llvm::cast<DictionaryAttr>(argAttrs[1]);
EXPECT_EQ(argDict.get("test.attr"), StringAttr::get(&context, "arg0"));
}
TEST_F(
FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgAttrs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value calleeVal = block->getArgument(0);
Value argVal = block->getArgument(1);
// Indirect call with callee + one argument
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
callOp->setAttr(callOp.getArgAttrsAttrName(),
makeArgAttrs(&context,
{DictionaryAttr::get(&context, {}),
makeTestArgDict(&context, "arg0")}));
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
// Switch to direct
auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
callOp.setCalleeFromCallable(callTargetRef);
EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 1u);
EXPECT_EQ(callOp.getOperand(0), argVal);
ArrayAttr argAttrs = callOp.getArgAttrsAttr();
ASSERT_TRUE(argAttrs);
ASSERT_EQ(argAttrs.size(), 1u);
EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
StringAttr::get(&context, "arg0"));
}
TEST_F(FIRCallInterfaceTest,
setCalleeFromCallable_indirectToIndirect_withArgAttrs) {
OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
auto i32Ty = builder.getI32Type();
auto funcType = builder.getFunctionType({i32Ty}, {});
auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
auto [func, block] =
createModuleWithFunction(builder, loc, "container", containerType);
Value callee0 = block->getArgument(0);
Value callee1 = block->getArgument(1);
Value argVal = block->getArgument(2);
// Indirect call with one argument and arg_attrs
auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
callOp->setAttr(callOp.getArgAttrsAttrName(),
makeArgAttrs(&context,
{DictionaryAttr::get(&context, {}),
makeTestArgDict(&context, "arg0")}));
ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
// Switch to other indirect callee
callOp.setCalleeFromCallable(callee1);
EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
EXPECT_EQ(callOp.getNumOperands(), 2u);
EXPECT_EQ(callOp.getOperand(0), callee1);
EXPECT_EQ(callOp.getOperand(1), argVal);
ArrayAttr argAttrs = callOp.getArgAttrsAttr();
ASSERT_TRUE(argAttrs);
ASSERT_EQ(argAttrs.size(), 2u);
EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[1]).get("test.attr"),
StringAttr::get(&context, "arg0"));
}