Skip to content

Commit b8cb1b9

Browse files
author
Martin Mory
committed
fix VTables for collapsed types. llvm-link collapsed multiple C++ types with identical members to one single LLVM type. The current virtual function call resolution then only considers the VTable of the LLVM type that represents all of the collapsed type. Now, we look for all VTables that the function pointer aliases with and extract the possible callees from them.
1 parent 90fafd8 commit b8cb1b9

4 files changed

Lines changed: 61 additions & 43 deletions

File tree

include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
namespace llvm {
2121
class Function;
22+
class ConstantStruct;
2223
} // namespace llvm
2324

2425
namespace psr {
@@ -84,6 +85,9 @@ class LLVMVFTable : public VFTable<const llvm::Function *> {
8485
end() const {
8586
return VFT.end();
8687
};
88+
89+
[[nodiscard]] static std::vector<const llvm::Function *>
90+
getVFVectorFromIRVTable(const llvm::ConstantStruct *);
8791
};
8892

8993
} // namespace psr

lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,34 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
105105

106106
const llvm::Value *Receiver = CallSite->getArgOperand(0);
107107

108-
// Use points-to information to resolve the indirect call
109-
auto AllocSites = PT.getReachableAllocationSites(Receiver);
110-
auto PossibleAllocatedTypes = getReachableTypes(*AllocSites);
111-
112-
// Now we must check if we have found some allocated struct types
113-
set<const llvm::StructType *> PossibleTypes;
114-
for (const auto *Type : PossibleAllocatedTypes) {
115-
if (const auto *StructType =
116-
llvm::dyn_cast<llvm::StructType>(stripPointer(Type))) {
117-
PossibleTypes.insert(StructType);
118-
}
119-
}
108+
if (CallSite->getCalledOperand() &&
109+
CallSite->getCalledOperand()->getType()->isPointerTy()) {
110+
if (const auto *FTy = llvm::dyn_cast<llvm::FunctionType>(
111+
CallSite->getCalledOperand()->getType()->getPointerElementType())) {
120112

121-
for (const auto *PossibleTypeStruct : PossibleTypes) {
122-
const auto *Target =
123-
getNonPureVirtualVFTEntry(PossibleTypeStruct, VtableIndex, CallSite);
124-
if (Target) {
125-
PossibleCallTargets.insert(Target);
113+
auto PTS = PT.getPointsToSet(CallSite->getCalledOperand(), CallSite);
114+
for (const auto *P : *PTS) {
115+
if (auto *PGV = llvm::dyn_cast<llvm::GlobalVariable>(P)) {
116+
if (PGV->hasName() && PGV->getName().startswith("_ZTV") &&
117+
PGV->hasInitializer()) {
118+
if (auto *PCS = llvm::dyn_cast<llvm::ConstantStruct>(
119+
PGV->getInitializer())) {
120+
auto VFs = LLVMVFTable::getVFVectorFromIRVTable(PCS);
121+
if (VtableIndex < 0 || VtableIndex >= VFs.size()) {
122+
continue;
123+
}
124+
auto *Callee = VFs[VtableIndex];
125+
if (Callee == nullptr || !Callee->hasName() ||
126+
Callee->getName() == "__cxa_pure_virtual") {
127+
continue;
128+
}
129+
PossibleCallTargets.insert(Callee);
130+
}
131+
}
132+
}
133+
}
126134
}
127135
}
128-
if (PossibleCallTargets.empty()) {
129-
return CHAResolver::resolveVirtualCall(CallSite);
130-
}
131136

132137
return PossibleCallTargets;
133138
}

lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -203,29 +203,7 @@ LLVMTypeHierarchy::getVirtualFunctions(const llvm::Module &M,
203203
}
204204
if (const auto *I =
205205
llvm::dyn_cast<llvm::ConstantStruct>(TI->getInitializer())) {
206-
for (const auto &Op : I->operands()) {
207-
if (auto *CA = llvm::dyn_cast<llvm::ConstantArray>(Op)) {
208-
for (auto &COp : CA->operands()) {
209-
if (auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(COp)) {
210-
std::unique_ptr<llvm::Instruction, decltype(&deleteValue)> AsI(
211-
CE->getAsInstruction(), &deleteValue);
212-
if (auto *BC = llvm::dyn_cast<llvm::BitCastInst>(AsI.get())) {
213-
// if the entry is a GlobalAlias, get its Aliasee
214-
auto *ENTRY = BC->getOperand(0);
215-
while (auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(ENTRY)) {
216-
ENTRY = GA->getAliasee();
217-
}
218-
219-
if (ENTRY->hasName()) {
220-
if (auto *F = M.getFunction(ENTRY->getName())) {
221-
VFS.push_back(F);
222-
}
223-
}
224-
}
225-
}
226-
}
227-
}
228-
}
206+
VFS = LLVMVFTable::getVFVectorFromIRVTable(I);
229207
}
230208
}
231209
}

lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <utility>
1313

1414
#include "llvm/IR/Function.h"
15+
#include "llvm/IR/GlobalAlias.h"
16+
#include "llvm/IR/Operator.h"
1517

1618
#include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h"
1719

@@ -48,4 +50,33 @@ nlohmann::json LLVMVFTable::getAsJson() const {
4850
return J;
4951
}
5052

53+
std::vector<const llvm::Function *>
54+
LLVMVFTable::getVFVectorFromIRVTable(const llvm::ConstantStruct *VT) {
55+
std::vector<const llvm::Function *> VFS;
56+
for (const auto &Op : VT->operands()) {
57+
if (auto *CA = llvm::dyn_cast<llvm::ConstantArray>(Op)) {
58+
for (auto It = CA->operands().begin() + 2; It != CA->operands().end();
59+
++It) {
60+
auto &COp = *It;
61+
if (auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(COp)) {
62+
if (auto *BC = llvm::dyn_cast<llvm::BitCastOperator>(CE)) {
63+
// if the entry is a GlobalAlias, get its Aliasee
64+
auto *ENTRY = BC->getOperand(0);
65+
while (auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(ENTRY)) {
66+
ENTRY = GA->getAliasee();
67+
}
68+
auto *F = llvm::dyn_cast<llvm::Function>(ENTRY);
69+
VFS.push_back(F);
70+
} else {
71+
VFS.push_back(nullptr);
72+
}
73+
} else {
74+
VFS.push_back(nullptr);
75+
}
76+
}
77+
}
78+
}
79+
return VFS;
80+
}
81+
5182
} // namespace psr

0 commit comments

Comments
 (0)