Skip to content

Commit 69e6455

Browse files
committed
Merge branch 'f-FixVTabForCollapsedTypes' into development
2 parents 33cc0ca + 1b558c2 commit 69e6455

10 files changed

Lines changed: 100 additions & 82 deletions

File tree

include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef PHASAR_PHASARLLVM_CONTROLFLOW_RESOLVER_RESOLVER_H_
1818
#define PHASAR_PHASARLLVM_CONTROLFLOW_RESOLVER_RESOLVER_H_
1919

20+
#include <optional>
2021
#include <set>
2122
#include <string>
2223

@@ -33,7 +34,7 @@ namespace psr {
3334
class ProjectIRDB;
3435
class LLVMTypeHierarchy;
3536

36-
int getVFTIndex(const llvm::CallBase *CallSite);
37+
std::optional<unsigned> getVFTIndex(const llvm::CallBase *CallSite);
3738

3839
const llvm::StructType *getReceiverType(const llvm::CallBase *CallSite);
3940

include/phasar/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ class LLVMTypeHierarchy
8888
using out_edge_iterator = boost::graph_traits<bidigraph_t>::out_edge_iterator;
8989
using in_edge_iterator = boost::graph_traits<bidigraph_t>::in_edge_iterator;
9090

91+
static inline constexpr llvm::StringLiteral StructPrefix = "struct.";
92+
static inline constexpr llvm::StringLiteral ClassPrefix = "class.";
93+
static inline constexpr llvm::StringLiteral VTablePrefix = "_ZTV";
94+
static inline constexpr llvm::StringLiteral VTablePrefixDemang =
95+
"vtable for ";
96+
static inline constexpr llvm::StringLiteral TypeInfoPrefix = "_ZTI";
97+
static inline constexpr llvm::StringLiteral TypeInfoPrefixDemang =
98+
"typeinfo for ";
99+
static inline constexpr llvm::StringLiteral PureVirtualCallName =
100+
"__cxa_pure_virtual";
101+
91102
private:
92103
bidigraph_t TypeGraph;
93104
std::unordered_map<const llvm::StructType *, vertex_t> TypeVertexMap;
@@ -102,18 +113,6 @@ class LLVMTypeHierarchy
102113
// map from clearname to vtable variable
103114
std::unordered_map<std::string, const llvm::GlobalVariable *> ClearNameTVMap;
104115

105-
static const std::string StructPrefix;
106-
107-
static const std::string ClassPrefix;
108-
109-
static const std::string VTablePrefix;
110-
111-
static const std::string VTablePrefixDemang;
112-
113-
static const std::string TypeInfoPrefix;
114-
115-
static const std::string TypeInfoPrefixDemang;
116-
117116
static std::string removeStructOrClassPrefix(const llvm::StructType &T);
118117

119118
static std::string removeStructOrClassPrefix(const std::string &TypeName);

include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
namespace llvm {
2121
class Function;
22+
class ConstantStruct;
2223
} // namespace llvm
2324

2425
namespace psr {
@@ -84,6 +85,9 @@ class LLVMVFTable : public VFTable<const llvm::Function *> {
8485
end() const {
8586
return VFT.end();
8687
};
88+
89+
[[nodiscard]] static std::vector<const llvm::Function *>
90+
getVFVectorFromIRVTable(const llvm::ConstantStruct &);
8791
};
8892

8993
} // namespace psr

lib/PhasarLLVM/ControlFlow/Resolver/CHAResolver.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
3737
// Leading to SEGFAULT in Unittests. Error only when run in Debug mode
3838
// << llvmIRToString(CallSite));
3939

40-
auto VFTIdx = getVFTIndex(CallSite);
41-
if (VFTIdx < 0) {
40+
auto RetrievedVtableIndex = getVFTIndex(CallSite);
41+
if (!RetrievedVtableIndex.has_value()) {
4242
// An error occured
4343
LOG_IF_ENABLE(
4444
BOOST_LOG_SEV(lg::get(), DEBUG)
@@ -50,8 +50,10 @@ auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
5050
return {};
5151
}
5252

53+
auto VtableIndex = RetrievedVtableIndex.value();
54+
5355
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
54-
<< "Virtual function table entry is: " << VFTIdx);
56+
<< "Virtual function table entry is: " << VtableIndex);
5557

5658
const auto *ReceiverTy = getReceiverType(CallSite);
5759

@@ -62,7 +64,7 @@ auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
6264

6365
for (const auto &FallbackTy : FallbackTys) {
6466
const auto *Target =
65-
getNonPureVirtualVFTEntry(FallbackTy, VFTIdx, CallSite);
67+
getNonPureVirtualVFTEntry(FallbackTy, VtableIndex, CallSite);
6668
if (Target) {
6769
PossibleCallees.insert(Target);
6870
}

lib/PhasarLLVM/ControlFlow/Resolver/DTAResolver.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ auto DTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
161161
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
162162
<< "Call virtual function: " << llvmIRToString(CallSite));
163163

164-
auto VtableIndex = getVFTIndex(CallSite);
165-
if (VtableIndex < 0) {
164+
auto RetrievedVtableIndex = getVFTIndex(CallSite);
165+
if (!RetrievedVtableIndex.has_value()) {
166166
// An error occured
167167
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
168168
<< "Error with resolveVirtualCall : impossible to retrieve "
@@ -171,6 +171,8 @@ auto DTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
171171
return {};
172172
}
173173

174+
auto VtableIndex = RetrievedVtableIndex.value();
175+
174176
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
175177
<< "Virtual function table entry is: " << VtableIndex);
176178

lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
9090
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
9191
<< "Call virtual function: " << llvmIRToString(CallSite));
9292

