From 680b2ace562e75582a28eda1ab7a53df002df7f2 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Tue, 12 Nov 2024 23:39:03 +0530 Subject: [PATCH 1/5] java parsing --- app/modules/auth/auth_service.py | 3 +- .../graph_construction/code_graph_service.py | 65 +-- .../graph_construction/parsing_controller.py | 15 +- .../graph_construction/parsing_repomap.py | 378 +++++++++--------- .../graph_construction/parsing_service.py | 2 +- .../queries/tree-sitter-java-tags.scm | 39 +- 6 files changed, 277 insertions(+), 225 deletions(-) diff --git a/app/modules/auth/auth_service.py b/app/modules/auth/auth_service.py index f6649b50..15965baf 100644 --- a/app/modules/auth/auth_service.py +++ b/app/modules/auth/auth_service.py @@ -42,10 +42,11 @@ async def check_auth( HTTPBearer(auto_error=False) ), ): + return {"user_id": "WKyrZNjOflYSr9q8Jm7JcHqqwSr1", "email": "test@test.com"} + # Check if the application is in debug mode if os.getenv("isDevelopmentMode") == "enabled" and credential is None: request.state.user = {"user_id": os.getenv("defaultUsername")} - return {"user_id": os.getenv("defaultUsername")} else: if credential is None: raise HTTPException( diff --git a/app/modules/parsing/graph_construction/code_graph_service.py b/app/modules/parsing/graph_construction/code_graph_service.py index 18f651da..f1f4afd7 100644 --- a/app/modules/parsing/graph_construction/code_graph_service.py +++ b/app/modules/parsing/graph_construction/code_graph_service.py @@ -1,7 +1,7 @@ import hashlib import logging from typing import Dict, Optional - +import time from neo4j import GraphDatabase from sqlalchemy.orm import Session @@ -43,10 +43,16 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): nx_graph = self.repo_map.create_graph(repo_dir) with self.driver.session() as session: - # Create nodes - import time + # First, clear any existing data for this project + session.run( + """ + MATCH (n {repoId: $project_id}) + DETACH DELETE n + """, + project_id=project_id, + ) - start_time = time.time() # Start timing + start_time = time.time() node_count = nx_graph.number_of_nodes() logging.info(f"Creating {node_count} nodes") @@ -55,25 +61,38 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): for i in range(0, node_count, batch_size): batch_nodes = list(nx_graph.nodes(data=True))[i : i + batch_size] nodes_to_create = [] - for node in batch_nodes: - node_type = node[1].get("type") - label = node_type.capitalize() if node_type else "UNKNOWN" - node_data = { - "name": node[0], - "file_path": node[1].get("file", ""), - "start_line": node[1].get("line", -1), - "end_line": node[1].get("end_line", -1), + + for node_id, node_data in batch_nodes: + # Get the node type and ensure it's one of our expected types + node_type = node_data.get("type", "UNKNOWN") + if node_type == "UNKNOWN": + continue + # Initialize labels with NODE + labels = ["NODE"] + + # Add specific type label if it's a valid type + if node_type in ["FILE", "CLASS", "FUNCTION", "INTERFACE"]: + labels.append(node_type) + + # Prepare node data + processed_node = { + "name": node_data.get("name", node_id), # Use node_id as fallback + "file_path": node_data.get("file", ""), + "start_line": node_data.get("line", -1), + "end_line": node_data.get("end_line", -1), "repoId": project_id, - "node_id": CodeGraphService.generate_node_id(node[0], user_id), + "node_id": CodeGraphService.generate_node_id(node_id, user_id), "entityId": user_id, - "type": node_type if node_type else "Unknown", - "text": node[1].get("text", ""), - "labels": ["NODE", label], + "type": node_type, + "text": node_data.get("text", ""), + "labels": labels, } - # Remove any null values from node_data - node_data = {k: v for k, v in node_data.items() if v is not None} - nodes_to_create.append(node_data) + + # Remove None values + processed_node = {k: v for k, v in processed_node.items() if v is not None} + nodes_to_create.append(processed_node) + # Create nodes with labels session.run( """ UNWIND $nodes AS node @@ -112,11 +131,9 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): edges=edges_to_create, ) - end_time = time.time() # End timing - logging.info( - f"Time taken to create graph and search index: {end_time - start_time:.2f} seconds" - ) - + end_time = time.time() + logging.info(f"Time taken to create graph and search index: {end_time - start_time:.2f} seconds") + def cleanup_graph(self, project_id: str): with self.driver.session() as session: session.run( diff --git a/app/modules/parsing/graph_construction/parsing_controller.py b/app/modules/parsing/graph_construction/parsing_controller.py index a15386e5..c46724d3 100644 --- a/app/modules/parsing/graph_construction/parsing_controller.py +++ b/app/modules/parsing/graph_construction/parsing_controller.py @@ -60,12 +60,13 @@ async def parse_directory( response = {"project_id": project_id, "status": project_status} # Check commit status - is_latest = ( - await parse_helper.check_commit_status(project_id) - if not demo_project - else True - ) + # is_latest = ( + # await parse_helper.check_commit_status(project_id) + # if not demo_project + # else True + # ) + is_latest = False if not is_latest or project_status != ProjectStatusEnum.READY.value: cleanup_graph = True @@ -178,9 +179,9 @@ async def handle_new_project( } logger.info(f"Submitting parsing task for new project {new_project_id}") - + repo_name = repo_details.repo_name or repo_details.repo_path.split("/")[-1] await project_manager.register_project( - repo_details.repo_name, repo_details.branch_name, user_id, new_project_id + repo_name, repo_details.branch_name, user_id, new_project_id ) asyncio.create_task( GithubService(db).get_project_structure_async(new_project_id) diff --git a/app/modules/parsing/graph_construction/parsing_repomap.py b/app/modules/parsing/graph_construction/parsing_repomap.py index 1c6daf03..d16b6709 100644 --- a/app/modules/parsing/graph_construction/parsing_repomap.py +++ b/app/modules/parsing/graph_construction/parsing_repomap.py @@ -160,26 +160,47 @@ def get_tags_raw(self, fname, rel_fname): # Run the tags queries query = language.query(query_scm) captures = query.captures(tree.root_node) - captures = list(captures) - saw = set() + + # Enhanced debugging + print(f"Processing file: {fname}") + print(f"Language detected: {lang}") + for node, tag in captures: + node_text = node.text.decode('utf-8') + print(f"Captured node: {node_text} with tag: {tag}") + if tag.startswith("name.definition."): kind = "def" - type = tag.split(".")[-1] # + type = tag.split(".")[-1] + # Special handling for Java methods + if type == "method": + print(f"Found method definition: {node_text}") elif tag.startswith("name.reference."): kind = "ref" - type = tag.split(".")[-1] # + type = tag.split(".")[-1] + # Special handling for Java method calls + if type == "method": + print(f"Found method reference: {node_text}") else: continue saw.add(kind) + # Enhanced node text extraction for Java methods + if lang == "java" and type == "method": + # Handle method calls with object references (e.g., productService.listAllProducts()) + parent = node.parent + if parent and parent.type == "method_invocation": + object_node = parent.child_by_field_name("object") + if object_node: + node_text = f"{object_node.text.decode('utf-8')}.{node_text}" + result = Tag( rel_fname=rel_fname, fname=fname, - name=node.text.decode("utf-8"), + name=node_text, kind=kind, line=node.start_point[0], end_line=node.end_point[0], @@ -524,205 +545,202 @@ def render_tree(self, abs_fname, rel_fname, lois): self.tree_cache[key] = res return res - def create_graph(self, repo_dir): - start_time = time.time() - logging.info("Starting parsing of codebase") + def create_relationship(G, source, target, relationship_type, seen_relationships, extra_data=None): + """Helper to create relationships with proper direction checking""" + if source == target: + return False + + # Determine correct direction based on node types + source_data = G.nodes[source] + target_data = G.nodes[target] + + # Prevent duplicate bidirectional relationships + rel_key = (source, target, relationship_type) + reverse_key = (target, source, relationship_type) + + if rel_key in seen_relationships or reverse_key in seen_relationships: + return False + + # Only create relationship if we have right direction: + # 1. Interface method implementations should point to interface declaration + # 2. Method calls should point to method definitions + # 3. Class references should point to class definitions + valid_direction = False + + if relationship_type == "REFERENCES": + # Implementation -> Interface + if (source_data.get('type') == 'FUNCTION' and + target_data.get('type') == 'FUNCTION' and + 'Impl' in source): # Implementation class + valid_direction = True + + # Caller -> Callee + elif source_data.get('type') == 'FUNCTION': + valid_direction = True + + # Class Usage -> Class Definition + elif target_data.get('type') == 'CLASS': + valid_direction = True + + if valid_direction: + G.add_edge(source, target, + type=relationship_type, + **(extra_data or {})) + seen_relationships.add(rel_key) + return True + + return False - G = nx.MultiDiGraph() - defines = defaultdict(list) - references = defaultdict(list) - file_count = 0 + def create_graph(self, repo_dir): + G = nx.MultiDiGraph() + defines = defaultdict(set) + references = defaultdict(set) + seen_relationships = set() + + logging.info("Starting graph creation with detailed debugging...") + for root, dirs, files in os.walk(repo_dir): - # Ignore folders starting with '.' if any(part.startswith(".") for part in root.split(os.sep)): continue - + for file in files: - file_count += 1 - file_path = os.path.join(root, file) rel_path = os.path.relpath(file_path, repo_dir) if not self.parse_helper.is_text_file(file_path): continue - tags = self.get_tags(file_path, rel_path) - - # Extract full file content - file_content = self.io.read_text(file_path) or "" - if not file_content.endswith("\n"): - file_content += "\n" - - # Parse the file using tree-sitter - language = RepoMap.get_language_for_file(file_path) - if language: - parser = Parser() - parser.set_language(language) - tree = parser.parse(bytes(file_content, "utf8")) - root_node = tree.root_node + logging.info(f"\nProcessing file: {rel_path}") + + # Add file node + file_node_name = rel_path + if not G.has_node(file_node_name): + G.add_node( + file_node_name, + file=rel_path, + type="FILE", + text=self.io.read_text(file_path) or "", + line=0, + end_line=0, + name=rel_path.split("/")[-1], + ) + logging.info(f"Added FILE node: {file_node_name}") current_class = None - current_function = None - for tag in tags: + current_method = None + + # Process all tags in file + for tag in self.get_tags(file_path, rel_path): + logging.debug(f"Processing tag: {tag.kind} {tag.name} (type: {tag.type})") + if tag.kind == "def": if tag.type == "class": + node_type = "CLASS" + current_class = tag.name + current_method = None + logging.debug(f"Entered class context: {current_class}") + elif tag.type == "interface": + node_type = "INTERFACE" current_class = tag.name - current_function = None - node_type = "class" - elif tag.type == "function": - current_function = tag.name - node_type = "function" + current_method = None + logging.debug(f"Entered interface context: {current_class}") + elif tag.type in ["method", "function"]: + node_type = "FUNCTION" + current_method = tag.name + logging.debug(f"Entered method context: {current_method} in class {current_class}") else: - node_type = "other" - - node_name = f"{rel_path}:{tag.name}" + continue - # Extract code for the current tag using AST - if language: - node = RepoMap.find_node_by_range( - root_node, tag.line, node_type - ) - if node: - code_context = file_content[ - node.start_byte : node.end_byte - ] - node_end_line = ( - node.end_point[0] + 1 - ) # Adding 1 to match 1-based line numbering - else: - code_context = "" - node_end_line = tag.end_line - continue + # Create fully qualified node name + if current_class: + node_name = f"{rel_path}:{current_class}.{tag.name}" else: - code_context = "" - node_end_line = tag.end_line - continue + node_name = f"{rel_path}:{tag.name}" + + logging.info(f"Creating {node_type} node: {node_name}") - defines[tag.name].append( - ( + # Add node + if not G.has_node(node_name): + G.add_node( node_name, - tag.line, - node_end_line, - node_type, - rel_path, - current_class, - ) - ) - G.add_node( - node_name, - file=rel_path, - line=tag.line, - end_line=node_end_line, - type=tag.type, - text=code_context, - ) - elif tag.kind == "ref": - source = ( - f"{current_class}.{current_function}" - if current_class and current_function - else ( - f"{rel_path}:{current_function}" - if current_function - else rel_path - ) - ) - references[tag.name].append( - ( - source, - tag.line, - tag.end_line, - tag.type, - rel_path, - current_class, + file=rel_path, + line=tag.line, + end_line=tag.end_line, + type=node_type, + name=tag.name, + class_name=current_class ) - ) - - # Add a node for the entire file - G.add_node( - rel_path, - file=rel_path, - type="file", - text=file_content, - ) + + # Add CONTAINS relationship from file + rel_key = (file_node_name, node_name, "CONTAINS") + if rel_key not in seen_relationships: + G.add_edge( + file_node_name, + node_name, + type="CONTAINS", + ident=tag.name + ) + seen_relationships.add(rel_key) + logging.info(f"Added CONTAINS relationship: {file_node_name} -> {node_name}") + + # Record definition + defines[tag.name].add(node_name) + logging.debug(f"Recorded definition of {tag.name} as {node_name}") + elif tag.kind == "ref": + # Handle references + if current_class and current_method: + source = f"{rel_path}:{current_class}.{current_method}" + elif current_method: + source = f"{rel_path}:{current_method}" + else: + source = rel_path + + logging.debug(f"Found reference to {tag.name} from {source}") + references[tag.name].add(( + source, + tag.line, + tag.end_line, + current_class, + current_method + )) + + logging.info("\nDefinitions collected:") + for ident, nodes in defines.items(): + logging.info(f"{ident}: {nodes}") + + logging.info("\nReferences collected:") for ident, refs in references.items(): - if ident in defines: - if len(defines[ident]) == 1: - target, def_line, end_def_line, def_type, def_file, def_class = ( - defines[ident][0] - ) - for ( - source, - ref_line, - end_ref_line, - ref_type, - ref_file, - ref_class, - ) in refs: - G.add_edge( - source, - target, - type=ref_type, - ident=ident, - ref_line=ref_line, - end_ref_line=end_ref_line, - def_line=def_line, - end_def_line=end_def_line, - ) - else: - for ( - source, - ref_line, - end_ref_line, - ref_type, - ref_file, - ref_class, - ) in refs: - best_match = None - best_match_score = -1 - for ( - target, - def_line, - end_def_line, - def_type, - def_file, - def_class, - ) in defines[ident]: - if source != target: - match_score = 0 - if ref_file == def_file: - match_score += 2 - elif os.path.dirname(ref_file) == os.path.dirname( - def_file - ): - match_score += 1 - if ref_class == def_class: - match_score += 1 - if match_score > best_match_score: - best_match = ( - target, - def_line, - end_def_line, - def_type, - ) - best_match_score = match_score - - if best_match: - target, def_line, end_def_line, def_type = best_match - G.add_edge( - source, - target, - type=ref_type, - ident=ident, - ref_line=ref_line, - end_ref_line=end_ref_line, - def_line=def_line, - end_def_line=end_def_line, - ) + logging.info(f"{ident}: {refs}") - end_time = time.time() - logging.info(f"Parsing completed, time taken: {end_time - start_time} seconds") + logging.info("\nCreating REFERENCES relationships:") + # Second pass - create REFERENCES relationships + for ident, refs in references.items(): + target_nodes = defines.get(ident, set()) + logging.info(f"\nProcessing references to {ident}") + logging.info(f"Target nodes: {target_nodes}") + + for source, line, end_line, src_class, src_method in refs: + logging.info(f"Processing reference from {source}") + + for target in target_nodes: + if source == target: + logging.debug(f"Skipping self-reference: {source} -> {target}") + continue + + if G.has_node(source) and G.has_node(target): + RepoMap.create_relationship(G, source, target, "REFERENCES", + seen_relationships, + {"ident": ident, + "ref_line": line, + "end_ref_line": end_line}) + + logging.info("\nFinal graph statistics:") + logging.info(f"Nodes: {G.number_of_nodes()}") + logging.info(f"Edges: {G.number_of_edges()}") + logging.info(f"Unique relationships tracked: {len(seen_relationships)}") + return G @staticmethod @@ -755,9 +773,9 @@ def get_language_for_file(file_path): def find_node_by_range(root_node, start_line, node_type): def traverse(node): if node.start_point[0] <= start_line and node.end_point[0] >= start_line: - if node_type == "function" and node.type == "function_definition": + if node_type == "FUNCTION" and node.type in ["function_definition", "method","method_declaration", "function"]: return node - elif node_type == "class" and node.type == "class_definition": + elif node_type in ["CLASS", "INTERFACE"] and node.type in ["class_definition", "interface", "class", "class_declaration", "interface_declaration"]: return node for child in node.children: result = traverse(child) diff --git a/app/modules/parsing/graph_construction/parsing_service.py b/app/modules/parsing/graph_construction/parsing_service.py index fcc099f8..a2024c3f 100644 --- a/app/modules/parsing/graph_construction/parsing_service.py +++ b/app/modules/parsing/graph_construction/parsing_service.py @@ -207,7 +207,7 @@ async def analyze_directory( project_id, ProjectStatusEnum.PARSED ) # Generate docstrings using InferenceService - await self.inference_service.run_inference(project_id) + #await self.inference_service.run_inference(project_id) logger.info(f"DEBUGNEO4J: After inference project {project_id}") self.inference_service.log_graph_stats(project_id) await self.project_service.update_project_status( diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm index 3b7290d4..fdd00d21 100644 --- a/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm @@ -1,20 +1,35 @@ -(class_declaration - name: (identifier) @name.definition.class) @definition.class - +; Methods (method_declaration - name: (identifier) @name.definition.method) @definition.method + name: (identifier) @name.definition.method) +; Method invocations (method_invocation - name: (identifier) @name.reference.call - arguments: (argument_list) @reference.call) + name: (identifier) @name.reference.method) + +; Class definitions +(class_declaration + name: (identifier) @name.definition.class) +; Interface definitions (interface_declaration - name: (identifier) @name.definition.interface) @definition.interface + name: (identifier) @name.definition.interface) + +; Field declarations +(field_declaration + declarator: (variable_declarator + name: (identifier) @name.definition.field)) + +; Variable declarations +(local_variable_declaration + declarator: (variable_declarator + name: (identifier) @name.definition.variable)) -(type_list - (type_identifier) @name.reference.implementation) @reference.implementation +; Method parameters +(formal_parameter + name: (identifier) @name.definition.parameter) -(object_creation_expression - type: (type_identifier) @name.reference.class) @reference.class +; References to types +(type_identifier) @name.reference.type -(superclass (type_identifier) @name.reference.class) @reference.class +; References to variables and fields +(identifier) @name.reference.variable From de287ab25c1f6406701873e54a9eaac5c6dffaea Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Mon, 9 Dec 2024 21:36:25 +0530 Subject: [PATCH 2/5] cleanup wip --- app/modules/auth/auth_service.py | 2 +- .../graph_construction/parsing_controller.py | 11 ++++---- .../graph_construction/parsing_repomap.py | 27 +------------------ 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/app/modules/auth/auth_service.py b/app/modules/auth/auth_service.py index 15965baf..5a49e51e 100644 --- a/app/modules/auth/auth_service.py +++ b/app/modules/auth/auth_service.py @@ -42,11 +42,11 @@ async def check_auth( HTTPBearer(auto_error=False) ), ): - return {"user_id": "WKyrZNjOflYSr9q8Jm7JcHqqwSr1", "email": "test@test.com"} # Check if the application is in debug mode if os.getenv("isDevelopmentMode") == "enabled" and credential is None: request.state.user = {"user_id": os.getenv("defaultUsername")} + return {"user_id": os.getenv("defaultUsername")} else: if credential is None: raise HTTPException( diff --git a/app/modules/parsing/graph_construction/parsing_controller.py b/app/modules/parsing/graph_construction/parsing_controller.py index c46724d3..e5cde74d 100644 --- a/app/modules/parsing/graph_construction/parsing_controller.py +++ b/app/modules/parsing/graph_construction/parsing_controller.py @@ -60,13 +60,12 @@ async def parse_directory( response = {"project_id": project_id, "status": project_status} # Check commit status - # is_latest = ( - # await parse_helper.check_commit_status(project_id) - # if not demo_project - # else True - # ) + is_latest = ( + await parse_helper.check_commit_status(project_id) + if not demo_project + else True + ) - is_latest = False if not is_latest or project_status != ProjectStatusEnum.READY.value: cleanup_graph = True diff --git a/app/modules/parsing/graph_construction/parsing_repomap.py b/app/modules/parsing/graph_construction/parsing_repomap.py index d16b6709..9d889ed1 100644 --- a/app/modules/parsing/graph_construction/parsing_repomap.py +++ b/app/modules/parsing/graph_construction/parsing_repomap.py @@ -598,7 +598,6 @@ def create_graph(self, repo_dir): references = defaultdict(set) seen_relationships = set() - logging.info("Starting graph creation with detailed debugging...") for root, dirs, files in os.walk(repo_dir): if any(part.startswith(".") for part in root.split(os.sep)): @@ -632,23 +631,19 @@ def create_graph(self, repo_dir): # Process all tags in file for tag in self.get_tags(file_path, rel_path): - logging.debug(f"Processing tag: {tag.kind} {tag.name} (type: {tag.type})") if tag.kind == "def": if tag.type == "class": node_type = "CLASS" current_class = tag.name current_method = None - logging.debug(f"Entered class context: {current_class}") elif tag.type == "interface": node_type = "INTERFACE" current_class = tag.name current_method = None - logging.debug(f"Entered interface context: {current_class}") elif tag.type in ["method", "function"]: node_type = "FUNCTION" current_method = tag.name - logging.debug(f"Entered method context: {current_method} in class {current_class}") else: continue @@ -658,7 +653,6 @@ def create_graph(self, repo_dir): else: node_name = f"{rel_path}:{tag.name}" - logging.info(f"Creating {node_type} node: {node_name}") # Add node if not G.has_node(node_name): @@ -682,11 +676,9 @@ def create_graph(self, repo_dir): ident=tag.name ) seen_relationships.add(rel_key) - logging.info(f"Added CONTAINS relationship: {file_node_name} -> {node_name}") # Record definition defines[tag.name].add(node_name) - logging.debug(f"Recorded definition of {tag.name} as {node_name}") elif tag.kind == "ref": # Handle references @@ -697,7 +689,6 @@ def create_graph(self, repo_dir): else: source = rel_path - logging.debug(f"Found reference to {tag.name} from {source}") references[tag.name].add(( source, tag.line, @@ -706,23 +697,11 @@ def create_graph(self, repo_dir): current_method )) - logging.info("\nDefinitions collected:") - for ident, nodes in defines.items(): - logging.info(f"{ident}: {nodes}") - logging.info("\nReferences collected:") - for ident, refs in references.items(): - logging.info(f"{ident}: {refs}") - - logging.info("\nCreating REFERENCES relationships:") - # Second pass - create REFERENCES relationships for ident, refs in references.items(): target_nodes = defines.get(ident, set()) - logging.info(f"\nProcessing references to {ident}") - logging.info(f"Target nodes: {target_nodes}") - + for source, line, end_line, src_class, src_method in refs: - logging.info(f"Processing reference from {source}") for target in target_nodes: if source == target: @@ -736,10 +715,6 @@ def create_graph(self, repo_dir): "ref_line": line, "end_ref_line": end_line}) - logging.info("\nFinal graph statistics:") - logging.info(f"Nodes: {G.number_of_nodes()}") - logging.info(f"Edges: {G.number_of_edges()}") - logging.info(f"Unique relationships tracked: {len(seen_relationships)}") return G From 8509d6870e7a12c0eac00478c3f8049805e9fab8 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 13 Dec 2024 12:08:32 +0530 Subject: [PATCH 3/5] undo commented out inference step --- app/modules/parsing/graph_construction/parsing_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/modules/parsing/graph_construction/parsing_service.py b/app/modules/parsing/graph_construction/parsing_service.py index c0764aaf..3678818c 100644 --- a/app/modules/parsing/graph_construction/parsing_service.py +++ b/app/modules/parsing/graph_construction/parsing_service.py @@ -234,7 +234,7 @@ async def analyze_directory( project_id, ProjectStatusEnum.PARSED ) # Generate docstrings using InferenceService - #await self.inference_service.run_inference(project_id) + await self.inference_service.run_inference(project_id) logger.info(f"DEBUGNEO4J: After inference project {project_id}") self.inference_service.log_graph_stats(project_id) await self.project_service.update_project_status( From 0321d06c1bf0c304058751008ac4262bf0476636 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 13 Dec 2024 12:58:51 +0530 Subject: [PATCH 4/5] precommit --- app/modules/auth/auth_service.py | 1 - .../code_provider/github/github_router.py | 5 +- .../graph_construction/code_graph_service.py | 25 +-- .../graph_construction/parsing_controller.py | 23 ++- .../graph_construction/parsing_repomap.py | 145 +++++++++--------- app/modules/usage/usage_controller.py | 10 +- app/modules/usage/usage_router.py | 9 +- app/modules/usage/usage_schema.py | 6 +- app/modules/usage/usage_service.py | 16 +- 9 files changed, 130 insertions(+), 110 deletions(-) diff --git a/app/modules/auth/auth_service.py b/app/modules/auth/auth_service.py index 97011008..d03692cc 100644 --- a/app/modules/auth/auth_service.py +++ b/app/modules/auth/auth_service.py @@ -49,7 +49,6 @@ async def check_auth( HTTPBearer(auto_error=False) ), ): - # Check if the application is in debug mode if os.getenv("isDevelopmentMode") == "enabled" and credential is None: request.state.user = {"user_id": os.getenv("defaultUsername")} diff --git a/app/modules/code_provider/github/github_router.py b/app/modules/code_provider/github/github_router.py index 7bdaf559..4a8433e5 100644 --- a/app/modules/code_provider/github/github_router.py +++ b/app/modules/code_provider/github/github_router.py @@ -16,7 +16,7 @@ async def get_user_repos( ): user_repo_list = await GithubController(db).get_user_repos(user=user) user_repo_list["repositories"].extend(config_provider.get_demo_repo_list()) - + # Remove duplicates while preserving order seen = set() deduped_repos = [] @@ -24,11 +24,10 @@ async def get_user_repos( # Create tuple of values to use as hash key repo_key = repo["full_name"] - if repo_key not in seen: seen.add(repo_key) deduped_repos.append(repo) - + user_repo_list["repositories"] = deduped_repos return user_repo_list diff --git a/app/modules/parsing/graph_construction/code_graph_service.py b/app/modules/parsing/graph_construction/code_graph_service.py index f1f4afd7..455bdc53 100644 --- a/app/modules/parsing/graph_construction/code_graph_service.py +++ b/app/modules/parsing/graph_construction/code_graph_service.py @@ -1,7 +1,8 @@ import hashlib import logging +import time from typing import Dict, Optional -import time + from neo4j import GraphDatabase from sqlalchemy.orm import Session @@ -61,7 +62,7 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): for i in range(0, node_count, batch_size): batch_nodes = list(nx_graph.nodes(data=True))[i : i + batch_size] nodes_to_create = [] - + for node_id, node_data in batch_nodes: # Get the node type and ensure it's one of our expected types node_type = node_data.get("type", "UNKNOWN") @@ -69,14 +70,16 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): continue # Initialize labels with NODE labels = ["NODE"] - + # Add specific type label if it's a valid type if node_type in ["FILE", "CLASS", "FUNCTION", "INTERFACE"]: labels.append(node_type) - + # Prepare node data processed_node = { - "name": node_data.get("name", node_id), # Use node_id as fallback + "name": node_data.get( + "name", node_id + ), # Use node_id as fallback "file_path": node_data.get("file", ""), "start_line": node_data.get("line", -1), "end_line": node_data.get("end_line", -1), @@ -87,9 +90,11 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): "text": node_data.get("text", ""), "labels": labels, } - + # Remove None values - processed_node = {k: v for k, v in processed_node.items() if v is not None} + processed_node = { + k: v for k, v in processed_node.items() if v is not None + } nodes_to_create.append(processed_node) # Create nodes with labels @@ -132,8 +137,10 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): ) end_time = time.time() - logging.info(f"Time taken to create graph and search index: {end_time - start_time:.2f} seconds") - + logging.info( + f"Time taken to create graph and search index: {end_time - start_time:.2f} seconds" + ) + def cleanup_graph(self, project_id: str): with self.driver.session() as session: session.run( diff --git a/app/modules/parsing/graph_construction/parsing_controller.py b/app/modules/parsing/graph_construction/parsing_controller.py index 97b264f2..1e83b6a1 100644 --- a/app/modules/parsing/graph_construction/parsing_controller.py +++ b/app/modules/parsing/graph_construction/parsing_controller.py @@ -75,7 +75,7 @@ async def parse_directory( project = await project_manager.get_project_from_db( repo_name, repo_details.branch_name, user_id ) - + # First check if this is a demo project that hasn't been accessed by this user yet if not project and repo_details.repo_name in demo_repos: existing_project = await project_manager.get_global_project_from_db( @@ -134,15 +134,17 @@ async def parse_directory( project_manager, db, ) - + # Handle existing projects (including previously duplicated demo projects) if project: project_id = project.id is_latest = await parse_helper.check_commit_status(project_id) - + if not is_latest or project.status != ProjectStatusEnum.READY.value: cleanup_graph = True - logger.info(f"Submitting parsing task for existing project {project_id}") + logger.info( + f"Submitting parsing task for existing project {project_id}" + ) process_parsing.delay( repo_details.model_dump(), user_id, @@ -150,8 +152,10 @@ async def parse_directory( project_id, cleanup_graph, ) - - await project_manager.update_project_status(project_id, ProjectStatusEnum.SUBMITTED) + + await project_manager.update_project_status( + project_id, ProjectStatusEnum.SUBMITTED + ) PostHogClient().send_event( user_id, "parsed_repo_event", @@ -161,8 +165,11 @@ async def parse_directory( "project_id": project_id, }, ) - return {"project_id": project_id, "status": ProjectStatusEnum.SUBMITTED.value} - + return { + "project_id": project_id, + "status": ProjectStatusEnum.SUBMITTED.value, + } + return {"project_id": project_id, "status": project.status} else: # Handle new non-demo projects diff --git a/app/modules/parsing/graph_construction/parsing_repomap.py b/app/modules/parsing/graph_construction/parsing_repomap.py index 9d889ed1..511a0bc0 100644 --- a/app/modules/parsing/graph_construction/parsing_repomap.py +++ b/app/modules/parsing/graph_construction/parsing_repomap.py @@ -1,7 +1,6 @@ import logging import math import os -import time import warnings from collections import Counter, defaultdict, namedtuple from pathlib import Path @@ -12,7 +11,6 @@ from pygments.token import Token from pygments.util import ClassNotFound from tqdm import tqdm -from tree_sitter import Parser from tree_sitter_languages import get_language, get_parser from app.core.database import get_db @@ -26,7 +24,8 @@ class RepoMap: - warned_files = set() + # Parsing logic adapted from aider (https://github.com/paul-gauthier/aider) + # Modified and customized for potpie's parsing needs with detailed tags, relationship tracking etc def __init__( self, @@ -162,27 +161,18 @@ def get_tags_raw(self, fname, rel_fname): captures = query.captures(tree.root_node) captures = list(captures) saw = set() - - # Enhanced debugging - print(f"Processing file: {fname}") - print(f"Language detected: {lang}") - + for node, tag in captures: - node_text = node.text.decode('utf-8') - print(f"Captured node: {node_text} with tag: {tag}") - + node_text = node.text.decode("utf-8") + if tag.startswith("name.definition."): kind = "def" type = tag.split(".")[-1] - # Special handling for Java methods - if type == "method": - print(f"Found method definition: {node_text}") + elif tag.startswith("name.reference."): kind = "ref" type = tag.split(".")[-1] - # Special handling for Java method calls - if type == "method": - print(f"Found method reference: {node_text}") + else: continue @@ -214,10 +204,6 @@ def get_tags_raw(self, fname, rel_fname): if "def" not in saw: return - # We saw defs, without any refs - # Some tags files only provide defs (cpp, for example) - # Use pygments to backfill refs - try: lexer = guess_lexer_for_filename(fname, code) except ClassNotFound: @@ -545,64 +531,64 @@ def render_tree(self, abs_fname, rel_fname, lois): self.tree_cache[key] = res return res - def create_relationship(G, source, target, relationship_type, seen_relationships, extra_data=None): + def create_relationship( + G, source, target, relationship_type, seen_relationships, extra_data=None + ): """Helper to create relationships with proper direction checking""" if source == target: return False - + # Determine correct direction based on node types source_data = G.nodes[source] target_data = G.nodes[target] - + # Prevent duplicate bidirectional relationships rel_key = (source, target, relationship_type) reverse_key = (target, source, relationship_type) - + if rel_key in seen_relationships or reverse_key in seen_relationships: return False - + # Only create relationship if we have right direction: # 1. Interface method implementations should point to interface declaration # 2. Method calls should point to method definitions # 3. Class references should point to class definitions valid_direction = False - + if relationship_type == "REFERENCES": # Implementation -> Interface - if (source_data.get('type') == 'FUNCTION' and - target_data.get('type') == 'FUNCTION' and - 'Impl' in source): # Implementation class + if ( + source_data.get("type") == "FUNCTION" + and target_data.get("type") == "FUNCTION" + and "Impl" in source + ): # Implementation class valid_direction = True - - # Caller -> Callee - elif source_data.get('type') == 'FUNCTION': + + # Caller -> Callee + elif source_data.get("type") == "FUNCTION": valid_direction = True - + # Class Usage -> Class Definition - elif target_data.get('type') == 'CLASS': + elif target_data.get("type") == "CLASS": valid_direction = True - + if valid_direction: - G.add_edge(source, target, - type=relationship_type, - **(extra_data or {})) + G.add_edge(source, target, type=relationship_type, **(extra_data or {})) seen_relationships.add(rel_key) return True - - return False + return False def create_graph(self, repo_dir): G = nx.MultiDiGraph() defines = defaultdict(set) references = defaultdict(set) seen_relationships = set() - - + for root, dirs, files in os.walk(repo_dir): if any(part.startswith(".") for part in root.split(os.sep)): continue - + for file in files: file_path = os.path.join(root, file) rel_path = os.path.relpath(file_path, repo_dir) @@ -611,7 +597,7 @@ def create_graph(self, repo_dir): continue logging.info(f"\nProcessing file: {rel_path}") - + # Add file node file_node_name = rel_path if not G.has_node(file_node_name): @@ -624,21 +610,19 @@ def create_graph(self, repo_dir): end_line=0, name=rel_path.split("/")[-1], ) - logging.info(f"Added FILE node: {file_node_name}") current_class = None current_method = None # Process all tags in file for tag in self.get_tags(file_path, rel_path): - if tag.kind == "def": if tag.type == "class": node_type = "CLASS" current_class = tag.name current_method = None elif tag.type == "interface": - node_type = "INTERFACE" + node_type = "INTERFACE" current_class = tag.name current_method = None elif tag.type in ["method", "function"]: @@ -652,7 +636,6 @@ def create_graph(self, repo_dir): node_name = f"{rel_path}:{current_class}.{tag.name}" else: node_name = f"{rel_path}:{tag.name}" - # Add node if not G.has_node(node_name): @@ -663,17 +646,17 @@ def create_graph(self, repo_dir): end_line=tag.end_line, type=node_type, name=tag.name, - class_name=current_class + class_name=current_class, ) - + # Add CONTAINS relationship from file rel_key = (file_node_name, node_name, "CONTAINS") if rel_key not in seen_relationships: G.add_edge( - file_node_name, + file_node_name, node_name, type="CONTAINS", - ident=tag.name + ident=tag.name, ) seen_relationships.add(rel_key) @@ -689,33 +672,38 @@ def create_graph(self, repo_dir): else: source = rel_path - references[tag.name].add(( - source, - tag.line, - tag.end_line, - current_class, - current_method - )) - + references[tag.name].add( + ( + source, + tag.line, + tag.end_line, + current_class, + current_method, + ) + ) for ident, refs in references.items(): target_nodes = defines.get(ident, set()) for source, line, end_line, src_class, src_method in refs: - for target in target_nodes: if source == target: - logging.debug(f"Skipping self-reference: {source} -> {target}") continue - + if G.has_node(source) and G.has_node(target): - RepoMap.create_relationship(G, source, target, "REFERENCES", - seen_relationships, - {"ident": ident, - "ref_line": line, - "end_ref_line": end_line}) - - + RepoMap.create_relationship( + G, + source, + target, + "REFERENCES", + seen_relationships, + { + "ident": ident, + "ref_line": line, + "end_ref_line": end_line, + }, + ) + return G @staticmethod @@ -748,9 +736,20 @@ def get_language_for_file(file_path): def find_node_by_range(root_node, start_line, node_type): def traverse(node): if node.start_point[0] <= start_line and node.end_point[0] >= start_line: - if node_type == "FUNCTION" and node.type in ["function_definition", "method","method_declaration", "function"]: + if node_type == "FUNCTION" and node.type in [ + "function_definition", + "method", + "method_declaration", + "function", + ]: return node - elif node_type in ["CLASS", "INTERFACE"] and node.type in ["class_definition", "interface", "class", "class_declaration", "interface_declaration"]: + elif node_type in ["CLASS", "INTERFACE"] and node.type in [ + "class_definition", + "interface", + "class", + "class_declaration", + "interface_declaration", + ]: return node for child in node.children: result = traverse(child) diff --git a/app/modules/usage/usage_controller.py b/app/modules/usage/usage_controller.py index b17bbdd5..46e17392 100644 --- a/app/modules/usage/usage_controller.py +++ b/app/modules/usage/usage_controller.py @@ -1,9 +1,13 @@ from datetime import datetime -from app.modules.usage.usage_service import UsageService + from app.modules.usage.usage_schema import UsageResponse +from app.modules.usage.usage_service import UsageService + class UsageController: @staticmethod - async def get_user_usage(start_date: datetime, end_date: datetime, user_id: str) -> UsageResponse: + async def get_user_usage( + start_date: datetime, end_date: datetime, user_id: str + ) -> UsageResponse: usage_data = await UsageService.get_usage_data(start_date, end_date, user_id) - return usage_data \ No newline at end of file + return usage_data diff --git a/app/modules/usage/usage_router.py b/app/modules/usage/usage_router.py index 7f4bc149..d426da42 100644 --- a/app/modules/usage/usage_router.py +++ b/app/modules/usage/usage_router.py @@ -1,17 +1,20 @@ -from fastapi import APIRouter, Depends from datetime import datetime + +from fastapi import APIRouter, Depends + from app.modules.auth.auth_service import AuthService from app.modules.usage.usage_controller import UsageController from app.modules.usage.usage_schema import UsageResponse router = APIRouter() + class UsageAPI: @staticmethod @router.get("/usage", response_model=UsageResponse) async def get_usage( - start_date: datetime, - end_date: datetime, + start_date: datetime, + end_date: datetime, user=Depends(AuthService.check_auth), ): user_id = user["user_id"] diff --git a/app/modules/usage/usage_schema.py b/app/modules/usage/usage_schema.py index 9cf83be3..32d31bfb 100644 --- a/app/modules/usage/usage_schema.py +++ b/app/modules/usage/usage_schema.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel from typing import Dict +from pydantic import BaseModel + + class UsageResponse(BaseModel): total_human_messages: int - agent_message_counts: Dict[str, int] \ No newline at end of file + agent_message_counts: Dict[str, int] diff --git a/app/modules/usage/usage_service.py b/app/modules/usage/usage_service.py index 70946371..cfc4a4e8 100644 --- a/app/modules/usage/usage_service.py +++ b/app/modules/usage/usage_service.py @@ -1,11 +1,12 @@ from datetime import datetime + from fastapi import logger from sqlalchemy import func -from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError + from app.core.database import SessionLocal -from app.modules.conversations.message.message_model import Message, MessageType from app.modules.conversations.conversation.conversation_model import Conversation +from app.modules.conversations.message.message_model import Message, MessageType class UsageService: @@ -15,8 +16,8 @@ async def get_usage_data(start_date: datetime, end_date: datetime, user_id: str) with SessionLocal() as session: agent_query = ( session.query( - func.unnest(Conversation.agent_ids).label('agent_id'), - func.count(Message.id).label('message_count') + func.unnest(Conversation.agent_ids).label("agent_id"), + func.count(Message.id).label("message_count"), ) .join(Message, Message.conversation_id == Conversation.id) .filter( @@ -29,17 +30,16 @@ async def get_usage_data(start_date: datetime, end_date: datetime, user_id: str) ) agent_message_counts = { - agent_id: count - for agent_id, count in agent_query + agent_id: count for agent_id, count in agent_query } total_human_messages = sum(agent_message_counts.values()) return { "total_human_messages": total_human_messages, - "agent_message_counts": agent_message_counts + "agent_message_counts": agent_message_counts, } - + except SQLAlchemyError as e: logger.error(f"Failed to fetch usage data: {e}") raise Exception("Failed to fetch usage data") from e From f84508f83ab1df885c768621218dc27acda25cff Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 13 Dec 2024 13:05:42 +0530 Subject: [PATCH 5/5] remove duplicate cleanup step --- .../parsing/graph_construction/code_graph_service.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/app/modules/parsing/graph_construction/code_graph_service.py b/app/modules/parsing/graph_construction/code_graph_service.py index 455bdc53..d0af621e 100644 --- a/app/modules/parsing/graph_construction/code_graph_service.py +++ b/app/modules/parsing/graph_construction/code_graph_service.py @@ -44,14 +44,6 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): nx_graph = self.repo_map.create_graph(repo_dir) with self.driver.session() as session: - # First, clear any existing data for this project - session.run( - """ - MATCH (n {repoId: $project_id}) - DETACH DELETE n - """, - project_id=project_id, - ) start_time = time.time() node_count = nx_graph.number_of_nodes()