Skip to content

Commit

Permalink
Merge pull request #32 from Sefaria/fix/url-match-in-graph-query
Browse files Browse the repository at this point in the history
fix: exact url matc during graph retrieval, number of neighbors, and …
  • Loading branch information
Paul-Yu-Chun-Chang authored Sep 11, 2024
2 parents 81e9d54 + 70d1443 commit 6d6b273
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
6 changes: 1 addition & 5 deletions VirtualHavruta/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_node_data(node: "Node") -> dict:
record: dict = next(iter(record.values()))
return record

def convert_node_to_doc(node: "Node", base_url: str= "https://www.sefaria.org/") -> Document:
def convert_node_to_doc(node: "Node") -> Document:
"""
Convert a node from the graph database to a Document object.
Expand All @@ -109,10 +109,6 @@ def convert_node_to_doc(node: "Node", base_url: str= "https://www.sefaria.org/")
"""
node_data: dict = get_node_data(node)
metadata = {k:v for k, v in node_data.items() if not k.startswith("content")}
new_reference_part = metadata["url"].replace(base_url, "")
new_category = metadata["primaryDocCategory"]
metadata["source"] = f"Reference: {new_reference_part}. Version Title: -, Document Category: {new_category}, URL: {metadata['url']}"

page_content = dict_to_yaml_str(node_data.get("content")) if isinstance(node_data.get("content"), dict) else node_data.get("content", "")
return ChunkDocument(
page_content=page_content,
Expand Down
7 changes: 4 additions & 3 deletions VirtualHavruta/vh.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def retrieve_nodes_matching_linker_results(self, linker_results: list[dict], msg
url_to_node[url] = node
else:
url_to_node[url].metadata["source"] += " | " + node.metadata["source"]
self.logger.info(f"MsgID={msg_id}. [LINKER-GRAGH RETRIEVAL] Graph nodes retrieved using linker URLs: {url_to_node}")
self.logger.info(f"MsgID={msg_id}. [LINKER-GRAGH RETRIEVAL] Graph nodes retrieved using linker URLs: {['URL='+url+' SOURCE='+node.metadata["source"] for url, node in url_to_node.items()]}")
return list(url_to_node.values())

def get_retrieval_results_knowledge_graph(self, url: str, direction: str, order: int, score_central_node: float, filter_mode_nodes: str|None = None, msg_id: str = '') -> list[tuple[Document, float]]:
Expand Down Expand Up @@ -527,7 +527,7 @@ def query_graph_db_by_url(self, urls: list[str]) -> list[Document]:
query_parameters = {"urls": urls}
query_string="""
MATCH (n:Records)
WHERE any(substring IN $urls WHERE n.url CONTAINS substring)
WHERE any(substring IN $urls WHERE n.url = substring)
RETURN n
"""
with neo4j.GraphDatabase.driver(self.config["database"]["kg"]["url"], auth=(self.config["database"]["kg"]["username"], self.config["database"]["kg"]["password"])) as driver:
Expand Down Expand Up @@ -1232,7 +1232,8 @@ def graph_traversal_retriever(self,
score_central_node=6.0,
msg_id=msg_id
)
neighbor_nodes = [node for node, _ in neighbor_nodes_scores]
# Limit the amount of neighbors to top 15
neighbor_nodes = [node for node, _ in neighbor_nodes_scores][:15]
candidate_chunks = self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id)
# avoid re-adding the top chunk
candidate_chunks = [chunk for chunk in candidate_chunks if chunk not in collected_chunks]
Expand Down

0 comments on commit 6d6b273

Please sign in to comment.