[CIR] Add support for calling virtual functions (#153893)

This change adds support for calling virtual functions. This includes
adding the cir.vtable.get_virtual_fn_addr operation to lookup the
address of the function being called from an object's vtable.
This commit is contained in:
Andy Kaylor 2025-08-18 15:56:33 -07:00 committed by GitHub
parent 61a859bf6f
commit 7ac4d9bd53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 304 additions and 21 deletions

View File

@ -157,6 +157,20 @@ public:
return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), operand);
}
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,
uint64_t alignment = 0) {
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
assert(!cir::MissingFeatures::opLoadStoreVolatile());
assert(!cir::MissingFeatures::opLoadStoreMemOrder());
return cir::LoadOp::create(*this, loc, ptr, /*isDeref=*/false,
alignmentAttr);
}
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr,
uint64_t alignment) {
return createLoad(loc, ptr, alignment);
}
mlir::Value createNot(mlir::Value value) {
return create<cir::UnaryOp>(value.getLoc(), value.getType(),
cir::UnaryOpKind::Not, value);

View File

@ -1838,6 +1838,54 @@ def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
}];
}
//===----------------------------------------------------------------------===//
// VTableGetVirtualFnAddrOp
//===----------------------------------------------------------------------===//
def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
Pure
]> {
let summary = "Get a the address of a virtual function pointer";
let description = [{
The `vtable.get_virtual_fn_addr` operation retrieves the address of a
virtual function pointer from an object's vtable (__vptr).
This is an abstraction to perform the basic pointer arithmetic to get
the address of the virtual function pointer, which can then be loaded and
called.
The `vptr` operand must be a `!cir.ptr<!cir.vptr>` value, which would
have been returned by a previous call to `cir.vatble.get_vptr`. The
`index` operand is an index of the virtual function in the vtable.
The return type is a pointer-to-pointer to the function type.
Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
%5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
-> !s32i>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
!cir.ptr<!rec_C>) -> !s32i
```
}];
let arguments = (ins
Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
I64Attr:$index);
let results = (outs CIR_PointerType:$result);
let assemblyFormat = [{
$vptr `[` $index `]` attr-dict
`:` qualified(type($vptr)) `->` qualified(type($result))
}];
}
//===----------------------------------------------------------------------===//
// SetBitfieldOp
//===----------------------------------------------------------------------===//

View File

@ -95,7 +95,6 @@ struct MissingFeatures {
static bool opCallArgEvaluationOrder() { return false; }
static bool opCallCallConv() { return false; }
static bool opCallMustTail() { return false; }
static bool opCallVirtual() { return false; }
static bool opCallInAlloca() { return false; }
static bool opCallAttrs() { return false; }
static bool opCallSurroundingTry() { return false; }
@ -204,6 +203,7 @@ struct MissingFeatures {
static bool dataLayoutTypeAllocSize() { return false; }
static bool dataLayoutTypeStoreSize() { return false; }
static bool deferredCXXGlobalInit() { return false; }
static bool devirtualizeMemberFunction() { return false; }
static bool ehCleanupFlags() { return false; }
static bool ehCleanupScope() { return false; }
static bool ehCleanupScopeRequiresEHCleanup() { return false; }
@ -215,6 +215,7 @@ struct MissingFeatures {
static bool emitLValueAlignmentAssumption() { return false; }
static bool emitNullabilityCheck() { return false; }
static bool emitTypeCheck() { return false; }
static bool emitTypeMetadataCodeForVCall() { return false; }
static bool fastMathFlags() { return false; }
static bool fpConstraints() { return false; }
static bool generateDebugInfo() { return false; }

View File

@ -63,6 +63,16 @@ public:
/// parameter.
virtual bool needsVTTParameter(clang::GlobalDecl gd) { return false; }
/// Perform ABI-specific "this" argument adjustment required prior to
/// a call of a virtual function.
/// The "VirtualCall" argument is true iff the call itself is virtual.
virtual Address adjustThisArgumentForVirtualFunctionCall(CIRGenFunction &cgf,
clang::GlobalDecl gd,
Address thisPtr,
bool virtualCall) {
return thisPtr;
}
/// Build a parameter variable suitable for 'this'.
void buildThisParam(CIRGenFunction &cgf, FunctionArgList &params);
@ -100,6 +110,13 @@ public:
virtual cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd,
CharUnits vptrOffset) = 0;
/// Build a virtual function pointer in the ABI-specific way.
virtual CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf,
clang::GlobalDecl gd,
Address thisAddr,
mlir::Type ty,
SourceLocation loc) = 0;
/// Get the address point of the vtable for the given base subobject.
virtual mlir::Value
getVTableAddressPoint(BaseSubobject base,

View File

@ -79,11 +79,10 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
const Expr *base) {
assert(isa<CXXMemberCallExpr>(ce) || isa<CXXOperatorCallExpr>(ce));
if (md->isVirtual()) {
cgm.errorNYI(ce->getSourceRange(),
"emitCXXMemberOrOperatorMemberCallExpr: virtual call");
return RValue::get(nullptr);
}
// Compute the object pointer.
bool canUseVirtualCall = md->isVirtual() && !hasQualifier;
const CXXMethodDecl *devirtualizedMethod = nullptr;
assert(!cir::MissingFeatures::devirtualizeMemberFunction());
// Note on trivial assignment
// --------------------------
@ -127,7 +126,8 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
return RValue::get(nullptr);
// Compute the function type we're calling
const CXXMethodDecl *calleeDecl = md;
const CXXMethodDecl *calleeDecl =
devirtualizedMethod ? devirtualizedMethod : md;
const CIRGenFunctionInfo *fInfo = nullptr;
if (isa<CXXDestructorDecl>(calleeDecl)) {
cgm.errorNYI(ce->getSourceRange(),
@ -137,25 +137,46 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
fInfo = &cgm.getTypes().arrangeCXXMethodDeclaration(calleeDecl);
mlir::Type ty = cgm.getTypes().getFunctionType(*fInfo);
cir::FuncType ty = cgm.getTypes().getFunctionType(*fInfo);
assert(!cir::MissingFeatures::sanitizers());
assert(!cir::MissingFeatures::emitTypeCheck());
// C++ [class.virtual]p12:
// Explicit qualification with the scope operator (5.1) suppresses the
// virtual call mechanism.
//
// We also don't emit a virtual call if the base expression has a record type
// because then we know what the type is.
bool useVirtualCall = canUseVirtualCall && !devirtualizedMethod;
if (isa<CXXDestructorDecl>(calleeDecl)) {
cgm.errorNYI(ce->getSourceRange(),
"emitCXXMemberOrOperatorMemberCallExpr: destructor call");
return RValue::get(nullptr);
}
assert(!cir::MissingFeatures::sanitizers());
if (getLangOpts().AppleKext) {
cgm.errorNYI(ce->getSourceRange(),
"emitCXXMemberOrOperatorMemberCallExpr: AppleKext");
return RValue::get(nullptr);
CIRGenCallee callee;
if (useVirtualCall) {
callee = CIRGenCallee::forVirtual(ce, md, thisPtr.getAddress(), ty);
} else {
assert(!cir::MissingFeatures::sanitizers());
if (getLangOpts().AppleKext) {
cgm.errorNYI(ce->getSourceRange(),
"emitCXXMemberOrOperatorMemberCallExpr: AppleKext");
return RValue::get(nullptr);
}
callee = CIRGenCallee::forDirect(cgm.getAddrOfFunction(calleeDecl, ty),
GlobalDecl(calleeDecl));
}
if (md->isVirtual()) {
Address newThisAddr =
cgm.getCXXABI().adjustThisArgumentForVirtualFunctionCall(
*this, calleeDecl, thisPtr.getAddress(), useVirtualCall);
thisPtr.setAddress(newThisAddr);
}
CIRGenCallee callee =
CIRGenCallee::forDirect(cgm.getAddrOfFunction(md, ty), GlobalDecl(md));
return emitCXXMemberOrOperatorCall(
calleeDecl, callee, returnValue, thisPtr.getPointer(),

View File

@ -56,7 +56,12 @@ cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &info) {
}
CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
assert(!cir::MissingFeatures::opCallVirtual());
if (isVirtual()) {
const CallExpr *ce = getVirtualCallExpr();
return cgf.cgm.getCXXABI().getVirtualFunctionPointer(
cgf, getVirtualMethodDecl(), getThisAddress(), getVirtualFunctionType(),
ce ? ce->getBeginLoc() : SourceLocation());
}
return *this;
}

View File

@ -47,8 +47,9 @@ class CIRGenCallee {
Invalid,
Builtin,
PseudoDestructor,
Virtual,
Last = Builtin,
Last = Virtual
};
struct BuiltinInfoStorage {
@ -58,6 +59,12 @@ class CIRGenCallee {
struct PseudoDestructorInfoStorage {
const clang::CXXPseudoDestructorExpr *expr;
};
struct VirtualInfoStorage {
const clang::CallExpr *ce;
clang::GlobalDecl md;
Address addr;
cir::FuncType fTy;
};
SpecialKind kindOrFunctionPtr;
@ -65,6 +72,7 @@ class CIRGenCallee {
CIRGenCalleeInfo abstractInfo;
BuiltinInfoStorage builtinInfo;
PseudoDestructorInfoStorage pseudoDestructorInfo;
VirtualInfoStorage virtualInfo;
};
explicit CIRGenCallee(SpecialKind kind) : kindOrFunctionPtr(kind) {}
@ -128,7 +136,8 @@ public:
CIRGenCallee prepareConcreteCallee(CIRGenFunction &cgf) const;
CIRGenCalleeInfo getAbstractInfo() const {
assert(!cir::MissingFeatures::opCallVirtual());
if (isVirtual())
return virtualInfo.md;
assert(isOrdinary());
return abstractInfo;
}
@ -138,6 +147,39 @@ public:
return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr);
}
bool isVirtual() const { return kindOrFunctionPtr == SpecialKind::Virtual; }
static CIRGenCallee forVirtual(const clang::CallExpr *ce,
clang::GlobalDecl md, Address addr,
cir::FuncType fTy) {
CIRGenCallee result(SpecialKind::Virtual);
result.virtualInfo.ce = ce;
result.virtualInfo.md = md;
result.virtualInfo.addr = addr;
result.virtualInfo.fTy = fTy;
return result;
}
const clang::CallExpr *getVirtualCallExpr() const {
assert(isVirtual());
return virtualInfo.ce;
}
clang::GlobalDecl getVirtualMethodDecl() const {
assert(isVirtual());
return virtualInfo.md;
}
Address getThisAddress() const {
assert(isVirtual());
return virtualInfo.addr;
}
cir::FuncType getVirtualFunctionType() const {
assert(isVirtual());
return virtualInfo.fTy;
}
void setFunctionPointer(mlir::Operation *functionPtr) {
assert(isOrdinary());
kindOrFunctionPtr = SpecialKind(reinterpret_cast<uintptr_t>(functionPtr));

View File

@ -657,6 +657,20 @@ Address CIRGenFunction::getAddressOfBaseClass(
return value;
}
// TODO(cir): this can be shared with LLVM codegen.
bool CIRGenFunction::shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd) {
assert(!cir::MissingFeatures::hiddenVisibility());
if (!cgm.getCodeGenOpts().WholeProgramVTables)
return false;
if (cgm.getCodeGenOpts().VirtualFunctionElimination)
return true;
assert(!cir::MissingFeatures::sanitizers());
return false;
}
mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
const CXXRecordDecl *rd) {
auto vtablePtr = cir::VTableGetVPtrOp::create(

View File

@ -552,6 +552,11 @@ public:
mlir::Value getVTablePtr(mlir::Location loc, Address thisAddr,
const clang::CXXRecordDecl *vtableClass);
/// Returns whether we should perform a type checked load when loading a
/// virtual function for virtual calls to members of RD. This is generally
/// true when both vcall CFI and whole-program-vtables are enabled.
bool shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd);
/// A scope within which we are constructing the fields of an object which
/// might use a CXXDefaultInitExpr. This stashes away a 'this' value to use if
/// we need to evaluate the CXXDefaultInitExpr within the evaluation.

View File

@ -69,6 +69,10 @@ public:
cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd,
CharUnits vptrOffset) override;
CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf,
clang::GlobalDecl gd, Address thisAddr,
mlir::Type ty,
SourceLocation loc) override;
mlir::Value getVTableAddressPoint(BaseSubobject base,
const CXXRecordDecl *vtableClass) override;
@ -349,6 +353,50 @@ cir::GlobalOp CIRGenItaniumCXXABI::getAddrOfVTable(const CXXRecordDecl *rd,
return vtable;
}
CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
CIRGenFunction &cgf, clang::GlobalDecl gd, Address thisAddr, mlir::Type ty,
SourceLocation srcLoc) {
CIRGenBuilderTy &builder = cgm.getBuilder();
mlir::Location loc = cgf.getLoc(srcLoc);
cir::PointerType tyPtr = builder.getPointerTo(ty);
auto *methodDecl = cast<CXXMethodDecl>(gd.getDecl());
mlir::Value vtable = cgf.getVTablePtr(loc, thisAddr, methodDecl->getParent());
uint64_t vtableIndex = cgm.getItaniumVTableContext().getMethodVTableIndex(gd);
mlir::Value vfunc{};
if (cgf.shouldEmitVTableTypeCheckedLoad(methodDecl->getParent())) {
cgm.errorNYI(loc, "getVirtualFunctionPointer: emitVTableTypeCheckedLoad");
} else {
assert(!cir::MissingFeatures::emitTypeMetadataCodeForVCall());
mlir::Value vfuncLoad;
if (cgm.getItaniumVTableContext().isRelativeLayout()) {
assert(!cir::MissingFeatures::vtableRelativeLayout());
cgm.errorNYI(loc, "getVirtualFunctionPointer: isRelativeLayout");
} else {
auto vtableSlotPtr = cir::VTableGetVirtualFnAddrOp::create(
builder, loc, builder.getPointerTo(tyPtr), vtable, vtableIndex);
vfuncLoad = builder.createAlignedLoad(
loc, vtableSlotPtr, cgf.getPointerAlign().getQuantity());
}
// Add !invariant.load md to virtual function load to indicate that
// function didn't change inside vtable.
// It's safe to add it without -fstrict-vtable-pointers, but it would not
// help in devirtualization because it will only matter if we will have 2
// the same virtual function loads from the same vtable load, which won't
// happen without enabled devirtualization with -fstrict-vtable-pointers.
if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
cgm.getCodeGenOpts().StrictVTablePointers) {
cgm.errorNYI(loc, "getVirtualFunctionPointer: strictVTablePointers");
}
vfunc = vfuncLoad;
}
CIRGenCallee callee(gd, vfunc.getDefiningOp());
return callee;
}
mlir::Value
CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject base,
const CXXRecordDecl *vtableClass) {

View File

@ -212,6 +212,14 @@ public:
return Address(getPointer(), elementType, getAlignment());
}
void setAddress(Address address) {
assert(isSimple());
v = address.getPointer();
elementType = address.getElementType();
alignment = address.getAlignment().getQuantity();
assert(!cir::MissingFeatures::addressIsKnownNonNull());
}
const clang::Qualifiers &getQuals() const { return quals; }
clang::Qualifiers &getQuals() { return quals; }

View File

@ -1105,8 +1105,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
auto calleeTy = op->getOperands().front().getType();
auto calleePtrTy = cast<cir::PointerType>(calleeTy);
auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
calleeFuncTy.dump();
converter->convertType(calleeFuncTy).dump();
llvm::append_range(adjustedCallOperands, callOperands);
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
converter->convertType(calleeFuncTy));
}
@ -2231,6 +2230,9 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
return llvmStruct;
});
converter.addConversion([&](cir::VoidType type) -> mlir::Type {
return mlir::LLVM::LLVMVoidType::get(type.getContext());
});
}
// The applyPartialConversion function traverses blocks in the dominance order,
@ -2385,7 +2387,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecSplatOpLowering,
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMVTableAddrPointOpLowering,
CIRToLLVMVTableGetVPtrOpLowering
CIRToLLVMVTableGetVPtrOpLowering,
CIRToLLVMVTableGetVirtualFnAddrOpLowering
// clang-format on
>(converter, patterns.getContext());
@ -2521,6 +2524,19 @@ mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
return mlir::success();
}
mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite(
cir::VTableGetVirtualFnAddrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type targetType = getTypeConverter()->convertType(op.getType());
auto eltType = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
llvm::SmallVector<mlir::LLVM::GEPArg> offsets =
llvm::SmallVector<mlir::LLVM::GEPArg>{op.getIndex()};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, targetType, eltType, adaptor.getVptr(), offsets,
mlir::LLVM::GEPNoWrapFlags::inbounds);
return mlir::success();
}
mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
cir::StackSaveOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {

View File

@ -497,6 +497,17 @@ public:
mlir::ConversionPatternRewriter &) const override;
};
class CIRToLLVMVTableGetVirtualFnAddrOpLowering
: public mlir::OpConversionPattern<cir::VTableGetVirtualFnAddrOp> {
public:
using mlir::OpConversionPattern<
cir::VTableGetVirtualFnAddrOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(cir::VTableGetVirtualFnAddrOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};
class CIRToLLVMStackSaveOpLowering
: public mlir::OpConversionPattern<cir::StackSaveOp> {
public:

View File

@ -46,3 +46,36 @@ A::A() {}
// NOTE: The GEP in OGCG looks very different from the one generated with CIR,
// but it is equivalent. The OGCG GEP indexes by base pointer, then
// structure, then array, whereas the CIR GEP indexes by byte offset.
void f1(A *a) {
a->f('c');
}
// CIR: cir.func{{.*}} @_Z2f1P1A(%arg0: !cir.ptr<!rec_A> {{.*}})
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>
// CIR: cir.store %arg0, %[[A_ADDR]]
// CIR: %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]]
// CIR: %[[C_LITERAL:.*]] = cir.const #cir.int<99> : !s8i
// CIR: %[[VPTR_ADDR:.*]] = cir.vtable.get_vptr %[[A]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr>
// CIR: %[[VPTR:.*]] = cir.load{{.*}} %[[VPTR_ADDR]] : !cir.ptr<!cir.vptr>, !cir.vptr
// CIR: %[[FN_PTR_PTR:.*]] = cir.vtable.get_virtual_fn_addr %[[VPTR]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>>
// CIR: %[[FN_PTR:.*]] = cir.load{{.*}} %[[FN_PTR_PTR:.*]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>
// CIR: cir.call %[[FN_PTR]](%[[A]], %[[C_LITERAL]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>, !cir.ptr<!rec_A>, !s8i) -> ()
// LLVM: define{{.*}} void @_Z2f1P1A(ptr %[[ARG0:.*]])
// LLVM: %[[A_ADDR:.*]] = alloca ptr
// LLVM: store ptr %[[ARG0]], ptr %[[A_ADDR]]
// LLVM: %[[A:.*]] = load ptr, ptr %[[A_ADDR]]
// LLVM: %[[VPTR:.*]] = load ptr, ptr %[[A]]
// LLVM: %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i32 0
// LLVM: %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]]
// LLVM: call void %[[FN_PTR]](ptr %[[A]], i8 99)
// OGCG: define{{.*}} void @_Z2f1P1A(ptr {{.*}} %[[ARG0:.*]])
// OGCG: %[[A_ADDR:.*]] = alloca ptr
// OGCG: store ptr %[[ARG0]], ptr %[[A_ADDR]]
// OGCG: %[[A:.*]] = load ptr, ptr %[[A_ADDR]]
// OGCG: %[[VPTR:.*]] = load ptr, ptr %[[A]]
// OGCG: %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i64 0
// OGCG: %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]]
// OGCG: call void %[[FN_PTR]](ptr {{.*}} %[[A]], i8 {{.*}} 99)