93-
auto VtableIndex = getVFTIndex(CallSite);
94-
if (VtableIndex < 0) {
93+
auto RetrievedVtableIndex = getVFTIndex(CallSite);
94+
if (!RetrievedVtableIndex.has_value()) {
9595
// An error occured
9696
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
9797
<< "Error with resolveVirtualCall : impossible to retrieve "
@@ -100,34 +100,42 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
100100
return {};
101101
}
102102

103+
auto VtableIndex = RetrievedVtableIndex.value();
104+
103105
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
104106
<< "Virtual function table entry is: " << VtableIndex);
105107

106108
const llvm::Value *Receiver = CallSite->getArgOperand(0);
107109

108-
// Use points-to information to resolve the indirect call
109-
auto AllocSites = PT.getReachableAllocationSites(Receiver);
110-
auto PossibleAllocatedTypes = getReachableTypes(*AllocSites);
111-
112-
// Now we must check if we have found some allocated struct types
113-
set<const llvm::StructType *> PossibleTypes;
114-
for (const auto *Type : PossibleAllocatedTypes) {
115-
if (const auto *StructType =
116-
llvm::dyn_cast<llvm::StructType>(stripPointer(Type))) {
117-
PossibleTypes.insert(StructType);
118-
}
119-
}
110+
if (CallSite->getCalledOperand() &&
111+
CallSite->getCalledOperand()->getType()->isPointerTy()) {
112+
if (const auto *FTy = llvm::dyn_cast<llvm::FunctionType>(
113+
CallSite->getCalledOperand()->getType()->getPointerElementType())) {
120114

121-
for (const auto *PossibleTypeStruct : PossibleTypes) {
122-
const auto *Target =
123-
getNonPureVirtualVFTEntry(PossibleTypeStruct, VtableIndex, CallSite);
124-
if (Target) {
125-
PossibleCallTargets.insert(Target);
115+
auto PTS = PT.getPointsToSet(CallSite->getCalledOperand(), CallSite);
116+
for (const auto *P : *PTS) {
117+
if (auto *PGV = llvm::dyn_cast<llvm::GlobalVariable>(P)) {
118+
if (PGV->hasName() &&
119+
PGV->getName().startswith(LLVMTypeHierarchy::VTablePrefix) &&
120+
PGV->hasInitializer()) {
121+
if (auto *PCS = llvm::dyn_cast<llvm::ConstantStruct>(
122+
PGV->getInitializer())) {
123+
auto VFs = LLVMVFTable::getVFVectorFromIRVTable(*PCS);
124+
if (VtableIndex >= VFs.size()) {
125+
continue;
126+
}
127+
auto *Callee = VFs[VtableIndex];
128+
if (Callee == nullptr || !Callee->hasName() ||
129+
Callee->getName() == LLVMTypeHierarchy::PureVirtualCallName) {
130+
continue;
131+
}
132+
PossibleCallTargets.insert(Callee);
133+
}
134+
}
135+
}
136+
}
126137
}
127138
}
128-
if (PossibleCallTargets.empty()) {
129-
return CHAResolver::resolveVirtualCall(CallSite);
130-
}
131139

132140
return PossibleCallTargets;
133141
}

lib/PhasarLLVM/ControlFlow/Resolver/RTAResolver.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ auto RTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
5656
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
5757
<< "Call virtual function: " << llvmIRToString(CallSite));
5858

