@@ -38,15 +38,13 @@ LLVMBasedBackwardsICFG::LLVMBasedBackwardsICFG(LLVMBasedICFG &ICFG)
3838 : ForwardICFG(ICFG) {
3939 auto CgCopy = ForwardICFG.CallGraph ;
4040 boost::copy_graph (boost::make_reverse_graph (CgCopy), ForwardICFG.CallGraph );
41+ createBackwardRets ();
4142}
4243
43- LLVMBasedBackwardsICFG::LLVMBasedBackwardsICFG (
44- ProjectIRDB &IRDB, CallGraphAnalysisType CGType,
45- const std::set<std::string> &EntryPoints, LLVMTypeHierarchy *TH,
46- LLVMPointsToInfo *PT, Soundness S)
47- : ForwardICFG(IRDB, CGType, EntryPoints, TH, PT, S) {
48- auto CgCopy = ForwardICFG.CallGraph ;
49- boost::copy_graph (boost::make_reverse_graph (CgCopy), ForwardICFG.CallGraph );
44+ void LLVMBasedBackwardsICFG::createBackwardRets () {
45+ for (const auto *Function : getAllFunctions ()) {
46+ BackwardRetToFunction[BackwardRets[Function].getInstance ()] = Function;
47+ }
5048}
5149
5250bool LLVMBasedBackwardsICFG::isIndirectFunctionCall (
@@ -88,15 +86,11 @@ std::set<const llvm::Instruction *>
8886LLVMBasedBackwardsICFG::getReturnSitesOfCallAt (
8987 const llvm::Instruction *N) const {
9088 std::set<const llvm::Instruction *> ReturnSites;
91- if (const auto *Call = llvm::dyn_cast<llvm::CallInst >(N)) {
89+ if (const auto *Call = llvm::dyn_cast<llvm::CallBase >(N)) {
9290 for (const auto *Succ : this ->getSuccsOf (Call)) {
9391 ReturnSites.insert (Succ);
9492 }
9593 }
96- if (const auto *Invoke = llvm::dyn_cast<llvm::InvokeInst>(N)) {
97- ReturnSites.insert (&Invoke->getNormalDest ()->back ());
98- ReturnSites.insert (&Invoke->getUnwindDest ()->back ());
99- }
10094 return ReturnSites;
10195}
10296
@@ -105,6 +99,51 @@ LLVMBasedBackwardsICFG::allNonCallStartNodes() const {
10599 return ForwardICFG.allNonCallStartNodes ();
106100}
107101
102+ const llvm::Function *
103+ LLVMBasedBackwardsICFG::getFunctionOf (const llvm::Instruction *Stmt) const {
104+ auto BackwardRetIt = BackwardRetToFunction.find (Stmt);
105+ if (BackwardRetIt != BackwardRetToFunction.end ()) {
106+ return BackwardRetIt->second ;
107+ }
108+ return Stmt->getFunction ();
109+ }
110+
111+ std::vector<const llvm::Instruction *>
112+ LLVMBasedBackwardsICFG::getPredsOf (const llvm::Instruction *Stmt) const {
113+ auto BackwardRetIt = BackwardRetToFunction.find (Stmt);
114+ if (BackwardRetIt == BackwardRetToFunction.end ()) {
115+ return LLVMBasedBackwardCFG::getPredsOf (Stmt);
116+ }
117+ auto ExitPoints =
118+ LLVMBasedBackwardCFG::getExitPointsOf (BackwardRetIt->second );
119+ return {ExitPoints.begin (), ExitPoints.end ()};
120+ }
121+
122+ std::vector<const llvm::Instruction *>
123+ LLVMBasedBackwardsICFG::getSuccsOf (const llvm::Instruction *Stmt) const {
124+ if (isExitInst (Stmt)) {
125+ return {};
126+ }
127+ std::vector<const llvm::Instruction *> Succs =
128+ LLVMBasedBackwardCFG::getSuccsOf (Stmt);
129+ if (Succs.size () == 0 ) {
130+ assert (Stmt->getParent ()->getParent () && " Could not find parent of stmt's parent " );
131+ Succs.push_back (
132+ BackwardRets.at (Stmt->getParent ()->getParent ()).getInstance ());
133+ }
134+ return Succs;
135+ }
136+
137+ bool LLVMBasedBackwardsICFG::isExitInst (const llvm::Instruction *Stmt) const {
138+ return (Stmt->getParent () == nullptr &&
139+ BackwardRetToFunction.count (Stmt) > 0 );
140+ }
141+
142+ std::set<const llvm::Instruction *>
143+ LLVMBasedBackwardsICFG::getExitPointsOf (const llvm::Function *Fun) const {
144+ return {BackwardRets.at (Fun).getInstance ()};
145+ }
146+
108147void LLVMBasedBackwardsICFG::mergeWith (const LLVMBasedBackwardsICFG &Other) {
109148 ForwardICFG.mergeWith (Other.ForwardICFG );
110149}
0 commit comments