diff --git a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py index 0233564a..7f61f5a4 100644 --- a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py +++ b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py @@ -227,6 +227,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa if parents: direct_parents.extend(parents) function_name_to_search = self.language_parser.get_function_name(document_function) + if not function_name_to_search: + return None if function_name_to_search == self.language_parser.get_constructor_method_name(): function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function) function_file_name = document_function.metadata.get('source') @@ -319,6 +321,8 @@ def get_possible_docs(self, function_name_to_search: str, package: str, exclusio (self.language_parser.is_function(doc) or self.language_parser.is_script_language()) and not self._is_doc_excluded(doc, exclusions)] + if not function_name_to_search: + return [] return [doc for doc in filter_1 if doc.page_content.__contains__(f"{function_name_to_search}(")] def __find_caller_functions_bfs(self, document_function: Document, function_package: str, @@ -344,6 +348,8 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack # direct_parents.extend([function_package]) # gets list of documents to search in only from parents of function' package. function_name_to_search = self.language_parser.get_function_name(document_function) + if not function_name_to_search: + return [] function_file_name = document_function.metadata.get('source') relevant_docs_to_search_in = list() # Search for caller functions only at parents according to dependency tree. @@ -365,7 +371,12 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack file_name = doc.metadata.get('source') if doc.metadata.get('state') == "invalid": continue - func_name = self.language_parser.get_function_name(doc) + try: + func_name = self.language_parser.get_function_name(doc) + except ValueError: + continue + if not func_name: + continue # check for same doc if (function_name_to_search == func_name) and (file_name == function_file_name): continue @@ -438,6 +449,8 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct logger.debug("get_relevant_documents: invalid function %s", target_doc.metadata['source']) continue function_name = self.language_parser.get_function_name(target_doc) + if not function_name: + continue function_file = target_doc.metadata.get('source') hashed_value = calculate_hashable_string_for_function(function_file, function_name) @@ -577,9 +590,9 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: root_package = [key for (key, value) in self.tree_dict.items() if ROOT_LEVEL_SENTINEL in value] prefix_of_3rd_parties_libs = self.language_parser.dir_name_for_3rd_party_packages() # find all parents ( all importing packages) of the ibput package so we'll have candidate pkgs to search in. - parents = list({self.language_parser.get_package_names(doc)[1] for doc in importing_docs if + parents = list({self.language_parser.get_package_names(doc)[-1] for doc in importing_docs if doc.metadata['source'].startswith( - prefix_of_3rd_parties_libs) and self.language_parser.get_package_names(doc)[1] + prefix_of_3rd_parties_libs) and self.language_parser.get_package_names(doc)[-1] in self.tree_dict.keys()}) for doc in importing_docs: if not doc.metadata.get('source').startswith(prefix_of_3rd_parties_libs): @@ -651,7 +664,11 @@ def __find_initial_function(self, function_name: str, package_name: str, documen self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=True), self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=False), ): - if function_name.lower() == self.language_parser.get_function_name(document).lower(): + doc_func_name = self.language_parser.get_function_name(document) + if not doc_func_name: + logger.warning("Skipping document with empty function name: %s", document.metadata.get('source', '')) + continue + if function_name.lower() == doc_func_name.lower(): package_exclusions.append(document) return document @@ -684,7 +701,10 @@ def print_call_hierarchy(self, call_hierarchy_list: list[Document]) -> list[str] package_name = package_function.metadata['source'] try: function_name = self.language_parser.get_function_name(package_function) - current_level = f"(package={package_name},function={function_name},depth={i})" + if not function_name: + current_level = f"(document={package_name},depth={i})" + else: + current_level = f"(package={package_name},function={function_name},depth={i})" except ValueError: current_level = f"(document={package_name},depth={i})" results.append(current_level) diff --git a/src/exploit_iq_commons/utils/document_embedding.py b/src/exploit_iq_commons/utils/document_embedding.py index ff068171..ad805cd1 100644 --- a/src/exploit_iq_commons/utils/document_embedding.py +++ b/src/exploit_iq_commons/utils/document_embedding.py @@ -147,6 +147,10 @@ def lazy_parse(self, blob: Blob) -> typing.Iterator[Document]: ) return + segmenter_cls = self.LANGUAGE_SEGMENTERS.get(language) + if segmenter_cls is not None and hasattr(segmenter_cls, "should_skip") and isinstance(blob.source, str) and segmenter_cls.should_skip(blob.source): + return + if self.parser_threshold >= len(code.splitlines()): yield Document( page_content=code, diff --git a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py index 6ed0895d..e2b37612 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py +++ b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py @@ -23,73 +23,119 @@ logger = LoggingFactory.get_agent_logger(__name__) +# Keywords whose control-flow syntax resembles a function call — if(cond), +# for(...), while(...), new Foo(), typeof x — causing false positives in the +# fallback pattern of get_function_name. Keywords like case, default, export, +# import, in, super are deliberately excluded: they never appear with call-like +# name(...) syntax, and they ARE valid method names (obj.default(), query.in()). +_JS_KEYWORDS = frozenset({ + 'if', 'for', 'while', 'switch', 'catch', 'with', 'do', 'else', + 'return', 'throw', 'new', 'delete', 'typeof', 'void', + 'try', 'finally', 'function', 'class', 'extends', + 'break', 'continue', 'var', 'let', 'const', 'yield', 'await', + 'debugger', 'instanceof', +}) + class JavaScriptFunctionsParser(LanguageFunctionsParser): + def __init__(self): + self._class_hierarchy_cache: dict[str, tuple[str | None, str | None]] = {} + self._class_hierarchy_cache_key: int = -1 + def get_function_name(self, function: Document) -> str: - """ - Extract function name from JavaScript code. - - Handles various JavaScript function patterns: - - function myFunc() {...} - - const myFunc = () => {...} - - async function myFunc() {...} - - class methods - - arrow functions without parentheses: const name = param => ... - - computed property methods: [Symbol.iterator]() {...} - """ + """Extract function name from JavaScript/TypeScript code.""" content = function.page_content if not self.is_function(function): raise ValueError('Only function document is supported') - # Try to match function declarations: function name(...) or async function name(...) - match = re.search(r'(?:async\s+)?function\s+(\w+)\s*\(', content) + if re.match(r'export\s+default\s+(?:async\s+)?function\s*\*?\s*\(', content): + return '' + + if re.match(r'export\s+default\s+(?:async\s+)?\(', content): + return '' + + if re.match(r'export\s+default\s+(?:async\s+)?[\w$]+\s*=>', content): + return '' + + body_start = content.find('{') + header = content[:body_start] if body_start != -1 else content + # For arrow functions with destructured params like ({a, b}) => {body}, + # the first '{' is inside the params, not the body start. + # Extend header to include '=>' only when followed by '{' or end (top-level arrow). + arrow_header = header + if body_start != -1 and not header.rstrip().endswith(')') and not re.match(r'(?:(?:async|static|get|set)\s+)*[\w$]+\s*\(', header): + search_from = body_start + while search_from < len(content): + arrow_pos = content.find('=>', search_from) + if arrow_pos == -1 or arrow_pos > body_start + 500: + break + after_arrow = content[arrow_pos + 2:].lstrip() + if after_arrow.startswith('{') or not after_arrow: + arrow_header = content[:arrow_pos + 2] + break + search_from = arrow_pos + 2 + + match = re.search(r'(?:async\s+)?function\b\s*\*?\s*([\w$]+)\s*\(', header) + if match: + return match.group(1) + + if '=>' in arrow_header: + match = re.search(r'(?:const|let|var)\s+([\w$]+)\s*=', arrow_header) + if match: + return match.group(1) + + match = re.search(r'(?:const|let|var)\s+([\w$]+)\s*=\s*(?:async\s+)?[\w$]+\s*=>', header) if match: return match.group(1) - # Try to match arrow functions with parentheses: const name = (...) => or const name = async (...) => - match = re.search(r'(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>', content) + match = re.search(r'(?:const|let|var)\s+([\w$]+)\s*=\s*(?:async\s+)?function\s*\(', header) if match: return match.group(1) - # Try to match arrow functions without parentheses (single parameter): const name = param => ... - # This pattern matches: const name = identifier => - match = re.search(r'(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\w+\s*=>', content) + match = re.search(r'(?:const|let|var)\s+([\w$]+)\s*=\s*[\w$.]+\s*\(', header) if match: return match.group(1) - # Try to match function expressions: const name = function(...) { - match = re.search(r'(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?function\s*\(', content) + match = re.search(r'([\w$]+)\s*=\s*(?:async\s+)?function\s*\*?\s*[\w$]*\s*\(', header) if match: return match.group(1) - # Try to match wrapper calls: var name = wrapper(function(...) { or var name = wrapper(() => - match = re.search(r'(?:const|let|var)\s+(\w+)\s*=\s*\w+\s*\(', content) + if '=>' in arrow_header: + match = re.search(r'([\w$]+)\s*=(?!>)', arrow_header) + if match: + return match.group(1) + + match = re.search(r'([\w$]+)\s*=\s*(?:async\s+)?[\w$]+\s*=>', header) if match: return match.group(1) - # Try to match property assignment functions: obj.name = function(...) { - match = re.search(r'(?:[\w.]+)\.(\w+)\s*=\s*(?:async\s+)?function\s*\(', content) + if '=>' in arrow_header: + match = re.search(r'^([\w$]+)\s*:', arrow_header, re.MULTILINE) + if match: + return match.group(1) + + match = re.search(r'^([\w$]+)\s*:\s*(?:async\s+)?[\w$]+\s*=>', content, re.MULTILINE) if match: return match.group(1) - # Try to match property assignment arrow functions: obj.name = (...) => - match = re.search(r'(?:[\w.]+)\.(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>', content) + match = re.search(r'^([\w$]+)\s*:\s*(?:async\s+)?function\s*\(', content, re.MULTILINE) if match: return match.group(1) - # Try to match computed property methods: [Symbol.iterator]() { or [expr]() { - # Return the computed expression inside brackets - match = re.search(r'^\s*\[([^\]]+)\]\s*\([^)]*\)\s*\{', content, re.MULTILINE) + match = re.search(r"""^["']([\w$.!-]+)["']\s*[:(\s]""", content, re.MULTILINE) if match: return match.group(1) - # Try to match class methods: methodName(...) { or async methodName(...) { - match = re.search(r'(?:async\s+)?(\w+)\s*\([^)]*\)\s*\{', content) + match = re.search(r'^\s*\*?\s*(?:static\s+)?(?:async\s+)?(?:get\s+|set\s+)?\[([^\]]+)\]\s*\(', content, re.MULTILINE) if match: return match.group(1) + for match in re.finditer(r'(?:static\s+)?(?:async\s+)?(?:get\s+|set\s+)?([\w$]+)\s*[<(]', content): + if match.group(1) not in _JS_KEYWORDS: + return match.group(1) + logger.warning(f"Could not extract function name from: {content[:100]}") return '' @@ -119,7 +165,10 @@ def is_function(self, function: Document) -> bool: content = function.page_content.strip() - if re.match(r'^\s*(?:export\s+(?:default\s+)?)?class\s+\w+', content): + if re.match(r'^\s*(?:export\s+(?:default\s+)?)?class\b', content): + return False + + if re.match(r'^\s*(?:export\s+)?(?:declare\s+)?(?:const\s+)?(?:interface|enum)\s+', content): return False return True @@ -130,31 +179,36 @@ def supported_files_extensions(self) -> list[str]: def _get_function_calls(self, caller_function: Document, callee_function_name: str, code_documents: dict[str, Document] = None) -> list[str]: """ Extract all function calls matching the given callee_function_name from caller_function. - + Detects both: - Direct function calls: functionName() or obj.functionName() - Callback references: setTimeout(functionName, ...) or arr.forEach(functionName) - + Also handles aliased imports in JavaScript: - ES6: import { originalName as alias } from 'module' - CommonJS: const { originalName: alias } = require('module') - + Args: caller_function: Document containing the caller function code callee_function_name: Name of the function to find calls to code_documents: Optional dict mapping source paths to full file Documents (for alias resolution) - + Returns: List of function call patterns found (e.g., ['func', 'obj.func']) """ + if not callee_function_name: + return [] + content = caller_function.page_content content = '\n'.join([line if not self.is_comment_line(line) else ''for line in content.splitlines()]) - direct_call_pattern = rf'((?:[\w.()]+\.)?(? bool: - calls = self._get_function_calls(caller_function, callee_function_name) + calls = self._get_function_calls(caller_function, callee_function_name, code_documents) if not calls: return False @@ -242,10 +296,59 @@ def search_for_called_function(self, caller_function: Document, callee_function_ return False + @staticmethod + def _build_class_hierarchy(code_documents: dict[str, Document]) -> dict[str, tuple[str | None, str | None]]: + """ + Build a class hierarchy index from code_documents in a single pass. + + Returns: + dict mapping child_class -> (extends_clause, immediate_parent) + extends_clause is the raw text after 'extends' (for mixin matching), + immediate_parent is the resolved parent class name. + """ + hierarchy: dict[str, tuple[str | None, str | None]] = {} + + es6_pattern = re.compile(r'class\s+([\w$]+)\s+extends\s+([^{]+)') + proto_patterns = [ + re.compile(r'(?:util\.)?inherits\s*\(\s*([\w$]+)\s*,\s*([\w$]+)\s*\)'), + re.compile(r'([\w$]+)\.prototype\s*=\s*Object\.create\s*\(\s*([\w$]+)\.prototype\s*\)'), + re.compile(r'Object\.setPrototypeOf\s*\(\s*([\w$]+)\.prototype\s*,\s*([\w$]+)\.prototype\s*\)'), + ] + + for doc in code_documents.values(): + content = doc.page_content + + for match in es6_pattern.finditer(content): + child = match.group(1) + extends_clause = match.group(2).strip() + if innermost := re.search(r'([\w$]+)\s*\)+\s*$', extends_clause): + parent = innermost.group(1) + elif simple := re.match(r'^([\w$]+)\s*$', extends_clause): + parent = simple.group(1) + else: + parent = None + hierarchy[child] = (extends_clause, parent) + + for pattern in proto_patterns: + for match in pattern.finditer(content): + child = match.group(1) + parent = match.group(2) + if child not in hierarchy: + hierarchy[child] = (None, parent) + + return hierarchy + + def _get_class_hierarchy(self, code_documents: dict[str, Document]) -> dict[str, tuple[str | None, str | None]]: + cache_key = id(code_documents) + if self._class_hierarchy_cache_key != cache_key: + self._class_hierarchy_cache = self._build_class_hierarchy(code_documents) + self._class_hierarchy_cache_key = cache_key + return self._class_hierarchy_cache + def _is_subclass_of(self, child_class: str, parent_class: str, code_documents: dict[str, Document], visited: set = None) -> bool: """ - Check if child_class extends parent_class by searching simplified_code documents. + Check if child_class extends parent_class using pre-built hierarchy index. Supports: - Simple: class X extends Y - Mixin: class X extends Mixin(Y) @@ -260,95 +363,71 @@ def _is_subclass_of(self, child_class: str, parent_class: str, code_documents: d visited = set() if child_class in visited: - return False # Prevent infinite recursion on circular references + return False visited.add(child_class) - es6_class_pattern = rf'class\s+{re.escape(child_class)}\s+extends\s+([^{{]+)' + hierarchy = self._get_class_hierarchy(code_documents) + entry = hierarchy.get(child_class) + if not entry: + return False - for doc in code_documents.values(): - if match := re.search(es6_class_pattern, doc.page_content): - extends_clause = match.group(1).strip() - # Direct match in extends clause (handles mixins too) - if re.search(rf'\b{re.escape(parent_class)}\b', extends_clause): - return True + extends_clause, immediate_parent = entry + + if extends_clause and re.search(rf'\b{re.escape(parent_class)}\b', extends_clause): + return True - immediate_parent = self._get_parent(child_class, code_documents) if immediate_parent: - # Direct match: immediate parent is the target if immediate_parent == parent_class: return True - # Transitive check: follow the chain return self._is_subclass_of(immediate_parent, parent_class, code_documents, visited) return False def _get_direct_parent(self, child_class: str, code_documents: dict[str, Document]) -> str | None: - """ - Extract the immediate parent class name from ES6 class extends syntax. - - Handles: - - Simple: class X extends Y -> returns Y - - Mixin: class X extends Mixin(Y) -> returns Y (the innermost/base class) - - Chained: class X extends Mixin1(Mixin2(Y)) -> returns Y - """ - class_pattern = rf'class\s+{re.escape(child_class)}\s+extends\s+([^{{]+)' - - for doc in code_documents.values(): - if match := re.search(class_pattern, doc.page_content): - extends_clause = match.group(1).strip() - # Find the last identifier before closing parentheses - if innermost_match := re.search(r'(\w+)\s*\)+\s*$', extends_clause): - return innermost_match.group(1) - # Simple case: class X extends Y (no parentheses) - if parent_match := re.match(r'^(\w+)\s*$', extends_clause): - return parent_match.group(1) + """Extract the immediate parent from ES6 class extends syntax using the hierarchy index.""" + hierarchy = self._get_class_hierarchy(code_documents) + entry = hierarchy.get(child_class) + if entry and entry[0] is not None: + return entry[1] return None def _get_prototype_parent(self, child_class: str, code_documents: dict[str, Document]) -> str | None: - """ - Extract parent class from prototype-based inheritance patterns. - - Handles: - - util.inherits(Child, Parent) - - inherits(Child, Parent) - - Child.prototype = Object.create(Parent.prototype) - - Object.setPrototypeOf(Child.prototype, Parent.prototype) - """ - patterns = [ - # util.inherits(Child, Parent) or inherits(Child, Parent) - rf'(?:util\.)?inherits\s*\(\s*{re.escape(child_class)}\s*,\s*(\w+)\s*\)', - # Child.prototype = Object.create(Parent.prototype) - rf'{re.escape(child_class)}\.prototype\s*=\s*Object\.create\s*\(\s*(\w+)\.prototype\s*\)', - # Object.setPrototypeOf(Child.prototype, Parent.prototype) - rf'Object\.setPrototypeOf\s*\(\s*{re.escape(child_class)}\.prototype\s*,\s*(\w+)\.prototype\s*\)', - ] - - for doc in code_documents.values(): - for pattern in patterns: - if match := re.search(pattern, doc.page_content): - return match.group(1) + """Extract parent from prototype-based inheritance using the hierarchy index.""" + hierarchy = self._get_class_hierarchy(code_documents) + entry = hierarchy.get(child_class) + if entry and entry[0] is None and entry[1] is not None: + return entry[1] return None def _get_parent(self, child_class: str, code_documents: dict[str, Document]) -> str | None: - return self._get_direct_parent(child_class, code_documents) or \ - self._get_prototype_parent(child_class, code_documents) + hierarchy = self._get_class_hierarchy(code_documents) + entry = hierarchy.get(child_class) + if entry: + return entry[1] + return None def create_map_of_local_vars(self, functions_methods_documents: list[Document]) -> dict[str, dict]: mappings = {} - + for func_method in functions_methods_documents: - func_key = f"{self.get_function_name(func_method)}@{func_method.metadata['source']}" + try: + func_name = self.get_function_name(func_method) + except ValueError: + continue + if not func_name: + continue + func_key = f"{func_name}@{func_method.metadata.get('source', '?')}" content = func_method.page_content - + all_vars = {} - + # Extract parameters from function signature param_match = re.search(r'\(([^)]*)\)', content) if param_match: all_vars.update(self._parse_declarations(param_match.group(1), is_param=True)) - + all_vars['return_types'] = [] - + # Extract local variables from function body if param_match: first_brace = content.find('{', param_match.end()) @@ -362,13 +441,13 @@ def create_map_of_local_vars(self, functions_methods_documents: list[Document]) if not self.is_comment_line(statement): if match := re.match(r'(const|let|var)\s+(.+)', statement): all_vars.update(self._parse_declarations(match.group(2), is_param=False)) - + # Add 'this' reference for class/object methods if class_name := self.get_class_name_from_class_function(func_method): all_vars['this'] = {"value": f'{class_name}()', "type": class_name} - + mappings[func_key] = all_vars - + return mappings @staticmethod @@ -724,6 +803,8 @@ def _trace_variable_to_value(self, variable_name: str, lines: list[str], depth: return '' def is_package_imported(self, code_content: str, identifier: str, callee_package: str = "") -> bool: + if not identifier: + return False if callee_package and callee_package not in code_content: return False @@ -745,7 +826,7 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package else: # import { template } from 'lodash' # import * as lodash from 'lodash' - before_from = line.split('from')[0] + before_from = line.split(' from ')[0] if re.search(rf'\b{re.escape(identifier)}\b', before_from): return True else: @@ -768,7 +849,7 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package if 'export' in line and 'from' in line: if callee_package and (f"'{callee_package}'" in line or f'"{callee_package}"' in line): if identifier: - before_from = line.split('from')[0] + before_from = line.split(' from ')[0] if re.search(rf'\b{re.escape(identifier)}\b', before_from): return True else: @@ -911,11 +992,42 @@ def _check_package_reexport(self, function_name: str, source_file: str, document return False + @staticmethod + def _is_exported_by_name(name: str, full_file_content: str) -> bool: + """Check if a name is exported via ES6 named export, export default, or CommonJS.""" + escaped = re.escape(name) + + # ES6: export default name + if re.search(rf'export\s+default\s+{escaped}\b', full_file_content): + return True + + # ES6: export { name } or export { x as name } + if re.search(rf'export\s+\{{[^}}]*\b{escaped}\b[^}}]*\}}', full_file_content): + return True + + # CommonJS: module.exports = name + if re.search(rf'module\.exports\s*=\s*{escaped}\b', full_file_content): + return True + + # CommonJS: module.exports = { name } or module.exports = { name: name } + if re.search(rf'module\.exports\s*=\s*\{{[^}}]*\b{escaped}\b[^}}]*\}}', full_file_content): + return True + + # CommonJS: module.exports.name = ... + if re.search(rf'module\.exports\.{escaped}\s*=', full_file_content): + return True + + # CommonJS: exports.name = ... + if re.search(rf'exports\.{escaped}\s*=', full_file_content): + return True + + return False + @staticmethod def _is_exportable_class(class_name: str, full_file_content: str) -> bool: """ Check if a class is exported using any export syntax. - + Handles: - ES6: export class ClassName - ES6: export default ClassName or export default class ClassName @@ -927,39 +1039,14 @@ def _is_exportable_class(class_name: str, full_file_content: str) -> bool: # ES6: export class ClassName or export default class ClassName if re.search(rf'export\s+(?:default\s+)?class\s+{re.escape(class_name)}\b', full_file_content): return True - - # ES6: export default ClassName (after class definition) - if re.search(rf'export\s+default\s+{re.escape(class_name)}\b', full_file_content): - return True - - # ES6: export { ClassName } - export_pattern = rf'export\s+\{{[^}}]*\b{re.escape(class_name)}\b[^}}]*\}}' - if re.search(export_pattern, full_file_content): - return True - - # CommonJS: module.exports = ClassName - if re.search(rf'module\.exports\s*=\s*{re.escape(class_name)}\b', full_file_content): - return True - - # CommonJS: module.exports = { ClassName } or module.exports = { ClassName: ClassName } - if re.search(rf'module\.exports\s*=\s*\{{[^}}]*\b{re.escape(class_name)}\b[^}}]*\}}', full_file_content): - return True - - # CommonJS: module.exports.ClassName = ClassName - if re.search(rf'module\.exports\.{re.escape(class_name)}\s*=', full_file_content): - return True - - # CommonJS: exports.ClassName = ClassName - if re.search(rf'exports\.{re.escape(class_name)}\s*=', full_file_content): - return True - - return False + + return JavaScriptFunctionsParser._is_exported_by_name(class_name, full_file_content) @staticmethod def _is_exportable_function(function_name: str, full_file_content: str) -> bool: """ Check if a standalone function is exported using any export syntax. - + Handles: - ES6: export function functionName - ES6: export default functionName @@ -969,32 +1056,15 @@ def _is_exportable_function(function_name: str, full_file_content: str) -> bool: - CommonJS: module.exports.functionName = functionName - CommonJS: exports.functionName = functionName """ - # ES6: export { functionName } or export { func as functionName } - export_pattern = rf'export\s+\{{[^}}]*\b{re.escape(function_name)}\b[^}}]*\}}' - if re.search(export_pattern, full_file_content): - return True - - # ES6: export default functionName - if re.search(rf'export\s+default\s+{re.escape(function_name)}\b', full_file_content): - return True - - # CommonJS: module.exports = functionName - if re.search(rf'module\.exports\s*=\s*{re.escape(function_name)}\b', full_file_content): - return True - - # CommonJS: module.exports = { functionName } or module.exports = { functionName: functionName } - if re.search(rf'module\.exports\s*=\s*\{{[^}}]*\b{re.escape(function_name)}\b[^}}]*\}}', full_file_content): + # ES6: export function functionName / export async function functionName + if re.search(rf'export\s+(?:async\s+)?function\b\s*\*?\s*{re.escape(function_name)}\b', full_file_content): return True - - # CommonJS: module.exports.functionName = ... - if re.search(rf'module\.exports\.{re.escape(function_name)}\s*=', full_file_content): - return True - - # CommonJS: exports.functionName = ... - if re.search(rf'exports\.{re.escape(function_name)}\s*=', full_file_content): + + # ES6: export const/let/var functionName = + if re.search(rf'export\s+(?:const|let|var)\s+{re.escape(function_name)}\b', full_file_content): return True - - return False + + return JavaScriptFunctionsParser._is_exported_by_name(function_name, full_file_content) def document_imports_package(self, documents: dict[str, Document], package_name: str) -> list[Document]: diff --git a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py index 7c7f7146..65cc2c6d 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py @@ -173,6 +173,8 @@ def is_doc_type(self, doc: Document) -> bool: return doc.page_content.startswith(self.get_type_reserved_word()) def filter_docs_by_func_pkg_name(self, function_name: str, package_name: str, documents: list[Document]) -> list[Document]: + if not function_name: + return [] relevant_docs = [ doc for doc in documents if doc.metadata.get('source').__contains__(package_name) and diff --git a/src/exploit_iq_commons/utils/javascript_extended_segmenter.py b/src/exploit_iq_commons/utils/javascript_extended_segmenter.py index d5168a36..28bb9211 100644 --- a/src/exploit_iq_commons/utils/javascript_extended_segmenter.py +++ b/src/exploit_iq_commons/utils/javascript_extended_segmenter.py @@ -13,386 +13,264 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +import threading from typing import List -from typing import Tuple -import esprima -from langchain_community.document_loaders.parsers.language.javascript import JavaScriptSegmenter +from tree_sitter_languages import get_language +from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import TreeSitterSegmenter from exploit_iq_commons.logging.loggers_factory import LoggingFactory logger = LoggingFactory.get_agent_logger(__name__) +_JS_EXCLUDE_DIRS: frozenset[str] = frozenset({ + "dist", "build/static", "coverage", ".nyc_output", +}) -class ExtendedJavaScriptSegmenter(JavaScriptSegmenter): - """Extended JavaScript segmenter that handles shebang and ES optional chaining.""" +_JS_EXCLUDE_SUFFIXES: tuple[str, ...] = ( + ".min.js", ".min.css", ".bundle.js", +) + +# Tree-sitter query matching JavaScript constructs that the agent tools +# need for function extraction and call-chain analysis. +# +# Named exports (export function foo, export class Bar, export const x = () => {}) +# are captured by the inner declaration patterns — the export keyword is NOT +# included in those chunks. Only anonymous default exports need explicit +# export_statement patterns since they have no matching inner declaration. +JS_CHUNK_QUERY = """ +[ + (function_declaration) @function + (generator_function_declaration) @generator + + (class_declaration) @class + + (lexical_declaration + (variable_declarator + value: [(arrow_function) (function)])) @func_expr + + (variable_declaration + (variable_declarator + value: [(arrow_function) (function)])) @var_func_expr + + (variable_declaration + (variable_declarator + value: (call_expression + arguments: (arguments [(function) (arrow_function)])))) @wrapped_func + + (lexical_declaration + (variable_declarator + value: (call_expression + arguments: (arguments [(function) (arrow_function)])))) @wrapped_func_const + + (lexical_declaration + (variable_declarator + value: (object + [(method_definition) (pair value: [(function) (arrow_function)])]))) @obj_with_methods + + (variable_declaration + (variable_declarator + value: (object + [(method_definition) (pair value: [(function) (arrow_function)])]))) @var_obj_with_methods + + (export_statement (function)) @export_default_func + (export_statement (arrow_function)) @export_default_arrow + (export_statement (class)) @export_default_class + + (expression_statement + (assignment_expression + right: [(function) (arrow_function)])) @assign_func +] +""" + + +class ExtendedJavaScriptSegmenter(TreeSitterSegmenter): + """JavaScript segmenter using tree-sitter for fast, modern-syntax-aware parsing. + + Handles optional chaining, nullish coalescing, top-level await, and all + other modern JS syntax that esprima-python cannot parse. Parses each + file once and caches the tree for reuse across extract/simplify calls. + """ + + # Class-level caches — initialized on first instantiation, shared across + # all instances. Double-checked locking for thread safety. + _js_language = None + _js_chunk_query = None + _js_error_query = None + _init_lock = threading.Lock() + + @staticmethod + def should_skip(path: str) -> bool: + """Return True if the file at *path* is a JS build artifact that should not be indexed.""" + parts = path.replace("\\", "/").split("/") + if "node_modules" in parts[:-1]: + return False + if path.endswith(_JS_EXCLUDE_SUFFIXES): + return True + # Only check root-level directories — deeper dist/build are legitimate (e.g. node_modules/pkg/dist/) + if len(parts) > 1: + if len(parts) > 2 and f"{parts[0]}/{parts[1]}" in _JS_EXCLUDE_DIRS: + return True + if parts[0] in _JS_EXCLUDE_DIRS: + return True + return False def __init__(self, code: str): - """Initialize the segmenter with preprocessed code.""" super().__init__(code) - # Skip files with shebang (#!) line since they typically contain non-standard JavaScript syntax + cls = type(self) + if cls._js_language is None: + with cls._init_lock: + if cls._js_language is None: + lang = get_language("javascript") + cls._js_chunk_query = lang.query(JS_CHUNK_QUERY) + cls._js_error_query = lang.query("(ERROR) @error") + cls._js_language = lang + self._cached_tree = None + if code.startswith("#!"): logger.warning("File contains a shebang line. Skipping parsing.") self.skip_file = True - # esprima-python parser limitation: cannot handle JavaScript optional chaining syntax, - # fallback to regular property access else: self.skip_file = False - self.code = self.code.replace("?.", ".") - - def _parse_with_fallback(self) -> Any: - """Try to parse code as script first, then as module if that fails.""" - try: - logger.debug("Attempting to parse as a script...") - return esprima.parseScript(self.code, loc=True) - except esprima.Error: - logger.debug("Script parsing failed. Trying module parsing...") - try: - return esprima.parseModule(self.code, loc=True) - except esprima.Error as e: - logger.error("Module parsing failed: %s", str(e)) - print(f'TW Module parsing error: {self.code=} {e=}') - return None - def extract_functions_classes(self) -> List[str]: - """ - Extract functions, classes, and individual class methods from JavaScript code. - - Handles various JavaScript function patterns: - - Regular function declarations - - Arrow functions (const func = () => {}) - - Async functions - - Generator functions - - Class declarations and their methods - - Exported functions and classes - - Returns: - List of code strings representing all extracted functions and classes - """ - if self.skip_file: - return [] + # -- TreeSitterSegmenter contract ---------------------------------------- - tree = self._parse_with_fallback() - if tree is None: - return [] + def get_language(self): + return self._js_language - functions_classes = [] - for node in tree.body: - # Handle direct function/class declarations (including async and generator) - if isinstance(node, (esprima.nodes.FunctionDeclaration, - esprima.nodes.AsyncFunctionDeclaration, - esprima.nodes.ClassDeclaration)): - functions_classes.append(self._extract_code(node)) - - # Handle variable declarations that might contain arrow functions or function expressions - elif isinstance(node, esprima.nodes.VariableDeclaration): - arrow_funcs = self._extract_arrow_functions(node) - functions_classes.extend(arrow_funcs) - - # Handle exported declarations - elif isinstance(node, esprima.nodes.ExportNamedDeclaration): - if isinstance(node.declaration, (esprima.nodes.FunctionDeclaration, - esprima.nodes.AsyncFunctionDeclaration, - esprima.nodes.ClassDeclaration)): - functions_classes.append(self._extract_code(node)) - elif isinstance(node.declaration, esprima.nodes.VariableDeclaration): - arrow_funcs = self._extract_arrow_functions(node.declaration) - functions_classes.extend(arrow_funcs) - - # Handle default exports - elif isinstance(node, esprima.nodes.ExportDefaultDeclaration): - if isinstance(node.declaration, (esprima.nodes.FunctionDeclaration, - esprima.nodes.AsyncFunctionDeclaration, - esprima.nodes.ClassDeclaration)): - functions_classes.append(self._extract_code(node)) - - # Handle property assignments with function expressions, - # e.g. hb.compile = function(input, options) { ... } - elif isinstance(node, esprima.nodes.ExpressionStatement): - if isinstance(node.expression, esprima.nodes.AssignmentExpression): - if isinstance(node.expression.right, ( - esprima.nodes.FunctionExpression, - esprima.nodes.ArrowFunctionExpression - )): - functions_classes.append(self._extract_code(node)) - - # Extract individual class methods - class_methods = self._extract_all_class_methods() - functions_classes.extend(class_methods) - - # Extract individual object methods - object_methods = self._extract_all_object_methods() - functions_classes.extend(object_methods) - - return functions_classes + def get_chunk_query(self) -> str: + return JS_CHUNK_QUERY - def simplify_code(self) -> str: - """Simplify the code by replacing function/class bodies with comments.""" - if self.skip_file: - return self.code + def make_line_comment(self, text: str) -> str: + return f"// {text}" - tree = self._parse_with_fallback() - if tree is None: - return self.code + # -- Cached tree --------------------------------------------------------- - simplified_lines = self.source_lines[:] - indices_to_del: List[Tuple[int, int]] = [] - - for node in tree.body: - if isinstance(node, (esprima.nodes.FunctionDeclaration, - esprima.nodes.AsyncFunctionDeclaration, - esprima.nodes.ClassDeclaration)): - start, end = node.loc.start.line - 1, node.loc.end.line - simplified_lines[start] = f"// Code for: {simplified_lines[start]}" - indices_to_del.append((start + 1, end)) - elif isinstance(node, esprima.nodes.ExportNamedDeclaration): - if isinstance(node.declaration, (esprima.nodes.FunctionDeclaration, - esprima.nodes.AsyncFunctionDeclaration, - esprima.nodes.ClassDeclaration)): - start, end = node.loc.start.line - 1, node.loc.end.line - simplified_lines[start] = f"// Code for: {simplified_lines[start]}" - indices_to_del.append((start + 1, end)) - - for start, end in reversed(indices_to_del): - del simplified_lines[start:end] - - return "\n".join(line for line in simplified_lines) - - def _extract_arrow_functions(self, var_node: esprima.nodes.VariableDeclaration) -> List[str]: - """ - Extract arrow functions, function expressions, and object literals with methods - from variable declarations. - - Handles patterns like: - - const func = () => {} - - const func = async () => {} - - const func = function() {} - - const func = async function() {} - - const obj = { method() {}, prop: function() {} } - - Args: - var_node: VariableDeclaration node that might contain function expressions or objects - - Returns: - List of function/object code strings - """ - functions = [] - - for declarator in var_node.declarations: - if declarator.init is None: - continue - - # Check if this is an arrow function or function expression - is_function_expr = isinstance(declarator.init, ( - esprima.nodes.ArrowFunctionExpression, - esprima.nodes.FunctionExpression - )) - - # Check if init is a CallExpression wrapping a function, - # e.g. var defaultsDeep = baseRest(function(args) { ... }) - if not is_function_expr and isinstance(declarator.init, esprima.nodes.CallExpression): - is_function_expr = any( - isinstance(arg, (esprima.nodes.FunctionExpression, - esprima.nodes.ArrowFunctionExpression)) - for arg in declarator.init.arguments - ) - - if is_function_expr: - # Extract the entire variable declaration including the assignment - code = self._extract_code(var_node) - functions.append(code) - break # Only extract once per VariableDeclaration - - # Check if this is an object literal with methods - if isinstance(declarator.init, esprima.nodes.ObjectExpression): - if self._object_has_methods(declarator.init): - code = self._extract_code(var_node) - functions.append(code) - break # Only extract once per VariableDeclaration - - return functions - - def _object_has_methods(self, obj_node: esprima.nodes.ObjectExpression) -> bool: - """ - Check if an object literal contains any method properties. - - Args: - obj_node: ObjectExpression AST node - - Returns: - True if the object contains at least one method property - """ - for prop in obj_node.properties: - if isinstance(prop, esprima.nodes.Property): - # Check for shorthand method syntax: { method() {} } - if prop.method: - return True - # Check for function expression: { prop: function() {} } - # or arrow function: { prop: () => {} } - if isinstance(prop.value, ( - esprima.nodes.FunctionExpression, - esprima.nodes.ArrowFunctionExpression - )): - return True - return False + def _get_tree(self): + if self._cached_tree is None: + from tree_sitter import Parser + parser = Parser() + parser.set_language(self._js_language) + self._cached_tree = parser.parse(bytes(self.code, "utf-8")) + return self._cached_tree - def _extract_all_class_methods(self) -> List[str]: - """ - Extract all methods from all classes in the code. + # -- Public API ---------------------------------------------------------- + # Overrides of TreeSitterSegmenter methods to use cached tree/query and + # to extend extract_functions_classes with class/object method extraction. - Each method is extracted as standalone code with a //class: ClassName annotation - appended at the end, enabling later identification of which class the method belongs to. + def is_valid(self) -> bool: + if self.skip_file: + return False + return len(self._js_error_query.captures(self._get_tree().root_node)) == 0 - Returns: - List of method code strings with class annotations - """ - tree = self._parse_with_fallback() - if tree is None: + def extract_functions_classes(self) -> List[str]: + if self.skip_file: return [] - methods = [] + tree = self._get_tree() + captures = self._js_chunk_query.captures(tree.root_node) - for node in tree.body: - if isinstance(node, esprima.nodes.ClassDeclaration): - class_methods = self._extract_class_methods(node) - methods.extend(class_methods) - elif isinstance(node, esprima.nodes.ExportNamedDeclaration): - if isinstance(node.declaration, esprima.nodes.ClassDeclaration): - class_methods = self._extract_class_methods(node.declaration) - methods.extend(class_methods) + processed_lines = set() + chunks: List[str] = [] - return methods - - def _extract_class_methods(self, class_node: esprima.nodes.ClassDeclaration) -> List[str]: - """ - Extract all methods from a single class node. - - Args: - class_node: AST node representing a class declaration + for node, _name in captures: + start_line = node.start_point[0] + end_line = node.end_point[0] + lines = range(start_line, end_line + 1) - Returns: - List of method code strings, each annotated with //class: ClassName - """ - class_name = class_node.id.name if class_node.id else "AnonymousClass" - methods = [] - - for method_node in class_node.body.body: - if isinstance(method_node, esprima.nodes.MethodDefinition): - method_code = self._extract_method_code(method_node) - annotated_method = f"{method_code}\n//(class: {class_name})" - methods.append(annotated_method) - - return methods - - def _extract_method_code(self, method_node: esprima.nodes.MethodDefinition) -> str: - """ - Extract the source code for a single method. - - Args: - method_node: AST node representing a method definition - - Returns: - Source code string for the method - """ - if not hasattr(method_node, 'loc') or method_node.loc is None: - logger.warning("Method node has no location information") - return "" - - start_line = method_node.loc.start.line - 1 # Convert to 0-indexed - end_line = method_node.loc.end.line - - method_lines = self.source_lines[start_line:end_line] - return "\n".join(method_lines) - - def _extract_all_object_methods(self) -> List[str]: - """ - Extract all methods from all object literals in the code. - - Each method is extracted as standalone code with a //(object: objectName) annotation - appended at the end, enabling later identification of which object the method belongs to. - - Returns: - List of method code strings with object annotations - """ - tree = self._parse_with_fallback() - if tree is None: - return [] + if any(line in processed_lines for line in lines): + continue - methods = [] + processed_lines.update(lines) + chunks.append(node.text.decode("utf-8")) - for node in tree.body: - if isinstance(node, esprima.nodes.VariableDeclaration): - for declarator in node.declarations: - if declarator.init is None: - continue - if isinstance(declarator.init, esprima.nodes.ObjectExpression): - if self._object_has_methods(declarator.init): - object_name = declarator.id.name if hasattr(declarator.id, 'name') else "AnonymousObject" - object_methods = self._extract_object_methods(declarator.init, object_name) - methods.extend(object_methods) - - # Handle exported object literals - elif isinstance(node, esprima.nodes.ExportNamedDeclaration): - if isinstance(node.declaration, esprima.nodes.VariableDeclaration): - for declarator in node.declaration.declarations: - if declarator.init is None: - continue - if isinstance(declarator.init, esprima.nodes.ObjectExpression): - if self._object_has_methods(declarator.init): - object_name = declarator.id.name if hasattr(declarator.id, 'name') else "AnonymousObject" - object_methods = self._extract_object_methods(declarator.init, object_name) - methods.extend(object_methods) + chunks.extend(self._extract_class_methods(tree)) + chunks.extend(self._extract_object_methods(tree)) - return methods + return chunks - def _extract_object_methods(self, obj_node: esprima.nodes.ObjectExpression, object_name: str) -> List[str]: - """ - Extract all methods from a single object literal. + def simplify_code(self) -> str: + if self.skip_file: + return self.code - Args: - obj_node: AST node representing an object expression - object_name: Name of the variable holding the object + tree = self._get_tree() + processed_lines: set = set() - Returns: - List of method code strings, each annotated with //(object: objectName) - """ - methods = [] + simplified_lines = self.source_lines[:] + for node, _name in self._js_chunk_query.captures(tree.root_node): + start_line = node.start_point[0] + end_line = node.end_point[0] - for prop in obj_node.properties: - if not isinstance(prop, esprima.nodes.Property): + lines = range(start_line, end_line + 1) + if any(line in processed_lines for line in lines): continue - # Check if this property is a method - is_method = prop.method or isinstance(prop.value, ( - esprima.nodes.FunctionExpression, - esprima.nodes.ArrowFunctionExpression - )) - - if is_method: - method_code = self._extract_property_method_code(prop) - if method_code: - # Use same annotation pattern as classes for consistency - annotated_method = f"{method_code}\n//(class: {object_name})" - methods.append(annotated_method) - + simplified_lines[start_line] = self.make_line_comment( + f"Code for: {self.source_lines[start_line]}" + ) + for line_num in range(start_line + 1, end_line + 1): + simplified_lines[line_num] = None # type: ignore[call-overload] + + processed_lines.update(lines) + + return "\n".join(line for line in simplified_lines if line is not None) + + # -- Private helpers ----------------------------------------------------- + + def _extract_class_methods(self, tree) -> List[str]: + """Extract individual methods from class declarations, annotated with class name.""" + methods: List[str] = [] + for node_type in ("class_declaration", "class"): + for node in self._walk_type(tree.root_node, node_type): + class_name = self._child_text(node, "identifier") or "AnonymousClass" + body = self._child_by_type(node, "class_body") + if body is None: + continue + for method_node in body.children: + if method_node.type == "method_definition": + method_code = method_node.text.decode("utf-8") + methods.append(f"{method_code}\n//(class: {class_name})") return methods - def _extract_property_method_code(self, prop_node: esprima.nodes.Property) -> str: - """ - Extract the source code for a single object property method. - - Args: - prop_node: AST node representing a property with a method value - - Returns: - Source code string for the method - """ - if not hasattr(prop_node, 'loc') or prop_node.loc is None: - logger.warning("Property node has no location information") - return "" - - start_line = prop_node.loc.start.line - 1 # Convert to 0-indexed - end_line = prop_node.loc.end.line + def _extract_object_methods(self, tree) -> List[str]: + """Extract individual methods from object literals in variable declarations.""" + methods: List[str] = [] + for decl_type in ("lexical_declaration", "variable_declaration"): + for decl_node in self._walk_type(tree.root_node, decl_type): + for declarator in decl_node.children: + if declarator.type != "variable_declarator": + continue + obj_name = self._child_text(declarator, "identifier") or "AnonymousObject" + obj_node = self._child_by_type(declarator, "object") + if obj_node is None: + continue + for prop in obj_node.children: + if prop.type == "method_definition": + methods.append(f"{prop.text.decode('utf-8')}\n//(class: {obj_name})") + elif prop.type == "pair": + value = self._child_by_type(prop, "function") or self._child_by_type(prop, "arrow_function") + if value is not None: + methods.append(f"{prop.text.decode('utf-8')}\n//(class: {obj_name})") + return methods - method_lines = self.source_lines[start_line:end_line] - return "\n".join(method_lines) + @staticmethod + def _walk_type(root, node_type: str): + """Yield top-level children matching node_type, including inside export_statement.""" + for child in root.children: + if child.type == node_type: + yield child + elif child.type == "export_statement": + for grandchild in child.children: + if grandchild.type == node_type: + yield grandchild + + @staticmethod + def _child_by_type(node, child_type: str): + for child in node.children: + if child.type == child_type: + return child + return None + + @staticmethod + def _child_text(node, child_type: str) -> str | None: + child = ExtendedJavaScriptSegmenter._child_by_type(node, child_type) + return child.text.decode("utf-8") if child is not None else None \ No newline at end of file diff --git a/src/vuln_analysis/utils/function_name_extractor.py b/src/vuln_analysis/utils/function_name_extractor.py index 6d789561..b3d19e46 100644 --- a/src/vuln_analysis/utils/function_name_extractor.py +++ b/src/vuln_analysis/utils/function_name_extractor.py @@ -144,7 +144,12 @@ def fetch_list(self, query: str) -> list[str]: if func.page_content[current_offset: match.start()].strip().startswith(comment_line_character): pass else: - new_function_name = self.lang_parser.get_function_name(func) + try: + new_function_name = self.lang_parser.get_function_name(func) + except ValueError: + continue + if not new_function_name: + continue containing_functions.append(f"{package},{new_function_name}") return containing_functions diff --git a/tests/test_java_script_extended_segmenter.py b/tests/test_java_script_extended_segmenter.py index 637f0d1f..41dfef7a 100644 --- a/tests/test_java_script_extended_segmenter.py +++ b/tests/test_java_script_extended_segmenter.py @@ -103,13 +103,12 @@ def test_code_simplification(test_case): assert "// Code for:" in simplified -def test_optional_chaining_replacement(): - """Test that optional chaining is correctly replaced.""" +def test_optional_chaining_preservation(): + """Test that optional chaining is preserved (tree-sitter handles it natively).""" code = "const name = user?.profile?.name;" segmenter = ExtendedJavaScriptSegmenter(code) assert not segmenter.skip_file - assert "?." not in segmenter.code - assert "." in segmenter.code + assert "?." in segmenter.code def test_invalid_js(): @@ -195,8 +194,9 @@ def test_extract_async_function(): segmenter = ExtendedJavaScriptSegmenter(code) functions = segmenter.extract_functions_classes() - assert len(functions) == 1 + assert len(functions) == 2 assert 'fetchData' in functions[0] and 'async' in functions[0] + assert 'asyncArrow' in functions[1] @@ -506,7 +506,7 @@ def test_shebang_file(): }; """, "expected_count": 3, - "object_var": "export const config", + "object_var": "const config", "annotation_name": "config", "expected_methods": 2, }, @@ -591,3 +591,177 @@ class UserService { object_methods = [f for f in functions if '//(class: apiUtils)' in f] assert len(object_methods) == 2 + +def test_anonymous_default_export_class_methods(): + """Test that methods from 'export default class { ... }' get //(class:) annotations.""" + code = """ +export default class { + connect() { + return this.db.connect(); + } + + disconnect() { + this.db.close(); + } +} +""" + segmenter = ExtendedJavaScriptSegmenter(code) + functions = segmenter.extract_functions_classes() + + # Should extract: the class chunk + 2 methods with AnonymousClass annotation + method_functions = [f for f in functions if '//(class: AnonymousClass)' in f] + assert len(method_functions) == 2 + assert any('connect' in f for f in method_functions) + assert any('disconnect' in f for f in method_functions) + + +def test_named_default_export_class_methods(): + """Test that methods from 'export default class Foo { ... }' get //(class: Foo) annotations.""" + code = """ +export default class Foo { + bar() { + return 1; + } +} +""" + segmenter = ExtendedJavaScriptSegmenter(code) + functions = segmenter.extract_functions_classes() + + method_functions = [f for f in functions if '//(class: Foo)' in f] + assert len(method_functions) == 1 + assert 'bar' in method_functions[0] + + +# ============================================================================= +# should_skip too broad — app-level dist/ directories excluded +# ============================================================================= + +class TestShouldSkipEdgeCases: + """should_skip currently excludes any 'dist/' outside node_modules, even + when it's nested inside application source like 'src/dist/'.""" + + def test_node_modules_dist_preserved(self): + """Third-party dist/ inside node_modules should NOT be skipped.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "node_modules/lodash/dist/lodash.js" + ) + + def test_node_modules_nested_dist_preserved(self): + """Nested dist/ inside node_modules should NOT be skipped.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "node_modules/@babel/core/dist/index.js" + ) + + def test_root_dist_skipped(self): + """Root-level dist/ IS a build artifact — should be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("dist/bundle.js") + + def test_root_build_static_skipped(self): + """Root-level build/static/ IS a build artifact — should be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("build/static/js/main.js") + + def test_minified_file_skipped(self): + """.min.js files should always be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("lib/utils.min.js") + + def test_src_dist_should_not_be_skipped(self): + """dist/ inside src/ is likely application source, not a build artifact. + Currently should_skip returns True here — this is the bug.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "src/dist/utils.js" + ), "src/dist/utils.js is likely app source, not a build artifact" + + def test_lib_dist_should_not_be_skipped(self): + """dist/ inside lib/ is likely a sub-module, not a build artifact.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "lib/dist/helpers.js" + ), "lib/dist/helpers.js is likely app source, not a build artifact" + + def test_regular_app_file_not_skipped(self): + """Regular application files should not be skipped.""" + assert not ExtendedJavaScriptSegmenter.should_skip("src/app/main.js") + + def test_bundle_file_skipped(self): + """.bundle.js files should always be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("assets/app.bundle.js") + + +# ============================================================================= +# Thread-safe class-level init +# ============================================================================= + +class TestThreadSafeInit: + """Class-level caches (_js_language, _js_chunk_query, _js_error_query) + are initialized on first instantiation with a None-check but no lock. + Two threads hitting __init__ simultaneously could see _js_language set + but _js_chunk_query still None.""" + + def test_concurrent_init_produces_valid_segmenters(self): + """Create segmenters from multiple threads and verify they all parse correctly.""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + code = "function hello() { return 42; }" + results = [] + errors = [] + + def create_and_extract(idx): + seg = ExtendedJavaScriptSegmenter(code) + funcs = seg.extract_functions_classes() + return idx, funcs + + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(create_and_extract, i) for i in range(20)] + for f in as_completed(futures): + try: + idx, funcs = f.result() + results.append((idx, funcs)) + except Exception as exc: + errors.append(exc) + + assert not errors, f"Concurrent init raised: {errors}" + assert len(results) == 20 + for idx, funcs in results: + assert len(funcs) == 1, f"Thread {idx}: expected 1 function, got {len(funcs)}" + + def test_class_caches_are_set_after_init(self): + """After any instantiation, all three class caches must be non-None.""" + ExtendedJavaScriptSegmenter("var x = 1;") + assert ExtendedJavaScriptSegmenter._js_language is not None + assert ExtendedJavaScriptSegmenter._js_chunk_query is not None + assert ExtendedJavaScriptSegmenter._js_error_query is not None + + +# ============================================================================= +# Bug: should_skip .min.js check runs before node_modules path analysis +# Lines 102-104: .min.js suffix check returns True before we check node_modules +# ============================================================================= + +class TestShouldSkipMinJsInNodeModules: + """Third-party .min.js files inside node_modules should NOT be skipped.""" + + def test_minified_in_node_modules_should_not_be_skipped(self): + """node_modules/jquery/dist/jquery.min.js is legitimate third-party source.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "node_modules/jquery/dist/jquery.min.js" + ), ".min.js inside node_modules should not be skipped" + + def test_minified_bundle_in_node_modules_should_not_be_skipped(self): + """node_modules/lodash/lodash.min.js is legitimate.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "node_modules/lodash/lodash.min.js" + ) + + def test_bundle_in_node_modules_should_not_be_skipped(self): + """node_modules/chart.js/dist/chart.bundle.js is legitimate.""" + assert not ExtendedJavaScriptSegmenter.should_skip( + "node_modules/chart.js/dist/chart.bundle.js" + ) + + def test_app_level_minified_still_skipped(self): + """App-level .min.js files are build artifacts — should be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("lib/utils.min.js") + + def test_app_level_bundle_still_skipped(self): + """App-level .bundle.js files are build artifacts — should be skipped.""" + assert ExtendedJavaScriptSegmenter.should_skip("assets/app.bundle.js") + diff --git a/tests/test_javascript_functions_parser.py b/tests/test_javascript_functions_parser.py index 68973d85..fda6ea5b 100644 --- a/tests/test_javascript_functions_parser.py +++ b/tests/test_javascript_functions_parser.py @@ -2,6 +2,7 @@ from langchain_core.documents import Document from exploit_iq_commons.utils.functions_parsers.javascript_functions_parser import JavaScriptFunctionsParser +from exploit_iq_commons.utils.javascript_extended_segmenter import ExtendedJavaScriptSegmenter @pytest.fixture @@ -254,6 +255,137 @@ def test_get_function_name_valid_patterns(parser, code, expected_name): assert result == expected_name +@pytest.mark.parametrize("code,expected_name", [ + # Object literal property with arrow function + ( + "handler: (req, res) => {\n res.send('ok');\n}", + "handler" + ), + ( + "onClick: async (event) => {\n await handleClick(event);\n}", + "onClick" + ), + # Object literal property with single-param arrow (no parens) + ( + "transform: item => item.toUpperCase()", + "transform" + ), + ( + "callback: async result => {\n await process(result);\n}", + "callback" + ), + # Object literal property with function expression + ( + "initialize: function() {\n this.setup();\n}", + "initialize" + ), + ( + "process: async function(data) {\n await save(data);\n}", + "process" + ), + # Property assignment with single-param arrow + ( + "module.exports.handler = event => {\n return event.body;\n}", + "handler" + ), + ( + "self.callback = async result => {\n await log(result);\n}", + "callback" + ), + # Dollar sign in identifiers + ( + "function $apply(scope) {\n scope.$digest();\n}", + "$apply" + ), + ( + "const $timeout = (fn, delay) => {\n setTimeout(fn, delay);\n}", + "$timeout" + ), + ( + "function not$(value) {\n return !value;\n}", + "not$" + ), + ( + "const in$ = (key, obj) => key in obj;", + "in$" + ), + # String-keyed object property + ( + '"content-type"(value) {\n return value.toLowerCase();\n}', + "content-type" + ), + ( + "'set-cookie': (value) => {\n return parse(value);\n}", + "set-cookie" + ), + # Static getter/setter with computed property + ( + "static [Symbol.species]() {\n return Array;\n}", + "Symbol.species" + ), + ( + "static async [Symbol.asyncIterator]() {\n yield 1;\n}", + "Symbol.asyncIterator" + ), + ( + "get [Symbol.toStringTag]() {\n return 'Custom';\n}", + "Symbol.toStringTag" + ), + # TypeScript generics in function signatures + ( + "function assert(value: T): asserts value {\n if (!value) throw new Error();\n}", + "assert" + ), + ( + "function identity(arg: T): T {\n return arg;\n}", + "identity" + ), + # Constructor with destructured default params + ( + "constructor({name = 'default', age = 0} = {}) {\n this.name = name;\n}//(class: Config)", + "constructor" + ), + # Wrapper call pattern (const name = wrapper(...)) + ( + "const memoized = memoize(computeValue)", + "memoized" + ), + ( + "const debounced = debounce(handleInput, 300)", + "debounced" + ), +]) +def test_get_function_name_new_patterns(parser, code, expected_name): + """Test get_function_name with generator and TypeScript patterns.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + result = parser.get_function_name(doc) + assert result == expected_name + + +@pytest.mark.parametrize("code", [ + # TypeScript interface declarations + "interface UserConfig {\n name: string;\n age: number;\n}", + "export interface ApiResponse {\n data: any;\n status: number;\n}", + "declare interface Window {\n customProp: string;\n}", + # TypeScript enum declarations + "enum Color {\n Red,\n Green,\n Blue\n}", + "export enum Direction {\n Up,\n Down,\n Left,\n Right\n}", + "const enum HttpStatus {\n OK = 200,\n NotFound = 404\n}", + "export const enum LogLevel {\n Debug,\n Info,\n Error\n}", + "declare enum Platform {\n Web,\n Mobile\n}", +]) +def test_is_function_rejects_typescript_constructs(parser, code): + """Test that is_function returns False for TypeScript interface/enum declarations.""" + doc = Document( + page_content=code, + metadata={'source': 'test.ts', 'content_type': 'functions_classes'} + ) + assert parser.is_function(doc) is False + + @pytest.mark.parametrize("code", [ # Class declarations should raise ValueError "class MyClass {\n constructor() {}\n}", @@ -1440,7 +1572,7 @@ class TypeTest { ), ]) def test_is_exported_function(parser, description, function_content, file_content, expected): - source = "node_modules/packageurl-js/index.js", + source = "node_modules/packageurl-js/index.js" func_doc = Document( page_content=function_content, metadata={'source': source, 'content_type': 'functions_classes'} @@ -2194,18 +2326,18 @@ def test_is_exportable_function(parser, function_name, file_content, expected, d True, ), ( - "Dynamic import without await - .then pattern", + "Dynamic import without await - .then pattern (empty identifier → False)", """import('lodash').then(mod => mod.template());""", "", "lodash", - True, + False, ), ( - "Dynamic import side-effect only", + "Dynamic import side-effect only (empty identifier → False)", """await import('side-effect-pkg');""", "", "side-effect-pkg", - True, + False, ), ( "Dynamic import with backticks", @@ -2399,11 +2531,11 @@ def test_is_exportable_function(parser, function_name, file_content, expected, d True, ), ( - "Dynamic import without assignment", + "Dynamic import without assignment (empty identifier → False)", """(await import('lodash')).template();""", "", "lodash", - True, + False, ), ( "Multiple dynamic imports - check first", @@ -2449,169 +2581,1922 @@ def test_is_package_imported_dynamic_imports(parser, description, code_content, # Tests for is_package_imported - Multi-function imports and word boundary edge cases # ============================================================================ -@pytest.mark.parametrize("description,code_content,identifier,callee_package,expected", [ - # ======================================================================== - # Multi-function imports - should match exact identifiers - # ======================================================================== - ( - "Multi-function import - exact match for function1", - """import { function1, function2 } from "./utils.js";""", - "function1", - "./utils.js", - True, - ), - ( - "Multi-function import - exact match for function2", - """import { function1, function2 } from "./utils.js";""", - "function2", - "./utils.js", - True, - ), - # ======================================================================== - # Substring edge cases - should NOT match partial identifiers - # ======================================================================== +# ============================================================================ +# Tests for empty-name CCA guards +# ============================================================================ + +def test_get_function_calls_empty_name_returns_empty(parser): + """_get_function_calls must return [] when callee_function_name is empty.""" + caller = Document( + page_content="function test() {\n anything();\n something.method();\n}", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser._get_function_calls(caller, '') == [] + assert parser._get_function_calls(caller, None) == [] + + +def test_is_package_imported_empty_identifier_returns_false(parser): + """is_package_imported must return False when identifier is empty.""" + code = "import { template } from 'lodash';\nconst x = require('express');" + assert parser.is_package_imported(code, '', 'lodash') is False + assert parser.is_package_imported(code, '', '') is False + + +def test_filter_docs_by_func_pkg_name_empty_name_returns_empty(parser): + """filter_docs_by_func_pkg_name must return [] when function_name is empty.""" + docs = [ + Document(page_content="function template() {}", metadata={'source': 'node_modules/lodash/template.js'}), + Document(page_content="function parse() {}", metadata={'source': 'node_modules/lodash/parse.js'}), + ] + assert parser.filter_docs_by_func_pkg_name('', 'lodash', docs) == [] + + +def test_search_for_called_function_empty_callee_name(parser): + """search_for_called_function returns False when callee_function_name is empty. + + Without the guard on _get_function_calls, an empty callee name would + cause the regex to match spuriously and return True for any caller that imports + the callee package. + """ + caller = Document( + page_content="function handler() {\n lodash.template('hello');\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + callee = Document( + page_content="function template(str) { return str; }", + metadata={'source': 'node_modules/lodash/template.js', 'content_type': 'functions_classes'} + ) + code_docs = { + 'app/handler.js': Document( + page_content="import { template } from 'lodash';\nfunction handler() { lodash.template('hello'); }", + metadata={'source': 'app/handler.js', 'content_type': 'simplified_code'} + ) + } + + result = parser.search_for_called_function( + caller_function=caller, + callee_function_name='', + callee_function=callee, + callee_function_package='lodash', + code_documents=code_docs, + type_documents=[], + callee_function_file_name='node_modules/lodash/template.js', + fields_of_types={}, + functions_local_variables_index={}, + documents_of_functions=[], + type_inheritance={} + ) + assert result is False + + +# ============================================================================ +# Tests for dotted call chain RHS in variable declarations +# ============================================================================ + +@pytest.mark.parametrize("code,expected_name", [ ( - "Multi-function import - 'func' should NOT match 'function1'", - """import { function1, function2 } from "./utils.js";""", - "func", - "./utils.js", - False, + "var changes = ts.textChanges.ChangeTracker.with(context, function(tracker) {\n tracker.apply();\n})", + "changes" ), ( - "Multi-function import - 'add' should NOT match 'addHandler'", - """import { addHandler, removeHandler } from "./utils.js";""", - "add", - "./utils.js", - False, + "const result = obj.method.call(this, arg)", + "result" ), ( - "Single function import - 'parse' should NOT match 'parseAsync'", - """import { parseAsync } from "some-lib";""", - "parse", - "some-lib", - False, + "let parser = JSON.parse(data)", + "parser" ), - # ======================================================================== - # CommonJS require - substring edge cases - # ======================================================================== +]) +def test_get_function_name_dotted_rhs(parser, code, expected_name): + """Variable declarations with dotted call chains on the RHS should extract the variable name.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +# ============================================================================ +# Tests for bare assignment function +# ============================================================================ + +@pytest.mark.parametrize("code,expected_name", [ ( - "CommonJS destructured require - exact match", - """const { function1, function2 } = require('./utils.js');""", - "function1", - "./utils.js", - True, + "name = function(x, y) {\n return x + y;\n}", + "name" ), ( - "CommonJS destructured require - 'func' should NOT match 'function1'", - """const { function1, function2 } = require('./utils.js');""", - "func", - "./utils.js", - False, + "callback = function() {\n return 1;\n}", + "callback" ), - # ======================================================================== - # Re-exports - substring edge cases - # ======================================================================== ( - "Re-export - exact match", - """export { function1, function2 } from "./utils.js";""", - "function1", - "./utils.js", - True, + "handler = async function(req, res) {\n await handle(req);\n}", + "handler" ), ( - "Re-export - 'func' should NOT match 'function1'", - """export { function1, function2 } from "./utils.js";""", - "func", - "./utils.js", - False, + "gen = function*() {\n yield 1;\n}", + "gen" ), ]) -def test_is_package_imported_word_boundaries(parser, description, code_content, identifier, callee_package, expected): - """Test is_package_imported with word boundary edge cases to prevent substring false positives.""" - result = parser.is_package_imported(code_content, identifier, callee_package) - assert result == expected, f"{description}: expected {expected}, got {result}" +def test_get_function_name_bare_assignment_function(parser, code, expected_name): + """Bare assignment to function expression (no const/let/var) should extract the LHS name.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name -@pytest.mark.parametrize("function_content,expected_name", [ - # Object methods now use //(class: name) pattern +# ============================================================================ +# Tests for bare assignment arrow +# ============================================================================ + +@pytest.mark.parametrize("code,expected_name", [ ( - "fetchData(url) {\n return fetch(url);\n}\n//(class: apiUtils)", - "apiUtils" + "resolve = (id, options) => {\n return lookup(id);\n}", + "resolve" ), ( - "format: function(data) {\n return JSON.stringify(data);\n}\n//(class: helpers)", - "helpers" + "process = async (data) => {\n await save(data);\n}", + "process" ), - # Standalone function (no annotation) ( - "function regularFunc() {\n return true;\n}", - None + "transform = x => x * 2", + "transform" ), ]) -def test_get_class_name_for_object_methods(parser, function_content, expected_name): - """Test that object methods use the same //(class: name) pattern as classes.""" +def test_get_function_name_bare_assignment_arrow(parser, code, expected_name): + """Bare assignment to arrow function (no const/let/var) should extract the LHS name.""" doc = Document( - page_content=function_content, + page_content=code, metadata={'source': 'test.js', 'content_type': 'functions_classes'} ) - result = parser.get_class_name_from_class_function(doc) - assert result == expected_name + assert parser.get_function_name(doc) == expected_name -@pytest.mark.parametrize("description,function_content,file_content,expected", [ - ( - "Object method should be exported when object is in module.exports", - "fetchData(url) {\n return fetch(url);\n}\n//(class: apiUtils)", - """const apiUtils = { - fetchData(url) { - return fetch(url); - } -}; +# ============================================================================ +# Tests for anonymous export default +# ============================================================================ -module.exports = { apiUtils }; -""", - True, +@pytest.mark.parametrize("code", [ + "export default function() {\n return 42;\n}", + "export default async function() {\n await fetch('/api');\n}", + "export default function*() {\n yield 1;\n}", +]) +def test_get_function_name_anonymous_export_default(parser, code): + """Anonymous export default should return empty string.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == '' + + +# ============================================================================ +# Tests for generator computed property +# ============================================================================ + +def test_get_function_name_generator_computed_property(parser): + """Generator computed property with * prefix should extract the symbol expression.""" + doc = Document( + page_content="*[Symbol.iterator]() {\n for (const item of this.items) {\n yield item;\n }\n}", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == 'Symbol.iterator' + + +# ============================================================================ +# Tests for keyword rejection in fallback pattern +# ============================================================================ + +@pytest.mark.parametrize("code,expected_name", [ + ( + "doStuff() {\n if (x) {\n return;\n }\n}", + "doStuff" ), ( - "Object method should be exported with ES6 named export", - "format(data) {\n return JSON.stringify(data);\n}\n//(class: helpers)", - """const helpers = { - format(data) { - return JSON.stringify(data); - } -}; - -export { helpers }; -""", - True, + "render() {\n for (let i = 0; i < 10; i++) {\n items.push(i);\n }\n}", + "render" ), ( - "Object method should NOT be exported when object is not exported", - "internalMethod() {\n return 'internal';\n}\n//(class: privateUtils)", - """const privateUtils = { - internalMethod() { - return 'internal'; - } -}; - -// No exports -""", - False, + "handle() {\n while (queue.length) {\n process(queue.shift());\n }\n}", + "handle" + ), + ( + "execute() {\n switch(type) {\n case('a'):\n break;\n }\n}", + "execute" ), ]) -def test_is_exported_function_for_objects(parser, description, function_content, file_content, expected): - """Test is_exported_function correctly handles object methods with //(class: name) annotation.""" - source = "node_modules/test-pkg/index.js" - func_doc = Document( - page_content=function_content, - metadata={'source': source, 'content_type': 'functions_classes'} +def test_get_function_name_keyword_rejection(parser, code, expected_name): + """Keyword filtering should skip JS keywords and find the real method name.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} ) - full_file_doc = Document( - page_content=file_content, - metadata={'source': source, 'content_type': 'simplified_code'} + assert parser.get_function_name(doc) == expected_name + + +@pytest.mark.parametrize("code", [ + "var reporter = {\n onTestStart: function(test) { if (test.pending) return; }\n}", + "if (x) {\n for (y in z) {\n while (true) { break; }\n }\n}", +]) +def test_get_function_name_all_keywords_returns_empty(parser, code): + """When only keywords match the fallback pattern, return empty string.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} ) - documents_of_full_sources = {source: full_file_doc} - - result = parser.is_exported_function(func_doc, documents_of_full_sources) + assert parser.get_function_name(doc) == '' + + +# ============================================================================ +# Tests for keyword set trimming: removed keywords that are valid method names +# ============================================================================ + +@pytest.mark.parametrize("code, expected_name", [ + # 'default' removed from _JS_KEYWORDS — valid method in commander.js etc. + ("default(value, description) {\n this._defaultValue = value;\n}", "default"), + # 'case' removed — can be a method name (e.g., case(value) in switch builders) + ("case(value) {\n this.cases.push(value);\n}", "case"), + # 'export' removed — can be a method name (e.g., export(data) in serializers) + ("export(data) {\n return JSON.stringify(data);\n}", "export"), + # 'import' removed — can be a method name + ("import(module) {\n return require(module);\n}", "import"), + # 'in' removed — can be a method name (e.g., in(collection) in query builders) + ("in(collection) {\n return this.where(collection);\n}", "in"), + # 'of' removed — can be a method name + ("of(items) {\n return new Collection(items);\n}", "of"), + # 'super' removed — can be used as method name in some patterns + ("super(args) {\n return parent.call(this, args);\n}", "super"), + # 'this' removed — can be used as method name + ("this(args) {\n return self.init(args);\n}", "this"), +]) +def test_get_function_name_removed_keywords_are_valid_methods(parser, code, expected_name): + """Keywords removed from _JS_KEYWORDS should be extractable as method names.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +@pytest.mark.parametrize("code", [ + # Keywords still in _JS_KEYWORDS should NOT be extracted (no non-keyword calls in body) + "if (x) { return (y) }", + "for (var y in z) { break; }", + "while (true) { continue; }", + "switch (x) { break; }", + "return (x)", + "throw (x)", + "new (x)", + "typeof (x)", + "delete (x)", +]) +def test_get_function_name_retained_keywords_still_rejected(parser, code): + """Keywords still in _JS_KEYWORDS should return empty string.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == '' + + +# ============================================================================ +# Tests for variable declaration matching restricted to header +# ============================================================================ + +@pytest.mark.parametrize("code, expected_name", [ + # Should NOT match const/let/var assignments inside function body + ("async method() {\n const x = obj.call(y);\n return x;\n}", "method"), + ("render() {\n const sanitized = this.sanitize(input);\n return sanitized;\n}", "render"), + # Should match in header (no body yet) + ("const result = helper.create(args)", "result"), + # Multi-line header before body + ("const handler =\n middleware.wrap(\n fn\n )", "handler"), +]) +def test_get_function_name_pattern5_body_restriction(parser, code, expected_name): + """Variable declaration matching should only search before the first '{' to avoid body matches.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +# ============================================================================ +# Tests for body-matching prevention: patterns should not match inside bodies +# ============================================================================ + +@pytest.mark.parametrize("code, expected_name", [ + # Function declaration inside body should not hijack name + ("method() { function inner() {} }", "method"), + ("render() { async function fetchData() {} }", "render"), + # Const single-param arrow inside body + ("method() { const helper = x => x }", "method"), + # Property function assignment inside body + ("method() { this.handler = function() {} }", "method"), + # Property single-param arrow inside body + ("method() { this.cb = val => val }", "method"), + # Bare assignment function inside body + ("method() { callback = function() {} }", "method"), + # Bare single-param arrow inside body + ("method() { transform = x => x * 2 }", "method"), +]) +def test_get_function_name_body_matching_prevention(parser, code, expected_name): + """Patterns restricted to header should not match declarations inside function bodies.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +# ============================================================================ +# Tests for generator name without space: function*name() +# ============================================================================ + +@pytest.mark.parametrize("code, expected_name", [ + ("function*myGen() { yield 1; }", "myGen"), + ("function *myGen() { yield 1; }", "myGen"), + ("function * myGen() { yield 1; }", "myGen"), + ("async function*streamGen() { yield 1; }", "streamGen"), +]) +def test_get_function_name_generator_no_space(parser, code, expected_name): + """Generator declarations should match function*name() without space after *.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +def test_get_function_name_no_functionName_false_match(parser): + """Word boundary should not match 'functionName(' as 'Name'.""" + doc = Document( + page_content="functionName(x) { return x; }", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == 'functionName' + + +# ============================================================================ +# Tests for function_called_from_caller_body regex fixes +# ============================================================================ + +def test_function_called_from_caller_body_finds_call(): + """Should return truthy when function is called (not declared) in body.""" + from exploit_iq_commons.utils.chain_of_calls_retriever import ChainOfCallsRetriever + from exploit_iq_commons.utils.functions_parsers.javascript_functions_parser import JavaScriptFunctionsParser + + parser = JavaScriptFunctionsParser() + retriever = ChainOfCallsRetriever.__new__(ChainOfCallsRetriever) + retriever.language_parser = parser + + doc = Document( + page_content="function handler(req) {\n const result = process(req.body);\n return result;\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + assert retriever.function_called_from_caller_body(doc, "process") + + +# ============================================================================ +# Tests for function_called_from_caller_body empty-name pass-through +# ============================================================================ + +def test_function_called_from_caller_body_empty_name_passthrough(): + """function_called_from_caller_body returns True for empty function_to_search (pass-through). + + Empty function_to_search is used by __find_initial_function as a "no filter" — + all documents pass through. Empty-name defense is handled by BFS guards. + """ + from exploit_iq_commons.utils.chain_of_calls_retriever import ChainOfCallsRetriever + from exploit_iq_commons.utils.functions_parsers.javascript_functions_parser import JavaScriptFunctionsParser + + parser = JavaScriptFunctionsParser() + retriever = ChainOfCallsRetriever.__new__(ChainOfCallsRetriever) + retriever.language_parser = parser + + doc = Document( + page_content="function handler() {\n someLib.process(data);\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + assert retriever.function_called_from_caller_body(doc, "") is True + assert retriever.function_called_from_caller_body(doc, " ") is True + + +# ============================================================================ +# Tests for _breadth_first_search empty-name guard +# ============================================================================ + +def test_bfs_skips_empty_name_docs(): + """BFS should skip documents whose get_function_name returns empty string.""" + from unittest.mock import MagicMock, patch + from exploit_iq_commons.utils.chain_of_calls_retriever import ChainOfCallsRetriever, _SearchCtx + + parser = JavaScriptFunctionsParser() + retriever = ChainOfCallsRetriever.__new__(ChainOfCallsRetriever) + retriever.language_parser = parser + retriever.documents = [] + retriever.sort_docs = {} + + empty_name_doc = Document( + page_content="var reporter = {\n if (x) {}\n}", + metadata={'source': 'test.js', 'content_type': 'functions_classes', 'state': None} + ) + + ctx = _SearchCtx() + result, found = retriever._breadth_first_search([], empty_name_doc, 'test-pkg', ctx) + assert found is False + assert result == [] + + +# ============================================================================ +# Tests for __find_caller_functions_bfs candidate empty-name guard +# ============================================================================ + +def test_bfs_caller_skips_empty_name_candidates(): + """__find_caller_functions_bfs should skip candidate docs with empty function names.""" + from unittest.mock import MagicMock, patch + from collections import defaultdict + from exploit_iq_commons.utils.chain_of_calls_retriever import ChainOfCallsRetriever, _SearchCtx + + parser = JavaScriptFunctionsParser() + retriever = ChainOfCallsRetriever.__new__(ChainOfCallsRetriever) + retriever.language_parser = parser + retriever.ecosystem = 'javascript' + + target_doc = Document( + page_content="function vulnerable() {\n return 'exploit';\n}", + metadata={'source': 'node_modules/lib/index.js', 'content_type': 'functions_classes'} + ) + + empty_name_candidate = Document( + page_content="var reporter = {\n if (x) {}\n}", + metadata={'source': 'node_modules/parent/index.js', 'content_type': 'functions_classes', 'state': None} + ) + + retriever.sort_docs = defaultdict(list, {'parent-pkg': [empty_name_candidate]}) + retriever.documents = [empty_name_candidate] + retriever.documents_of_full_sources = {} + retriever.documents_of_types = [] + retriever.types_classes_fields_mapping = {} + retriever.functions_local_variables_index = {} + retriever.documents_of_functions = [] + + ctx = _SearchCtx() + + with patch.object(retriever, '_get_parents', return_value=['parent-pkg']): + result = retriever._ChainOfCallsRetriever__find_caller_functions_bfs( + document_function=target_doc, + function_package='lib', + ctx=ctx + ) + + assert result == [] + + +@pytest.mark.parametrize("description,code_content,identifier,callee_package,expected", [ + # ======================================================================== + # Multi-function imports - should match exact identifiers + # ======================================================================== + ( + "Multi-function import - exact match for function1", + """import { function1, function2 } from "./utils.js";""", + "function1", + "./utils.js", + True, + ), + ( + "Multi-function import - exact match for function2", + """import { function1, function2 } from "./utils.js";""", + "function2", + "./utils.js", + True, + ), + # ======================================================================== + # Substring edge cases - should NOT match partial identifiers + # ======================================================================== + ( + "Multi-function import - 'func' should NOT match 'function1'", + """import { function1, function2 } from "./utils.js";""", + "func", + "./utils.js", + False, + ), + ( + "Multi-function import - 'add' should NOT match 'addHandler'", + """import { addHandler, removeHandler } from "./utils.js";""", + "add", + "./utils.js", + False, + ), + ( + "Single function import - 'parse' should NOT match 'parseAsync'", + """import { parseAsync } from "some-lib";""", + "parse", + "some-lib", + False, + ), + # ======================================================================== + # CommonJS require - substring edge cases + # ======================================================================== + ( + "CommonJS destructured require - exact match", + """const { function1, function2 } = require('./utils.js');""", + "function1", + "./utils.js", + True, + ), + ( + "CommonJS destructured require - 'func' should NOT match 'function1'", + """const { function1, function2 } = require('./utils.js');""", + "func", + "./utils.js", + False, + ), + # ======================================================================== + # Re-exports - substring edge cases + # ======================================================================== + ( + "Re-export - exact match", + """export { function1, function2 } from "./utils.js";""", + "function1", + "./utils.js", + True, + ), + ( + "Re-export - 'func' should NOT match 'function1'", + """export { function1, function2 } from "./utils.js";""", + "func", + "./utils.js", + False, + ), +]) +def test_is_package_imported_word_boundaries(parser, description, code_content, identifier, callee_package, expected): + """Test is_package_imported with word boundary edge cases to prevent substring false positives.""" + result = parser.is_package_imported(code_content, identifier, callee_package) assert result == expected, f"{description}: expected {expected}, got {result}" + +@pytest.mark.parametrize("function_content,expected_name", [ + # Object methods now use //(class: name) pattern + ( + "fetchData(url) {\n return fetch(url);\n}\n//(class: apiUtils)", + "apiUtils" + ), + ( + "format: function(data) {\n return JSON.stringify(data);\n}\n//(class: helpers)", + "helpers" + ), + # Standalone function (no annotation) + ( + "function regularFunc() {\n return true;\n}", + None + ), +]) +def test_get_class_name_for_object_methods(parser, function_content, expected_name): + """Test that object methods use the same //(class: name) pattern as classes.""" + doc = Document( + page_content=function_content, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + result = parser.get_class_name_from_class_function(doc) + assert result == expected_name + + +@pytest.mark.parametrize("description,function_content,file_content,expected", [ + ( + "Object method should be exported when object is in module.exports", + "fetchData(url) {\n return fetch(url);\n}\n//(class: apiUtils)", + """const apiUtils = { + fetchData(url) { + return fetch(url); + } +}; + +module.exports = { apiUtils }; +""", + True, + ), + ( + "Object method should be exported with ES6 named export", + "format(data) {\n return JSON.stringify(data);\n}\n//(class: helpers)", + """const helpers = { + format(data) { + return JSON.stringify(data); + } +}; + +export { helpers }; +""", + True, + ), + ( + "Object method should NOT be exported when object is not exported", + "internalMethod() {\n return 'internal';\n}\n//(class: privateUtils)", + """const privateUtils = { + internalMethod() { + return 'internal'; + } +}; + +// No exports +""", + False, + ), +]) +def test_is_exported_function_for_objects(parser, description, function_content, file_content, expected): + """Test is_exported_function correctly handles object methods with //(class: name) annotation.""" + source = "node_modules/test-pkg/index.js" + func_doc = Document( + page_content=function_content, + metadata={'source': source, 'content_type': 'functions_classes'} + ) + full_file_doc = Document( + page_content=file_content, + metadata={'source': source, 'content_type': 'simplified_code'} + ) + documents_of_full_sources = {source: full_file_doc} + + result = parser.is_exported_function(func_doc, documents_of_full_sources) + assert result == expected, f"{description}: expected {expected}, got {result}" + + +# ============================================================================= +# should_skip — node_modules/dist should NOT be skipped +# ============================================================================= + +@pytest.mark.parametrize("path,expected", [ + ("node_modules/express/dist/router.js", False), + ("node_modules/@babel/core/dist/index.js", False), + ("node_modules/pkg/dist/build/static/app.js", False), + ("dist/bundle.js", True), + ("src/dist/helper.js", False), + ("build/static/app.js", True), + ("app.min.js", True), + ("src/index.js", False), + ("coverage/report.js", True), +]) +def test_should_skip_node_modules_dist(path, expected): + """dist/ inside node_modules/ is legitimate source, not a build artifact.""" + assert ExtendedJavaScriptSegmenter.should_skip(path) == expected + + +# ============================================================================= +# Optional chaining in _get_function_calls +# ============================================================================= + +def test_get_function_calls_optional_chaining(parser): + """_get_function_calls should detect calls through optional chaining (?.).""" + caller = Document( + page_content="function test() {\n obj?.method(arg);\n}", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + calls = parser._get_function_calls(caller, "method") + assert any("method" in c for c in calls), f"Expected 'method' call, got {calls}" + + +def test_get_function_calls_chained_optional(parser): + """_get_function_calls should handle chained optional access a?.b?.method().""" + caller = Document( + page_content="function test() {\n a?.b?.method(x);\n}", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + calls = parser._get_function_calls(caller, "method") + assert any("method" in c for c in calls), f"Expected 'method' call, got {calls}" + + +# ============================================================================= +# Position check — destructuring defaults inside bodies +# ============================================================================= + +@pytest.mark.parametrize("code,expected_name", [ + ( + "method({callback = (x) => x}) {\n return callback;\n}", + "method", + ), + ( + "const handler = ({a, b}) => { return a + b; }", + "handler", + ), + ( + "resolve = ({id}) => { doStuff(); }", + "resolve", + ), + ( + "function outer() {\n const inner = (x) => x;\n}", + "outer", + ), +]) +def test_position_check_destructuring_in_body(parser, code, expected_name): + """Patterns with [^)]* should not match destructuring defaults inside bodies.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +# ============================================================================= +# Function expression with destructuring uses header +# ============================================================================= + +def test_pattern5_function_expr_with_destructuring(parser): + """Function expressions with destructuring params should extract the variable name.""" + doc = Document( + page_content="const fn = function({a, b}) { return a + b; }", + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == "fn" + + +# ============================================================================= +# Bare assignment and property patterns +# ============================================================================= + +@pytest.mark.parametrize("code,expected_name", [ + ( + "obj.method = function() { return 1; }", + "method", + ), + ( + "obj.handler = (x) => x * 2", + "handler", + ), + ( + "exports.default = async function(req, res) { res.send(); }", + "default", + ), + ( + "module.exports.init = function() {}", + "init", + ), +]) +def test_property_assignment_after_pattern_removal(parser, code, expected_name): + """Property assignments (obj.name = ...) should still work after removing patterns 6-8.""" + doc = Document( + page_content=code, + metadata={'source': 'test.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == expected_name + + +# ============================================================================= +# Optional chaining '?' leaks into identifiers in _get_function_calls +# ============================================================================= + +class TestOptionalChainingInFunctionCalls: + """_get_function_calls regex captures '?' from optional chaining, which + then leaks into search_for_called_function's identifier split, preventing + import resolution and local-var lookup.""" + + def test_optional_chaining_direct_call_detected(self, parser): + """_get_function_calls should detect calls through optional chaining.""" + caller = Document( + page_content="function process(res) {\n return res?.json();\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + calls = parser._get_function_calls(caller, "json") + assert len(calls) >= 1, "Should detect json() through optional chaining" + + def test_optional_chaining_identifier_has_no_question_mark(self, parser): + """After splitting a qualified optional-chaining call by '.', identifiers + should NOT contain '?'. The '?' is syntax, not part of the name.""" + caller = Document( + page_content="function process(res) {\n return res?.json();\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + calls = parser._get_function_calls(caller, "json") + for call in calls: + parts = call.split('.') + for part in parts: + clean = part.rstrip('(') + assert '?' not in clean, ( + f"Identifier '{clean}' still contains '?' from optional chaining" + ) + + def test_optional_chaining_search_for_called_function(self, parser): + """search_for_called_function should find calls through optional chaining + when the identifier matches an imported package.""" + caller = Document( + page_content=( + "import axios from 'axios';\n" + "function fetchData() {\n" + " return axios?.get('/api/data');\n" + "}" + ), + metadata={'source': 'app/api.js', 'content_type': 'functions_classes'} + ) + callee = Document( + page_content="get(url) {\n return this.request('GET', url);\n}\n//(class: axios)", + metadata={'source': 'node_modules/axios/index.js', 'content_type': 'functions_classes'} + ) + code_docs = { + 'app/api.js': Document( + page_content="import axios from 'axios';\nfunction fetchData() { return axios?.get('/api/data'); }", + metadata={'source': 'app/api.js', 'content_type': 'simplified_code'} + ) + } + result = parser.search_for_called_function( + caller_function=caller, + callee_function_name="get", + callee_function=callee, + callee_function_package="axios", + code_documents=code_docs, + type_documents=[], + callee_function_file_name='node_modules/axios/index.js', + fields_of_types={}, + functions_local_variables_index={}, + documents_of_functions=[], + ) + assert result is True, ( + "search_for_called_function should match 'axios?.get()' — " + "the '?' is syntax, not part of the identifier" + ) + + +# ============================================================================= +# Named export chunks lose 'export' prefix — is_exported_function +# fallback doesn't check for 'export function name' pattern +# ============================================================================= + +class TestNamedExportWithoutPrefix: + """Tree-sitter extracts 'function compile(...)' without the 'export' prefix. + is_exported_function should still detect it as exported via the fallback.""" + + def test_named_export_function_without_prefix(self, parser): + """A chunk 'function compile(input) {...}' from an 'export function compile' + should be detected as exported by checking the full source.""" + func_doc = Document( + page_content="function compile(input) {\n return transform(input);\n}", + metadata={'source': 'node_modules/handlebars/index.js', 'content_type': 'functions_classes'} + ) + full_source = Document( + page_content=( + "import { transform } from './transform';\n" + "export function compile(input) {\n" + " return transform(input);\n" + "}\n" + "export function precompile(input) {\n" + " return parse(input);\n" + "}\n" + ), + metadata={'source': 'node_modules/handlebars/index.js', 'content_type': 'simplified_code'} + ) + docs_of_full_sources = {'node_modules/handlebars/index.js': full_source} + result = parser.is_exported_function(func_doc, docs_of_full_sources) + assert result is True, ( + "Function 'compile' is exported via 'export function compile' in source, " + "even though the chunk itself lacks the 'export' prefix" + ) + + def test_named_export_async_function_without_prefix(self, parser): + """Async export: 'async function fetchData(...)' without 'export' in chunk.""" + func_doc = Document( + page_content="async function fetchData(url) {\n return await fetch(url);\n}", + metadata={'source': 'node_modules/my-lib/api.js', 'content_type': 'functions_classes'} + ) + full_source = Document( + page_content="export async function fetchData(url) {\n return await fetch(url);\n}\n", + metadata={'source': 'node_modules/my-lib/api.js', 'content_type': 'simplified_code'} + ) + docs_of_full_sources = {'node_modules/my-lib/api.js': full_source} + result = parser.is_exported_function(func_doc, docs_of_full_sources) + assert result is True, ( + "Async function 'fetchData' is exported via 'export async function fetchData' in source" + ) + + def test_non_exported_function_still_false(self, parser): + """A function that is genuinely not exported should still return False.""" + func_doc = Document( + page_content="function internal() {\n return 'private';\n}", + metadata={'source': 'node_modules/pkg/utils.js', 'content_type': 'functions_classes'} + ) + full_source = Document( + page_content=( + "function internal() {\n return 'private';\n}\n" + "function external() {\n return internal();\n}\n" + "module.exports = external;\n" + ), + metadata={'source': 'node_modules/pkg/utils.js', 'content_type': 'simplified_code'} + ) + docs_of_full_sources = {'node_modules/pkg/utils.js': full_source} + result = parser.is_exported_function(func_doc, docs_of_full_sources) + assert result is False + + +# ============================================================================= +# BFS ValueError on class docs — get_function_name raises for classes +# ============================================================================= + +class TestGetFunctionNameOnClassDocs: + """get_function_name raises ValueError on class documents. In BFS, this + aborts the entire search loop.""" + + def test_named_class_raises_valueerror(self, parser): + """get_function_name should raise ValueError for named class docs.""" + doc = Document( + page_content="class Router {\n constructor() {}\n route(path) {}\n}", + metadata={'source': 'app/router.js', 'content_type': 'functions_classes'} + ) + assert not parser.is_function(doc) + with pytest.raises(ValueError, match="Only function document"): + parser.get_function_name(doc) + + def test_anonymous_class_is_function_returns_false(self, parser): + """is_function should return False for anonymous class expressions, + preventing ValueError in get_function_name downstream.""" + doc = Document( + page_content="export default class {\n connect() {}\n disconnect() {}\n}", + metadata={'source': 'node_modules/pkg/db.js', 'content_type': 'functions_classes'} + ) + assert not parser.is_function(doc), ( + "Anonymous class 'export default class { ... }' should NOT pass is_function. " + "The regex requires class\\s+[\\w$]+ (a named identifier after 'class')." + ) + + +# ============================================================================= +# Nested parens in arrow function default parameters +# ============================================================================= + +class TestNestedParensInArrowDefaults: + """Arrow functions with nested parens in default params (e.g. getDefaults()) should extract the variable name.""" + + def test_arrow_with_function_call_default(self, parser): + """Arrow function with a function call as default parameter value.""" + doc = Document( + page_content="const handler = (options = getDefaults()) => {\n process(options);\n}", + metadata={'source': 'app/handler.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == "handler", ( + f"Expected 'handler' but got '{name}'. [^)]* in arrow regex stops at " + "the inner ')' of getDefaults(), causing fallthrough to catch-all." + ) + + def test_arrow_with_nested_method_call_default(self, parser): + """Arrow with obj.method() call in default parameter.""" + doc = Document( + page_content="const processor = (cfg = Config.load()) => {\n return cfg.run();\n}", + metadata={'source': 'app/process.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == "processor", f"Expected 'processor' but got '{name}'" + + def test_arrow_with_simple_default_still_works(self, parser): + """Ensure simple defaults (no nested parens) still work after the fix.""" + doc = Document( + page_content="const greet = (name = 'world') => {\n console.log(name);\n}", + metadata={'source': 'app/greet.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == "greet" + + def test_arrow_with_numeric_default_still_works(self, parser): + """Numeric default: no nested parens.""" + doc = Document( + page_content="const delay = (ms = 1000) => {\n return new Promise(r => setTimeout(r, ms));\n}", + metadata={'source': 'app/delay.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == "delay" + + +# ============================================================================= +# Anonymous export default arrow returns inner function name +# ============================================================================= + +class TestAnonymousExportDefaultArrow: + """export default (x) => {...} with inner named functions should return ''.""" + + def test_anonymous_default_arrow_with_inner_function(self, parser): + """An anonymous export default arrow should return '' even when the body + contains named functions.""" + doc = Document( + page_content=( + "export default async (req) => {\n" + " function validate(r) { return r.ok; }\n" + " return validate(req);\n" + "}" + ), + metadata={'source': 'node_modules/pkg/handler.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == '', ( + f"Expected '' for anonymous export default arrow, but got '{name}'. " + "The catch-all matched an inner function name from the body." + ) + + def test_anonymous_default_arrow_simple(self, parser): + """Simple anonymous export default arrow with no inner functions.""" + doc = Document( + page_content="export default (x, y) => {\n return x + y;\n}", + metadata={'source': 'node_modules/pkg/add.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == '', f"Expected '' for anonymous arrow, got '{name}'" + + def test_named_default_function_still_works(self, parser): + """Named export default function should still return the name.""" + doc = Document( + page_content="export default function handler(req) {\n return req.body;\n}", + metadata={'source': 'node_modules/pkg/handler.js', 'content_type': 'functions_classes'} + ) + assert parser.get_function_name(doc) == "handler" + + +# ============================================================================= +# is_function returns True for anonymous class expressions +# ============================================================================= + +class TestGetFunctionNamePerformance: + """get_function_name must not catastrophically backtrack on large function bodies + containing many parenthesized expressions (e.g., compiled TypeScript output).""" + + def test_large_prototype_method_completes_fast(self, parser): + """Reproduces XMLSerializerImpl.js — prototype method assignment with a + 17K+ char body full of var declarations and parenthesized calls. + Before fix, get_function_name took 34.8s on this content because + arrow-function regex patterns searched the full content.""" + import time + body_lines = [] + for i in range(300): + body_lines.append(f" var markup{i} = this._serialize(node, (opts{i} || {{}}));") + body_lines.append(f" if (requireWellFormed && (node.localName.indexOf(':') !== -1 ||") + body_lines.append(f" !algorithm_1.xml_isName(node.localName))) {{") + body_lines.append(f" throw new Error('not well-formed: ' + node.localName);") + body_lines.append(f" }}") + body = "\n".join(body_lines) + content = ( + f"XMLSerializerImpl.prototype._serializeElementNS = " + f"function (node, namespace, prefixMap, prefixIndex, requireWellFormed) {{\n" + f" var e_1, _a;\n{body}\n }}" + ) + assert len(content) > 15000 + doc = Document( + page_content=content, + metadata={'source': 'node_modules/@oozcitak/dom/lib/serializer/XMLSerializerImpl.js', + 'content_type': 'functions_classes'} + ) + t0 = time.monotonic() + name = parser.get_function_name(doc) + elapsed = time.monotonic() - t0 + assert name == "_serializeElementNS" + assert elapsed < 1.0, f"get_function_name took {elapsed:.2f}s (should be <1s)" + + def test_arrow_with_destructured_params_still_works(self, parser): + """Arrow function with destructured params — first '{' is in params, + not in the body. Ensures arrow_header covers the full signature.""" + import time + content = ( + "const handleRequest = ({method, url, headers}) => {\n" + " const response = fetch(url);\n" + " let status = 200;\n" + " return response;\n" + "}" + ) + doc = Document( + page_content=content, + metadata={'source': 'app/request.js', 'content_type': 'functions_classes'} + ) + t0 = time.monotonic() + name = parser.get_function_name(doc) + elapsed = time.monotonic() - t0 + assert name == "handleRequest" + assert elapsed < 1.0 + + def test_small_doc_with_parens_in_body(self, parser): + """Reproduces lodash LazyWrapper — anonymous function assigned to + computed property, body has var = (expr) patterns that triggered + catastrophic regex backtracking even on 556 chars.""" + import time + content = ( + "LazyWrapper.prototype[methodName] = function(n) {\n" + " n = n === undefined ? 1 : nativeMax(toInteger(n), 0);\n" + " var result = (this.__filtered__ && !index)\n" + " ? new LazyWrapper(this)\n" + " : this.clone();\n" + " if (result.__filtered__) {\n" + " result.__takeCount__ = nativeMin(n, result.__takeCount__);\n" + " } else {\n" + " result.__views__.push({\n" + " 'size': nativeMin(n, MAX_ARRAY_LENGTH),\n" + " 'type': methodName + (result.__dir__ < 0 ? 'Right' : '')\n" + " });\n" + " }\n" + " return result;\n" + " };" + ) + doc = Document( + page_content=content, + metadata={'source': 'node_modules/lodash/lodash.js', + 'content_type': 'functions_classes'} + ) + t0 = time.monotonic() + name = parser.get_function_name(doc) + elapsed = time.monotonic() - t0 + assert elapsed < 1.0, f"get_function_name took {elapsed:.2f}s (should be <1s)" + + def test_babel_iife_wrapped_arrow_completes_fast(self, parser): + """Reproduces babel package.js — IIFE-wrapped arrow with nested parens + in the header like (0, _utils.makeStaticFileCache)((filepath, content) => + caused catastrophic regex backtracking (38.8s on 662 chars).""" + import time + content = ( + "const readConfigPackage = " + "(0, _utils.makeStaticFileCache)((filepath, content) => {\n" + " let options;\n" + " try {\n" + " options = (0, _json.parse)(content);\n" + " } catch (err) {\n" + " throw new _configError.default(\n" + " `Error while parsing JSON - ${err.message}`, filepath);\n" + " }\n" + " if (!options) throw new _configError.default(\n" + " `No config detected in ${filepath}`, filepath);\n" + " if (typeof options !== 'object') throw new _configError.default(\n" + " `Config returned typeof ${typeof options}`, filepath);\n" + " if (Array.isArray(options)) throw new _configError.default(\n" + " `Config returned an array`, filepath);\n" + " delete options['$schema'];\n" + " return {\n" + " filepath,\n" + " dirname: _path().dirname(filepath),\n" + " options\n" + " };\n" + "});" + ) + doc = Document( + page_content=content, + metadata={'source': 'node_modules/@babel/core/lib/config/files/package.js', + 'content_type': 'functions_classes'} + ) + t0 = time.monotonic() + name = parser.get_function_name(doc) + elapsed = time.monotonic() - t0 + assert elapsed < 1.0, f"get_function_name took {elapsed:.2f}s (should be <1s)" + assert name == "readConfigPackage", f"Expected 'readConfigPackage', got '{name}'" + + def test_constructor_with_inner_arrow_not_confused(self, parser): + """Reproduces PackageURL constructor — body contains arrow functions + like `key => {` which caused arrow_header to extend into the body, + then P9 matched `required =` instead of extracting `constructor`.""" + content = ( + "constructor(type, namespace, name, version, qualifiers, subpath) {\n" + " let required = { 'type': type, 'name': name };\n" + " Object.keys(required).forEach(key => {\n" + " if (!required[key]) {\n" + " throw new Error('Invalid purl: \"' + key + '\" is a required field.');\n" + " }\n" + " });\n" + " }\n" + "//(class: PackageURL)" + ) + doc = Document( + page_content=content, + metadata={'source': 'node_modules/packageurl-js/lib/package-url.js', + 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == "constructor", f"Expected 'constructor', got '{name}'" + + +class TestIsFunctionAnonymousClass: + """is_function regex requires class\\s+[\\w$]+ (a named class). Anonymous + classes like 'export default class { ... }' bypass the check.""" + + def test_anonymous_class_is_not_function(self, parser): + """Anonymous default export class should not be considered a function.""" + doc = Document( + page_content="export default class {\n connect() {}\n}", + metadata={'source': 'node_modules/pkg/db.js', 'content_type': 'functions_classes'} + ) + assert not parser.is_function(doc), ( + "Anonymous class should not pass is_function" + ) + + def test_anonymous_class_with_methods_is_not_function(self, parser): + """Anonymous class with multiple methods.""" + doc = Document( + page_content=( + "export default class {\n" + " constructor() { this.x = 1; }\n" + " render() { return this.x; }\n" + "}" + ), + metadata={'source': 'node_modules/pkg/component.js', 'content_type': 'functions_classes'} + ) + assert not parser.is_function(doc) + + def test_named_class_still_not_function(self, parser): + """Named classes should still return False (existing behavior).""" + doc = Document( + page_content="class MyClass {\n constructor() {}\n}", + metadata={'source': 'app/my.js', 'content_type': 'functions_classes'} + ) + assert not parser.is_function(doc) + + def test_regular_function_still_is_function(self, parser): + """Regular functions should still return True.""" + doc = Document( + page_content="function doStuff() {\n return 1;\n}", + metadata={'source': 'app/util.js', 'content_type': 'functions_classes'} + ) + assert parser.is_function(doc) + + +# ============================================================================= +# Generator export regex misses function* (no space) +# ============================================================================= + +class TestGeneratorExportDetection: + """_is_exportable_function must detect generator functions exported as + 'export function* genName()' where * immediately follows 'function'.""" + + def test_export_function_star_no_space(self, parser): + result = parser._is_exportable_function( + "genValues", + "export function* genValues() { yield 1; }" + ) + assert result is True, "function* (no space) should be detected as exported" + + def test_export_function_space_star(self, parser): + result = parser._is_exportable_function( + "genValues", + "export function *genValues() { yield 1; }" + ) + assert result is True, "function * (space before star) should be detected" + + def test_export_function_star_space(self, parser): + result = parser._is_exportable_function( + "genValues", + "export function * genValues() { yield 1; }" + ) + assert result is True, "function * genValues (space around star) should be detected" + + def test_export_async_function_star(self, parser): + result = parser._is_exportable_function( + "genValues", + "export async function* genValues() { yield 1; }" + ) + assert result is True, "async function* should be detected" + + def test_non_generator_still_works(self, parser): + result = parser._is_exportable_function( + "myFunc", + "export function myFunc() { return 1; }" + ) + assert result is True, "regular export function should still work" + + +# ============================================================================= +# Anonymous export-default arrow without parens +# ============================================================================= + +class TestAnonymousExportDefaultArrowNoParens: + """export default x => { innerFunc(); } should return '' not 'innerFunc'.""" + + def test_single_param_no_parens(self, parser): + doc = Document( + page_content="export default x => {\n innerFunc();\n}", + metadata={'source': 'node_modules/pkg/index.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == '', f"Expected '' for anonymous single-param arrow, got '{name}'" + + def test_single_param_async_no_parens(self, parser): + doc = Document( + page_content="export default async req => {\n validate(req);\n return respond();\n}", + metadata={'source': 'node_modules/pkg/handler.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == '', f"Expected '' for async single-param arrow, got '{name}'" + + def test_single_param_expression_body(self, parser): + doc = Document( + page_content="export default x => x + 1", + metadata={'source': 'node_modules/pkg/inc.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == '', f"Expected '' for expression-body arrow, got '{name}'" + + +# ============================================================================= +# arrow_header extends into inner arrows in destructured params +# ============================================================================= + +class TestArrowHeaderInnerArrowExtension: + """Method definitions with destructured params containing inner arrows + should not have arrow_header extend into the params.""" + + def test_method_with_inner_arrow_default(self, parser): + doc = Document( + page_content="method({cb = x => {return x}}) {\n doStuff();\n}", + metadata={'source': 'node_modules/pkg/util.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == 'method', f"Expected 'method' but got '{name}'" + + def test_method_with_complex_arrow_default(self, parser): + doc = Document( + page_content="handle({onError = err => {log(err)}, timeout = 5000}) {\n process();\n}", + metadata={'source': 'node_modules/pkg/handler.js', 'content_type': 'functions_classes'} + ) + name = parser.get_function_name(doc) + assert name == 'handle', f"Expected 'handle' but got '{name}'" + + +# ============================================================================= +# Anonymous function name collisions in local-var map +# ============================================================================= + +class TestLocalVarMapCollisions: + """Multiple anonymous functions from same file should not overwrite each other.""" + + def test_multiple_anonymous_same_file_no_collision(self, parser): + """Anonymous functions (empty name) should be skipped rather than collide.""" + docs = [ + Document( + page_content="(x) => {\n const a = x + 1;\n return a;\n}", + metadata={'source': 'app/utils.js', 'content_type': 'functions_classes'} + ), + Document( + page_content="(y) => {\n const b = y * 2;\n return b;\n}", + metadata={'source': 'app/utils.js', 'content_type': 'functions_classes'} + ), + ] + mappings = parser.create_map_of_local_vars(docs) + assert len(mappings) == 0, ( + f"Expected 0 entries (anonymous functions skipped), got {len(mappings)}. " + "Empty function names cannot be looked up and should not be indexed." + ) + + def test_named_functions_still_indexed(self, parser): + """Named functions should still be indexed normally.""" + docs = [ + Document( + page_content="function add(x) {\n const result = x + 1;\n return result;\n}", + metadata={'source': 'app/utils.js', 'content_type': 'functions_classes'} + ), + Document( + page_content="function multiply(y) {\n const product = y * 2;\n return product;\n}", + metadata={'source': 'app/utils.js', 'content_type': 'functions_classes'} + ), + ] + mappings = parser.create_map_of_local_vars(docs) + assert len(mappings) == 2 + assert 'add@app/utils.js' in mappings + assert 'multiply@app/utils.js' in mappings + + +# ============================================================================= +# _is_exportable duplication +# ============================================================================= + +class TestIsExportableSharedLogic: + """_is_exportable_class and _is_exportable_function should share CommonJS checks.""" + + def test_class_commonjs_module_exports_direct(self, parser): + assert parser._is_exportable_class("MyClass", "module.exports = MyClass;") + + def test_class_commonjs_module_exports_object(self, parser): + assert parser._is_exportable_class("MyClass", "module.exports = { MyClass };") + + def test_class_commonjs_module_exports_property(self, parser): + assert parser._is_exportable_class("MyClass", "module.exports.MyClass = MyClass;") + + def test_class_commonjs_exports_property(self, parser): + assert parser._is_exportable_class("MyClass", "exports.MyClass = MyClass;") + + def test_function_commonjs_module_exports_direct(self, parser): + assert parser._is_exportable_function("myFunc", "module.exports = myFunc;") + + def test_function_commonjs_module_exports_object(self, parser): + assert parser._is_exportable_function("myFunc", "module.exports = { myFunc };") + + def test_function_commonjs_module_exports_property(self, parser): + assert parser._is_exportable_function("myFunc", "module.exports.myFunc = myFunc;") + + def test_function_commonjs_exports_property(self, parser): + assert parser._is_exportable_function("myFunc", "exports.myFunc = myFunc;") + + +# ============================================================================= +# Bug: is_package_imported split('from') corrupts identifiers containing "from" +# Lines 799, 822: line.split('from')[0] splits on substrings like "fromString" +# ============================================================================= + +class TestIsPackageImportedFromSplit: + """split('from') breaks identifiers that start with or contain 'from'.""" + + def test_es6_import_identifier_starting_with_from(self, parser): + """import { fromString } from 'packageurl-js' should detect 'fromString'.""" + code = "import { fromString } from 'packageurl-js';" + assert parser.is_package_imported(code, "fromString", "packageurl-js") + + def test_es6_import_identifier_from_inside_name(self, parser): + """import { transformData } from 'utils' should detect 'transformData'.""" + code = "import { transformData } from 'utils';" + assert parser.is_package_imported(code, "transformData", "utils") + + def test_reexport_identifier_starting_with_from(self, parser): + """export { fromBuffer } from 'uuid' should detect 'fromBuffer'.""" + code = "export { fromBuffer } from 'uuid';" + assert parser.is_package_imported(code, "fromBuffer", "uuid") + + def test_normal_import_still_works(self, parser): + """import { template } from 'lodash' should still work after fix.""" + code = "import { template } from 'lodash';" + assert parser.is_package_imported(code, "template", "lodash") + + +# ============================================================================= +# Bug: _get_function_calls called without code_documents, alias resolution dead +# Line 236: search_for_called_function doesn't pass code_documents +# ============================================================================= + +class TestAliasResolutionPassthrough: + """search_for_called_function must pass code_documents so alias resolution works.""" + + def test_aliased_import_detected_via_search_for_called_function(self, parser): + """When an aliased import exists, search_for_called_function should detect it.""" + caller_code = "function handler() {\n myAlias();\n}" + caller_doc = Document( + page_content=caller_code, + metadata={"source": "handlers/main.js"} + ) + + callee_code = "function original() { return 1; }" + callee_doc = Document( + page_content=callee_code, + metadata={"source": "node_modules/somelib/index.js"} + ) + + full_file_code = "import { original as myAlias } from 'somelib';\n\n" + caller_code + full_file_doc = Document( + page_content=full_file_code, + metadata={"source": "handlers/main.js"} + ) + + code_documents = {"handlers/main.js": full_file_doc} + + result = parser.search_for_called_function( + caller_function=caller_doc, + callee_function_name="original", + callee_function=callee_doc, + callee_function_package="somelib", + code_documents=code_documents, + type_documents=[], + callee_function_file_name="index.js", + fields_of_types={}, + functions_local_variables_index={}, + documents_of_functions=[], + ) + assert result is True, "Alias 'myAlias' for 'original' should be detected" + + +# ============================================================================= +# Bug: $ missing from regex lookbehind in _get_function_calls +# Lines 196, 199: (? 0, "template() should match" + + def test_jquery_dollar_qualified_call_detected(self, parser): + """$.ajax() is a call to ajax through qualifier $ — should be detected.""" + code = "function fetch() {\n $.ajax('/api');\n}" + doc = Document(page_content=code, metadata={"source": "app.js"}) + calls = parser._get_function_calls(doc, "ajax") + # $.ajax() is a legitimate call to ajax. The $ qualifier isn't in the + # qualifier char class [\w.?()], so the match is 'ajax' not '$.ajax'. + # This is correct — the call IS to ajax. + assert len(calls) > 0, "$.ajax() should be detected as a call to ajax" + + +# ============================================================================= +# Bug: get_package_names returns single-element list for JS +# chain_of_calls_retriever.py:593 accesses [1] → IndexError +# ============================================================================= + +class TestGetPackageNamesSingleElement: + """JS get_package_names returns single-element list, so [1] index crashes.""" + + def test_third_party_package_returns_single_element(self, parser): + """Third-party packages return ['package_name'] — only one element.""" + doc = Document(page_content="function foo() {}", + metadata={"source": "node_modules/lodash/index.js"}) + result = parser.get_package_names(doc) + assert result == ["lodash"] + assert len(result) == 1 + + def test_third_party_scoped_returns_single_element(self, parser): + """Scoped packages return ['@scope/pkg'] — still one element.""" + doc = Document(page_content="function foo() {}", + metadata={"source": "node_modules/@babel/core/index.js"}) + result = parser.get_package_names(doc) + assert result == ["@babel/core"] + assert len(result) == 1 + + def test_root_project_returns_single_element(self, parser): + """Root project files return ['root_project'] — single element.""" + doc = Document(page_content="function foo() {}", + metadata={"source": "src/app.js"}) + result = parser.get_package_names(doc) + assert result == ["root_project"] + assert len(result) == 1 + + def test_index_0_always_valid(self, parser): + """Accessing [0] should always work for any document.""" + for source in ["node_modules/express/index.js", "src/app.js", + "node_modules/@types/node/index.d.ts"]: + doc = Document(page_content="var x;", metadata={"source": source}) + names = parser.get_package_names(doc) + assert len(names) >= 1 + _ = names[0] # Should never raise + + def test_index_1_raises_for_js(self, parser): + """Accessing [1] should raise IndexError — this is the bug CCA hits.""" + doc = Document(page_content="function foo() {}", + metadata={"source": "node_modules/lodash/index.js"}) + result = parser.get_package_names(doc) + with pytest.raises(IndexError): + _ = result[1] + + +# ============================================================================= +# Bug: print_call_hierarchy doesn't handle empty function_name string +# chain_of_calls_retriever.py:702 — ValueError is caught but empty string isn't +# ============================================================================= + +class TestPrintCallHierarchyEmptyName: + """print_call_hierarchy catches ValueError but doesn't check for empty string.""" + + def test_get_function_name_empty_string_handled(self, parser): + """Documents where get_function_name returns '' should not crash hierarchy.""" + doc = Document( + page_content="// just a comment block\n/* nothing here */", + metadata={"source": "utils.js"} + ) + try: + name = parser.get_function_name(doc) + except ValueError: + name = None + + # The bug: if get_function_name returns '' instead of raising, + # print_call_hierarchy formats it as (package=...,function=,depth=0) + # which is a meaningless entry in the call hierarchy. + # The fix in chain_of_calls_retriever.py guards against empty strings. + assert name is None or isinstance(name, str) + + +# ============================================================================= +# _build_class_hierarchy index and _is_subclass_of optimization +# ============================================================================= + +class TestBuildClassHierarchy: + """Verify the class hierarchy index is built correctly from code_documents.""" + + def test_simple_es6_extends(self, parser): + docs = { + "a.js": Document(page_content="class Dog extends Animal { bark() {} }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "Dog" in hierarchy + extends_clause, parent = hierarchy["Dog"] + assert parent == "Animal" + assert "Animal" in extends_clause + + def test_mixin_extends(self, parser): + docs = { + "a.js": Document(page_content="class MyComponent extends EventEmitter(Base) { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "MyComponent" in hierarchy + extends_clause, parent = hierarchy["MyComponent"] + assert parent == "Base" + assert "EventEmitter" in extends_clause + + def test_chained_mixin(self, parser): + docs = { + "a.js": Document(page_content="class Widget extends Draggable(Resizable(Component)) { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + extends_clause, parent = hierarchy["Widget"] + assert parent == "Component" + assert "Draggable" in extends_clause + assert "Resizable" in extends_clause + + def test_prototype_inherits(self, parser): + docs = { + "a.js": Document(page_content="util.inherits(ReadStream, EventEmitter);", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "ReadStream" in hierarchy + extends_clause, parent = hierarchy["ReadStream"] + assert extends_clause is None + assert parent == "EventEmitter" + + def test_object_create_prototype(self, parser): + docs = { + "a.js": Document(page_content="Child.prototype = Object.create(Parent.prototype);", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "Child" in hierarchy + assert hierarchy["Child"] == (None, "Parent") + + def test_set_prototype_of(self, parser): + docs = { + "a.js": Document(page_content="Object.setPrototypeOf(Sub.prototype, Super.prototype);", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "Sub" in hierarchy + assert hierarchy["Sub"] == (None, "Super") + + def test_multiple_classes_across_files(self, parser): + docs = { + "a.js": Document(page_content="class A extends B { }", metadata={"source": "a.js"}), + "b.js": Document(page_content="class B extends C { }", metadata={"source": "b.js"}), + "c.js": Document(page_content="class C { }", metadata={"source": "c.js"}), + } + hierarchy = parser._build_class_hierarchy(docs) + assert hierarchy["A"] == ("B", "B") + assert hierarchy["B"] == ("C", "C") + assert "C" not in hierarchy + + def test_es6_takes_precedence_over_prototype(self, parser): + """If both ES6 and prototype patterns exist, ES6 wins (indexed first).""" + docs = { + "a.js": Document( + page_content="class X extends Y { }\nutil.inherits(X, Z);", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert hierarchy["X"][1] == "Y" + + def test_empty_documents(self, parser): + hierarchy = parser._build_class_hierarchy({}) + assert hierarchy == {} + + +class TestIsSubclassOfOptimized: + """Verify _is_subclass_of correctness with the hierarchy index.""" + + def test_direct_subclass(self, parser): + docs = { + "a.js": Document(page_content="class Dog extends Animal { }", + metadata={"source": "a.js"}) + } + assert parser._is_subclass_of("Dog", "Animal", docs) + + def test_not_subclass(self, parser): + docs = { + "a.js": Document(page_content="class Dog extends Animal { }", + metadata={"source": "a.js"}) + } + assert not parser._is_subclass_of("Dog", "Vehicle", docs) + + def test_transitive_chain(self, parser): + docs = { + "a.js": Document(page_content="class A extends B { }", metadata={"source": "a.js"}), + "b.js": Document(page_content="class B extends C { }", metadata={"source": "b.js"}), + } + assert parser._is_subclass_of("A", "C", docs) + + def test_mixin_match(self, parser): + docs = { + "a.js": Document(page_content="class X extends Mixin(Base) { }", + metadata={"source": "a.js"}) + } + assert parser._is_subclass_of("X", "Mixin", docs) + assert parser._is_subclass_of("X", "Base", docs) + + def test_circular_reference_no_infinite_loop(self, parser): + docs = { + "a.js": Document(page_content="class A extends B { }", metadata={"source": "a.js"}), + "b.js": Document(page_content="class B extends A { }", metadata={"source": "b.js"}), + } + assert not parser._is_subclass_of("A", "Z", docs) + + def test_empty_child_or_parent(self, parser): + docs = {"a.js": Document(page_content="class X extends Y { }", metadata={"source": "a.js"})} + assert not parser._is_subclass_of("", "Y", docs) + assert not parser._is_subclass_of("X", "", docs) + assert not parser._is_subclass_of(None, "Y", docs) + + def test_prototype_inheritance(self, parser): + docs = { + "a.js": Document(page_content="util.inherits(ReadStream, EventEmitter);", + metadata={"source": "a.js"}) + } + assert parser._is_subclass_of("ReadStream", "EventEmitter", docs) + + def test_hierarchy_cache_reused(self, parser): + """Same code_documents dict should reuse cached hierarchy.""" + docs = { + "a.js": Document(page_content="class A extends B { }", metadata={"source": "a.js"}) + } + parser._is_subclass_of("A", "B", docs) + cache_key_1 = parser._class_hierarchy_cache_key + + parser._is_subclass_of("A", "C", docs) + cache_key_2 = parser._class_hierarchy_cache_key + + assert cache_key_1 == cache_key_2 + + def test_different_docs_rebuilds_cache(self, parser): + docs1 = {"a.js": Document(page_content="class A extends B { }", metadata={"source": "a.js"})} + docs2 = {"a.js": Document(page_content="class X extends Y { }", metadata={"source": "a.js"})} + + parser._is_subclass_of("A", "B", docs1) + key1 = parser._class_hierarchy_cache_key + + parser._is_subclass_of("X", "Y", docs2) + key2 = parser._class_hierarchy_cache_key + + assert key1 != key2 + + +class TestGetParentOptimized: + """Verify _get_direct_parent, _get_prototype_parent, _get_parent with index.""" + + def test_get_direct_parent(self, parser): + docs = {"a.js": Document(page_content="class X extends Y { }", metadata={"source": "a.js"})} + assert parser._get_direct_parent("X", docs) == "Y" + + def test_get_direct_parent_mixin(self, parser): + docs = {"a.js": Document(page_content="class X extends Mixin(Base) { }", metadata={"source": "a.js"})} + assert parser._get_direct_parent("X", docs) == "Base" + + def test_get_direct_parent_not_found(self, parser): + docs = {"a.js": Document(page_content="class X { }", metadata={"source": "a.js"})} + assert parser._get_direct_parent("X", docs) is None + + def test_get_prototype_parent(self, parser): + docs = {"a.js": Document(page_content="util.inherits(Child, Parent);", metadata={"source": "a.js"})} + assert parser._get_prototype_parent("Child", docs) == "Parent" + + def test_get_prototype_parent_not_found(self, parser): + docs = {"a.js": Document(page_content="class X extends Y { }", metadata={"source": "a.js"})} + assert parser._get_prototype_parent("X", docs) is None + + def test_get_parent_prefers_es6(self, parser): + docs = {"a.js": Document( + page_content="class X extends Y { }\nutil.inherits(X, Z);", + metadata={"source": "a.js"})} + assert parser._get_parent("X", docs) == "Y" + + +# ============================================================================= +# Bug: _build_class_hierarchy uses \w+ which misses $ in JS identifiers +# ============================================================================= + +class TestDollarSignInClassHierarchy: + """$ is a valid JS identifier char — hierarchy builder must handle it.""" + + def test_dollar_prefixed_es6_class(self, parser): + """class $Component extends React.Component should be in the hierarchy.""" + docs = { + "a.js": Document(page_content="class $Component extends Base { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "$Component" in hierarchy, f"$Component not found in hierarchy: {hierarchy}" + _, parent = hierarchy["$Component"] + assert parent == "Base" + + def test_dollar_prefixed_parent(self, parser): + """class Child extends $Base should resolve $Base as parent.""" + docs = { + "a.js": Document(page_content="class Child extends $Base { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "Child" in hierarchy + _, parent = hierarchy["Child"] + assert parent == "$Base" + + def test_dollar_prefixed_mixin_parent(self, parser): + """class Foo extends Mixin($Bar) should resolve $Bar as innermost parent.""" + docs = { + "a.js": Document(page_content="class Foo extends Mixin($Bar) { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "Foo" in hierarchy + _, parent = hierarchy["Foo"] + assert parent == "$Bar" + + def test_dollar_prefixed_prototype_inherits(self, parser): + """util.inherits($Stream, EventEmitter) should capture $Stream.""" + docs = { + "a.js": Document(page_content="util.inherits($Stream, EventEmitter);", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "$Stream" in hierarchy + assert hierarchy["$Stream"] == (None, "EventEmitter") + + def test_is_subclass_of_with_dollar(self, parser): + """_is_subclass_of should work with $-prefixed class names.""" + docs = { + "a.js": Document(page_content="class $Widget extends Component { }", + metadata={"source": "a.js"}) + } + assert parser._is_subclass_of("$Widget", "Component", docs) + + def test_dollar_in_both_child_and_parent(self, parser): + """Both child and parent have $ prefix.""" + docs = { + "a.js": Document(page_content="class $Child extends $Parent { }", + metadata={"source": "a.js"}) + } + hierarchy = parser._build_class_hierarchy(docs) + assert "$Child" in hierarchy + _, parent = hierarchy["$Child"] + assert parent == "$Parent" + + +# ============================================================================= +# Bug: recursive _get_function_calls omits code_documents — chained alias broken +# ============================================================================= + +class TestChainedAliasResolution: + """Recursive _get_function_calls must pass code_documents for chained aliases.""" + + def test_alias_of_alias_es6(self, parser): + """import { orig as mid } then import { mid as final } — final() should match orig.""" + caller_code = "function handler() {\n final();\n}" + caller_doc = Document( + page_content=caller_code, + metadata={"source": "handler.js"} + ) + + full_file_code = ( + "import { orig as mid } from 'pkg';\n" + "import { mid as final } from './re-export';\n" + "\n" + caller_code + ) + full_file_doc = Document( + page_content=full_file_code, + metadata={"source": "handler.js"} + ) + + code_documents = {"handler.js": full_file_doc} + + calls = parser._get_function_calls(caller_doc, "orig", code_documents) + assert "final" in calls, ( + f"Chained alias orig→mid→final should resolve. Got: {calls}" + ) + + def test_alias_commonjs_chain(self, parser): + """const { orig: mid } = require('a') then const { mid: local } = require('b').""" + caller_code = "function run() {\n local();\n}" + caller_doc = Document( + page_content=caller_code, + metadata={"source": "run.js"} + ) + + full_file_code = ( + "const { orig: mid } = require('a');\n" + "const { mid: local } = require('b');\n" + "\n" + caller_code + ) + full_file_doc = Document( + page_content=full_file_code, + metadata={"source": "run.js"} + ) + + code_documents = {"run.js": full_file_doc} + + calls = parser._get_function_calls(caller_doc, "orig", code_documents) + assert "local" in calls, ( + f"Chained CommonJS alias orig→mid→local should resolve. Got: {calls}" + ) +