59-
auto VtableIndex = getVFTIndex(CallSite);
60-
if (VtableIndex < 0) {
59+
auto RetrievedVtableIndex = getVFTIndex(CallSite);
60+
if (!RetrievedVtableIndex.has_value()) {
6161
// An error occured
6262
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
6363
<< "Error with resolveVirtualCall : impossible to retrieve "
@@ -66,6 +66,8 @@ auto RTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
6666
return {};
6767
}
6868

69+
auto VtableIndex = RetrievedVtableIndex.value();
70+
6971
LOG_IF_ENABLE(BOOST_LOG_SEV(lg::get(), DEBUG)
7072
<< "Virtual function table entry is: " << VtableIndex);
7173

lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Author: nicolas bellec
1515
*/
1616

17+
#include <optional>
1718
#include <set>
1819

1920
#include "llvm/IR/Constants.h"
@@ -30,23 +31,23 @@ using namespace psr;
3031

3132
namespace psr {
3233

33-
int getVFTIndex(const llvm::CallBase *CallSite) {
34+
std::optional<unsigned> getVFTIndex(const llvm::CallBase *CallSite) {
3435
// deal with a virtual member function
3536
// retrieve the vtable entry that is called
3637
const auto *Load =
3738
llvm::dyn_cast<llvm::LoadInst>(CallSite->getCalledOperand());
3839
if (Load == nullptr) {
39-
return -1;
40+
return std::nullopt;
4041
}
4142
const auto *GEP =
4243
llvm::dyn_cast<llvm::GetElementPtrInst>(Load->getPointerOperand());
4344
if (GEP == nullptr) {
44-
return -2;
45+
return std::nullopt;
4546
}
4647
if (auto *CI = llvm::dyn_cast<llvm::ConstantInt>(GEP->getOperand(1))) {
4748
return CI->getZExtValue();
4849
}
49-
return -3;
50+
return std::nullopt;
5051
}
5152

5253
const llvm::StructType *getReceiverType(const llvm::CallBase *CallSite) {

lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,6 @@ using namespace std;
4949

5050
namespace psr {
5151

52-
const std::string LLVMTypeHierarchy::StructPrefix = "struct.";
53-
54-
const std::string LLVMTypeHierarchy::ClassPrefix = "class.";
55-
56-
const std::string LLVMTypeHierarchy::VTablePrefix = "_ZTV";
57-
58-
const std::string LLVMTypeHierarchy::VTablePrefixDemang = "vtable for ";
59-
60-
const std::string LLVMTypeHierarchy::TypeInfoPrefix = "_ZTI";
61-
62-
const std::string LLVMTypeHierarchy::TypeInfoPrefixDemang = "typeinfo for ";
63-
6452
LLVMTypeHierarchy::VertexProperties::VertexProperties(
6553
const llvm::StructType *Type)
6654
: Type(Type), ReachableTypes({Type}) {}
@@ -203,29 +191,7 @@ LLVMTypeHierarchy::getVirtualFunctions(const llvm::Module &M,
203191
}
204192
if (const auto *I =
205193
llvm::dyn_cast<llvm::ConstantStruct>(TI->getInitializer())) {
206-
for (const auto &Op : I->operands()) {
207-
if (auto *CA = llvm::dyn_cast<llvm::ConstantArray>(Op)) {
208-
for (auto &COp : CA->operands()) {
209-
if (auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(COp)) {
210-
std::unique_ptr<llvm::Instruction, decltype(&deleteValue)> AsI(
211-
CE->getAsInstruction(), &deleteValue);
212-
if (auto *BC = llvm::dyn_cast<llvm::BitCastInst>(AsI.get())) {
213-
// if the entry is a GlobalAlias, get its Aliasee
214-
auto *ENTRY = BC->getOperand(0);
215-
while (auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(ENTRY)) {
216-
ENTRY = GA->getAliasee();
217-
}
218-
219-
if (ENTRY->hasName()) {
220-
if (auto *F = M.getFunction(ENTRY->getName())) {
221-
VFS.push_back(F);
222-
}
223-
}
224-
}
225-
}
226-
}
227-
}
228-
}
194+
VFS = LLVMVFTable::getVFVectorFromIRVTable(*I);
229195
}
230196
}
231197
}

lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <utility>
1313

1414
#include "llvm/IR/Function.h"
15+
#include "llvm/IR/GlobalAlias.h"
16+
#include "llvm/IR/Operator.h"
1517

1618
#include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h"
1719

@@ -48,4 +50,35 @@ nlohmann::json LLVMVFTable::getAsJson() const {
4850
return J;
4951
}
5052

53+
std::vector<const llvm::Function *>
54+
LLVMVFTable::getVFVectorFromIRVTable(const llvm::ConstantStruct &VT) {
55+
std::vector<const llvm::Function *> VFS;
56+
for (const auto &Op : VT.operands()) {
57+
if (const auto *CA = llvm::dyn_cast<llvm::ConstantArray>(Op)) {
58+
// Start iterating at offset 2, because offset 0 is vbase offset, offset 1
59+
// is RTTI
60+
for (auto It = std::next(CA->operands().begin(), 2);
61+
It != CA->operands().end(); ++It) {
62+
const auto &COp = *It;
63+
if (const auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(COp)) {
64+
if (const auto *BC = llvm::dyn_cast<llvm::BitCastOperator>(CE)) {
65+
// if the entry is a GlobalAlias, get its Aliasee
66+
auto *Entry = BC->getOperand(0);
67+
while (auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(Entry)) {
68+
Entry = GA->getAliasee();
69+
}
70+
auto *F = llvm::dyn_cast<llvm::Function>(Entry);
71+
VFS.push_back(F);
72+
} else {
73+
VFS.push_back(nullptr);
74+
}
75+
} else {
76+
VFS.push_back(nullptr);
77+
}
78+
}
79+
}
80+
}
81+
return VFS;
82+
}
83+
5184
} // namespace psr

0 commit comments

Comments
 (0)