diff --git a/.semversioner/next-release/patch-20250121150223319652.json b/.semversioner/next-release/patch-20250121150223319652.json new file mode 100644 index 0000000000..704e78c52e --- /dev/null +++ b/.semversioner/next-release/patch-20250121150223319652.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "implemented multi-index querying for api layer" +} diff --git a/graphrag/api/index.py b/graphrag/api/index.py index 1609a07e53..c7d77ef479 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -18,6 +18,7 @@ from graphrag.index.run.run_workflows import run_workflows from graphrag.index.typing import PipelineRunResult from graphrag.logger.base import ProgressLogger +from graphrag.utils.api import get_workflows_list log = logging.getLogger(__name__) @@ -60,7 +61,7 @@ async def build_index( if memory_profile: log.warning("New pipeline does not yet support memory profiling.") - workflows = _get_workflows_list(config) + workflows = get_workflows_list(config) async for output in run_workflows( workflows, @@ -79,20 +80,3 @@ async def build_index( progress_logger.info(str(output.result)) return outputs - - -def _get_workflows_list(config: GraphRagConfig) -> list[str]: - return [ - "create_base_text_units", - "create_final_documents", - "extract_graph", - "compute_communities", - "create_final_entities", - "create_final_relationships", - "create_final_nodes", - "create_final_communities", - *(["create_final_covariates"] if config.claim_extraction.enabled else []), - "create_final_text_units", - "create_final_community_reports", - "generate_text_embeddings", - ] diff --git a/graphrag/api/query.py b/graphrag/api/query.py index 15e575896a..60b7b943d5 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -18,7 +18,6 @@ """ from collections.abc import AsyncGenerator -from pathlib import Path from typing import TYPE_CHECKING, Any import pandas as pd @@ -26,7 +25,6 @@ from graphrag.config.embeddings import ( community_full_content_embedding, - create_collection_name, entity_description_embedding, text_unit_text_embedding, ) @@ -47,9 +45,13 @@ read_indexer_reports, read_indexer_text_units, ) +from graphrag.utils.api import ( + get_embedding_store, + load_search_prompt, + reformat_context_data, + update_context_data, +) from graphrag.utils.cli import redact -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.factory import VectorStoreFactory if TYPE_CHECKING: from graphrag.query.structured_search.base import SearchResult @@ -102,11 +104,11 @@ async def global_search( dynamic_community_selection=dynamic_community_selection, ) entities_ = read_indexer_entities(nodes, entities, community_level=community_level) - map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt) - reduce_prompt = _load_search_prompt( + map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt) + reduce_prompt = load_search_prompt( config.root_dir, config.global_search.reduce_prompt ) - knowledge_prompt = _load_search_prompt( + knowledge_prompt = load_search_prompt( config.root_dir, config.global_search.knowledge_prompt ) @@ -123,7 +125,7 @@ async def global_search( ) result: SearchResult = await search_engine.asearch(query=query) response = result.response - context_data = _reformat_context_data(result.context_data) # type: ignore + context_data = reformat_context_data(result.context_data) # type: ignore return response, context_data @@ -171,11 +173,11 @@ async def global_search_streaming( dynamic_community_selection=dynamic_community_selection, ) entities_ = read_indexer_entities(nodes, entities, community_level=community_level) - map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt) - reduce_prompt = _load_search_prompt( + map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt) + reduce_prompt = load_search_prompt( config.root_dir, config.global_search.reduce_prompt ) - knowledge_prompt = _load_search_prompt( + knowledge_prompt = load_search_prompt( config.root_dir, config.global_search.knowledge_prompt ) @@ -198,13 +200,177 @@ async def global_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) # type: ignore + context_data = reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: yield stream_chunk +@validate_call(config={"arbitrary_types_allowed": True}) +async def multi_index_global_search( + config: GraphRagConfig, + nodes_list: list[pd.DataFrame], + entities_list: list[pd.DataFrame], + communities_list: list[pd.DataFrame], + community_reports_list: list[pd.DataFrame], + index_names: list[str], + community_level: int | None, + dynamic_community_selection: bool, + response_type: str, + streaming: bool, + query: str, +) -> ( + tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], + ] + | AsyncGenerator +): + """Perform a global search across multiple indexes and return the context data and response. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes_list (list[pd.DataFrame]): A list of DataFrames containing the final nodes (from create_final_nodes.parquet) + - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from create_final_entities.parquet) + - communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from create_final_communities.parquet) + - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from create_final_community_reports.parquet) + - index_names (list[str]): A list of index names. + - community_level (int): The community level to search at. + - dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search. + - response_type (str): The type of response to return. + - streaming (bool): Whether to stream the results or not. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + # Streaming not supported yet + if streaming: + message = "Streaming not yet implemented for multi_global_search" + raise NotImplementedError(message) + + links = { + "nodes": {}, + "community": {}, + "community_reports": {}, + "entities": {}, + } + max_vals = { + "nodes": -1, + "community": -1, + "community_reports": -1, + "entities": -1, + } + + communities_dfs = [] + community_reports_dfs = [] + entities_dfs = [] + nodes_dfs = [] + + for idx, index_name in enumerate(index_names): + # Prepare each index's nodes dataframe for merging + nodes_df = nodes_list[idx] + nodes_df["community"] = nodes_df["community"].astype(int) + for i in nodes_df["human_readable_id"]: + links["nodes"][i + max_vals["nodes"] + 1] = { + "index_name": index_name, + "id": i, + } + if max_vals["nodes"] != -1: + nodes_df["human_readable_id"] += max_vals["nodes"] + 1 + nodes_df["community"] = nodes_df["community"].apply( + lambda x: x + max_vals["community_reports"] + 1 if x != -1 else x + ) + nodes_df["title"] = nodes_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + max_vals["nodes"] = int(nodes_df["human_readable_id"].max()) + nodes_dfs.append(nodes_df) + + # Prepare each index's community reports dataframe for merging + community_reports_df = community_reports_list[idx] + community_reports_df["community"] = community_reports_df["community"].astype( + int + ) + for i in community_reports_df["community"]: + links["community_reports"][i + max_vals["community_reports"] + 1] = { + "index_name": index_name, + "id": str(i), + } + community_reports_df["community"] += max_vals["community_reports"] + 1 + community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 + max_vals["community_reports"] = int(community_reports_df["community"].max()) + community_reports_dfs.append(community_reports_df) + + # Prepare each index's communities dataframe for merging + communities_df = communities_list[idx] + communities_df["community"] = communities_df["community"].astype(int) + communities_df["parent"] = communities_df["parent"].astype(int) + for i in communities_df["community"]: + links["community"][i + max_vals["community"] + 1] = { + "index_name": index_name, + "id": str(i), + } + communities_df["community"] += max_vals["community"] + 1 + communities_df["parent"] = communities_df["parent"].apply( + lambda x: x if x == -1 else x + max_vals["community"] + 1 + ) + communities_df["human_readable_id"] += max_vals["community"] + 1 + max_vals["community"] = int(communities_df["community"].max()) + communities_dfs.append(communities_df) + + # Prepare each index's entities dataframe for merging + entities_df = entities_list[idx] + for i in entities_df["human_readable_id"]: + links["entities"][i + max_vals["entities"] + 1] = { + "index_name": index_name, + "id": i, + } + entities_df["human_readable_id"] += max_vals["entities"] + 1 + entities_df["title"] = entities_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( + lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] + ) + max_vals["entities"] = int(entities_df["human_readable_id"].max()) + entities_dfs.append(entities_df) + + # Merge the dataframes + nodes_combined = pd.concat(nodes_dfs, axis=0, ignore_index=True, sort=False) + community_reports_combined = pd.concat( + community_reports_dfs, axis=0, ignore_index=True, sort=False + ) + entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) + communities_combined = pd.concat( + communities_dfs, axis=0, ignore_index=True, sort=False + ) + + result = await global_search( + config, + nodes=nodes_combined, + entities=entities_combined, + communities=communities_combined, + community_reports=community_reports_combined, + community_level=community_level, + dynamic_community_selection=dynamic_community_selection, + response_type=response_type, + query=query, + ) + + # Update the context data by linking index names and community ids + context = update_context_data(result[1], links) + + return (result[0], context) + + @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( config: GraphRagConfig, @@ -244,17 +410,18 @@ async def local_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=entity_description_embedding, ) - entities_ = read_indexer_entities(nodes, entities, community_level) covariates_ = read_indexer_covariates(covariates) if covariates is not None else [] - prompt = _load_search_prompt(config.root_dir, config.local_search.prompt) + prompt = load_search_prompt(config.root_dir, config.local_search.prompt) search_engine = get_local_search_engine( config=config, @@ -270,7 +437,7 @@ async def local_search( result: SearchResult = await search_engine.asearch(query=query) response = result.response - context_data = _reformat_context_data(result.context_data) # type: ignore + context_data = reformat_context_data(result.context_data) # type: ignore return response, context_data @@ -310,17 +477,19 @@ async def local_search_streaming( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=entity_description_embedding, ) entities_ = read_indexer_entities(nodes, entities, community_level) covariates_ = read_indexer_covariates(covariates) if covariates is not None else [] - prompt = _load_search_prompt(config.root_dir, config.local_search.prompt) + prompt = load_search_prompt(config.root_dir, config.local_search.prompt) search_engine = get_local_search_engine( config=config, @@ -341,7 +510,7 @@ async def local_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) # type: ignore + context_data = reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: @@ -349,7 +518,242 @@ async def local_search_streaming( @validate_call(config={"arbitrary_types_allowed": True}) -async def drift_search_streaming( +async def multi_index_local_search( + config: GraphRagConfig, + nodes_list: list[pd.DataFrame], + entities_list: list[pd.DataFrame], + community_reports_list: list[pd.DataFrame], + text_units_list: list[pd.DataFrame], + relationships_list: list[pd.DataFrame], + covariates_list: list[pd.DataFrame] | None, + index_names: list[str], + community_level: int, + response_type: str, + streaming: bool, + query: str, +) -> ( + tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], + ] + | AsyncGenerator +): + """Perform a local search across multiple indexes and return the context data and response. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes_list (list[pd.DataFrame]): A list of DataFrames containing the final nodes (from create_final_nodes.parquet) + - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from create_final_entities.parquet) + - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from create_final_community_reports.parquet) + - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from create_final_text_units.parquet) + - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from create_final_relationships.parquet) + - covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from create_final_covariates.parquet) + - index_names (list[str]): A list of index names. + - community_level (int): The community level to search at. + - response_type (str): The response type to return. + - streaming (bool): Whether to stream the results or not. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + # Streaming not supported yet + if streaming: + message = "Streaming not yet implemented for multi_index_local_search" + raise NotImplementedError(message) + + links = { + "nodes": {}, + "community_reports": {}, + "entities": {}, + "text_units": {}, + "relationships": {}, + "covariates": {}, + } + max_vals = { + "nodes": -1, + "community_reports": -1, + "entities": -1, + "text_units": 0, + "relationships": -1, + "covariates": 0, + } + + community_reports_dfs = [] + entities_dfs = [] + nodes_dfs = [] + relationships_dfs = [] + text_units_dfs = [] + covariates_dfs = [] + + for idx, index_name in enumerate(index_names): + # Prepare each index's nodes dataframe for merging + nodes_df = nodes_list[idx] + nodes_df["community"] = nodes_df["community"].astype(int) + for i in nodes_df["human_readable_id"]: + links["nodes"][i + max_vals["nodes"] + 1] = { + "index_name": index_name, + "id": i, + } + if max_vals["nodes"] != -1: + nodes_df["human_readable_id"] += max_vals["nodes"] + 1 + nodes_df["community"] = nodes_df["community"].apply( + lambda x: x + max_vals["community_reports"] + 1 if x != -1 else x + ) + nodes_df["title"] = nodes_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + nodes_df["id"] = nodes_df["id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + max_vals["nodes"] = int(nodes_df["human_readable_id"].max()) + nodes_dfs.append(nodes_df) + + # Prepare each index's community reports dataframe for merging + community_reports_df = community_reports_list[idx] + community_reports_df["community"] = community_reports_df["community"].astype( + int + ) + for i in community_reports_df["community"]: + links["community_reports"][i + max_vals["community_reports"] + 1] = { + "index_name": index_name, + "id": str(i), + } + community_reports_df["community"] += max_vals["community_reports"] + 1 + community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 + max_vals["community_reports"] = int(community_reports_df["community"].max()) + community_reports_dfs.append(community_reports_df) + + # Prepare each index's entities dataframe for merging + entities_df = entities_list[idx] + for i in entities_df["human_readable_id"]: + links["entities"][i + max_vals["entities"] + 1] = { + "index_name": index_name, + "id": i, + } + entities_df["human_readable_id"] += max_vals["entities"] + 1 + entities_df["title"] = entities_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + entities_df["id"] = entities_df["id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( + lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] + ) + max_vals["entities"] = int(entities_df["human_readable_id"].max()) + entities_dfs.append(entities_df) + + # Prepare each index's relationships dataframe for merging + relationships_df = relationships_list[idx] + for i in relationships_df["human_readable_id"].astype(int): + links["relationships"][i + max_vals["relationships"] + 1] = { + "index_name": index_name, + "id": i, + } + if max_vals["relationships"] != -1: + col = ( + relationships_df["human_readable_id"].astype(int) + + max_vals["relationships"] + + 1 + ) + relationships_df["human_readable_id"] = col.astype(str) + relationships_df["source"] = relationships_df["source"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + relationships_df["target"] = relationships_df["target"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( + lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] + ) + max_vals["relationships"] = int(relationships_df["human_readable_id"].max()) + relationships_dfs.append(relationships_df) + + # Prepare each index's text units dataframe for merging + text_units_df = text_units_list[idx] + for i in range(text_units_df.shape[0]): + links["text_units"][i + max_vals["text_units"]] = { + "index_name": index_name, + "id": i, + } + text_units_df["id"] = text_units_df["id"].apply( + lambda x, index_name=index_name: f"{x}-{index_name}" + ) + text_units_df["human_readable_id"] = ( + text_units_df["human_readable_id"] + max_vals["text_units"] + ) + max_vals["text_units"] += text_units_df.shape[0] + text_units_dfs.append(text_units_df) + + # If presents, prepare each index's covariates dataframe for merging + if covariates_list is not None: + covariates_df = covariates_list[idx] + for i in covariates_df["human_readable_id"].astype(int): + links["covariates"][i + max_vals["covariates"]] = { + "index_name": index_name, + "id": i, + } + covariates_df["id"] = covariates_df["id"].apply( + lambda x, index_name=index_name: f"{x}-{index_name}" + ) + covariates_df["human_readable_id"] = ( + covariates_df["human_readable_id"] + max_vals["covariates"] + ) + covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + covariates_df["subject_id"] = covariates_df["subject_id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + max_vals["covariates"] += covariates_df.shape[0] + covariates_dfs.append(covariates_df) + + # Merge the dataframes + nodes_combined = pd.concat(nodes_dfs, axis=0, ignore_index=True, sort=False) + community_reports_combined = pd.concat( + community_reports_dfs, axis=0, ignore_index=True, sort=False + ) + entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) + relationships_combined = pd.concat( + relationships_dfs, axis=0, ignore_index=True, sort=False + ) + text_units_combined = pd.concat( + text_units_dfs, axis=0, ignore_index=True, sort=False + ) + covariates_combined = None + if len(covariates_dfs) > 0: + covariates_combined = pd.concat( + covariates_dfs, axis=0, ignore_index=True, sort=False + ) + + result = await local_search( + config, + nodes=nodes_combined, + entities=entities_combined, + community_reports=community_reports_combined, + text_units=text_units_combined, + relationships=relationships_combined, + covariates=covariates_combined, + community_level=community_level, + response_type=response_type, + query=query, + ) + + # Update the context data by linking index names and community ids + context = update_context_data(result[1], links) + + return (result[0], context) + + +@validate_call(config={"arbitrary_types_allowed": True}) +async def drift_search( config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -359,7 +763,10 @@ async def drift_search_streaming( community_level: int, response_type: str, query: str, -) -> AsyncGenerator: +) -> tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], +]: """Perform a DRIFT search and return the context data and response. Parameters @@ -381,15 +788,17 @@ async def drift_search_streaming( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=entity_description_embedding, ) - full_content_embedding_store = _get_embedding_store( + full_content_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=community_full_content_embedding, ) @@ -397,11 +806,10 @@ async def drift_search_streaming( entities_ = read_indexer_entities(nodes, entities, community_level) reports = read_indexer_reports(community_reports, nodes, community_level) read_indexer_report_embeddings(reports, full_content_embedding_store) - prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt) - reduce_prompt = _load_search_prompt( + prompt = load_search_prompt(config.root_dir, config.drift_search.prompt) + reduce_prompt = load_search_prompt( config.root_dir, config.drift_search.reduce_prompt ) - search_engine = get_drift_search_engine( config=config, reports=reports, @@ -414,23 +822,17 @@ async def drift_search_streaming( response_type=response_type, ) - search_result = search_engine.astream_search(query=query) + result: SearchResult = await search_engine.asearch(query=query) + response = result.response + context_data = {} + for key in result.context_data: + context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore - # when streaming results, a context data object is returned as the first result - # and the query response in subsequent tokens - context_data = None - get_context_data = True - async for stream_chunk in search_result: - if get_context_data: - context_data = _reformat_context_data(stream_chunk) # type: ignore - yield context_data - get_context_data = False - else: - yield stream_chunk + return response, context_data @validate_call(config={"arbitrary_types_allowed": True}) -async def drift_search( +async def drift_search_streaming( config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -440,10 +842,7 @@ async def drift_search( community_level: int, response_type: str, query: str, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: +) -> AsyncGenerator: """Perform a DRIFT search and return the context data and response. Parameters @@ -465,15 +864,17 @@ async def drift_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=entity_description_embedding, ) - full_content_embedding_store = _get_embedding_store( + full_content_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=community_full_content_embedding, ) @@ -481,8 +882,8 @@ async def drift_search( entities_ = read_indexer_entities(nodes, entities, community_level) reports = read_indexer_reports(community_reports, nodes, community_level) read_indexer_report_embeddings(reports, full_content_embedding_store) - prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt) - reduce_prompt = _load_search_prompt( + prompt = load_search_prompt(config.root_dir, config.drift_search.prompt) + reduce_prompt = load_search_prompt( config.root_dir, config.drift_search.reduce_prompt ) @@ -498,11 +899,230 @@ async def drift_search( response_type=response_type, ) - result: SearchResult = await search_engine.asearch(query=query) - response = result.response - context_data = _reformat_context_data(result.context_data) # type: ignore + search_result = search_engine.astream_search(query=query) - return response, context_data + # when streaming results, a context data object is returned as the first result + # and the query response in subsequent tokens + context_data = None + get_context_data = True + async for stream_chunk in search_result: + if get_context_data: + context_data = reformat_context_data(stream_chunk) # type: ignore + yield context_data + get_context_data = False + else: + yield stream_chunk + + +@validate_call(config={"arbitrary_types_allowed": True}) +async def multi_index_drift_search( + config: GraphRagConfig, + nodes_list: list[pd.DataFrame], + entities_list: list[pd.DataFrame], + community_reports_list: list[pd.DataFrame], + text_units_list: list[pd.DataFrame], + relationships_list: list[pd.DataFrame], + index_names: list[str], + community_level: int, + response_type: str, + streaming: bool, + query: str, +) -> ( + tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], + ] + | AsyncGenerator +): + """Perform a DRIFT search across multiple indexes and return the context data and response. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes_list (list[pd.DataFrame]): A list of DataFrames containing the final nodes (from create_final_nodes.parquet) + - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from create_final_entities.parquet) + - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from create_final_community_reports.parquet) + - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from create_final_text_units.parquet) + - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from create_final_relationships.parquet) + - index_names (list[str]): A list of index names. + - community_level (int): The community level to search at. + - response_type (str): The response type to return. + - streaming (bool): Whether to stream the results or not. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + # Streaming not supported yet + if streaming: + message = "Streaming not yet implemented for multi_drift_search" + raise NotImplementedError(message) + + links = { + "nodes": {}, + "community_reports": {}, + "entities": {}, + "text_units": {}, + "relationships": {}, + } + max_vals = { + "nodes": -1, + "community_reports": -1, + "entities": -1, + "text_units": 0, + "relationships": -1, + } + + community_reports_dfs = [] + entities_dfs = [] + nodes_dfs = [] + relationships_dfs = [] + text_units_dfs = [] + + for idx, index_name in enumerate(index_names): + # Prepare each index's nodes dataframe for merging + nodes_df = nodes_list[idx] + nodes_df["community"] = nodes_df["community"].astype(int) + for i in nodes_df["human_readable_id"]: + links["nodes"][i + max_vals["nodes"] + 1] = { + "index_name": index_name, + "id": i, + } + if max_vals["nodes"] != -1: + nodes_df["human_readable_id"] += max_vals["nodes"] + 1 + nodes_df["community"] = nodes_df["community"].apply( + lambda x: x + max_vals["community_reports"] + 1 if x != -1 else x + ) + nodes_df["title"] = nodes_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + nodes_df["id"] = nodes_df["id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + max_vals["nodes"] = int(nodes_df["human_readable_id"].max()) + nodes_dfs.append(nodes_df) + + # Prepare each index's community reports dataframe for merging + community_reports_df = community_reports_list[idx] + community_reports_df["community"] = community_reports_df["community"].astype( + int + ) + for i in community_reports_df["community"]: + links["community_reports"][i + max_vals["community_reports"] + 1] = { + "index_name": index_name, + "id": str(i), + } + community_reports_df["community"] += max_vals["community_reports"] + 1 + community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 + community_reports_df["id"] = community_reports_df["id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + max_vals["community_reports"] = int(community_reports_df["community"].max()) + community_reports_dfs.append(community_reports_df) + + # Prepare each index's entities dataframe for merging + entities_df = entities_list[idx] + for i in entities_df["human_readable_id"]: + links["entities"][i + max_vals["entities"] + 1] = { + "index_name": index_name, + "id": i, + } + entities_df["human_readable_id"] += max_vals["entities"] + 1 + entities_df["title"] = entities_df["title"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + entities_df["id"] = entities_df["id"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( + lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] + ) + max_vals["entities"] = int(entities_df["human_readable_id"].max()) + entities_dfs.append(entities_df) + + # Prepare each index's relationships dataframe for merging + relationships_df = relationships_list[idx] + for i in relationships_df["human_readable_id"].astype(int): + links["relationships"][i + max_vals["relationships"] + 1] = { + "index_name": index_name, + "id": i, + } + if max_vals["relationships"] != -1: + col = ( + relationships_df["human_readable_id"].astype(int) + + max_vals["relationships"] + + 1 + ) + relationships_df["human_readable_id"] = col.astype(str) + relationships_df["source"] = relationships_df["source"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + relationships_df["target"] = relationships_df["target"].apply( + lambda x, index_name=index_name: x + f"-{index_name}" + ) + relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( + lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] + ) + max_vals["relationships"] = int( + relationships_df["human_readable_id"].astype(int).max() + ) + + relationships_dfs.append(relationships_df) + + # Prepare each index's text units dataframe for merging + text_units_df = text_units_list[idx] + for i in range(text_units_df.shape[0]): + links["text_units"][i + max_vals["text_units"]] = { + "index_name": index_name, + "id": i, + } + text_units_df["id"] = text_units_df["id"].apply( + lambda x, index_name=index_name: f"{x}-{index_name}" + ) + text_units_df["human_readable_id"] = ( + text_units_df["human_readable_id"] + max_vals["text_units"] + ) + max_vals["text_units"] += text_units_df.shape[0] + text_units_dfs.append(text_units_df) + + # Merge the dataframes + nodes_combined = pd.concat(nodes_dfs, axis=0, ignore_index=True, sort=False) + community_reports_combined = pd.concat( + community_reports_dfs, axis=0, ignore_index=True, sort=False + ) + entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) + relationships_combined = pd.concat( + relationships_dfs, axis=0, ignore_index=True, sort=False + ) + text_units_combined = pd.concat( + text_units_dfs, axis=0, ignore_index=True, sort=False + ) + + result = await drift_search( + config, + nodes=nodes_combined, + entities=entities_combined, + community_reports=community_reports_combined, + text_units=text_units_combined, + relationships=relationships_combined, + community_level=community_level, + response_type=response_type, + query=query, + ) + + # Update the context data by linking index names and community ids + context = {} + if type(result[1]) is dict: + for key in result[1]: + context[key] = update_context_data(result[1][key], links) + else: + context = result[1] + return (result[0], context) @validate_call(config={"arbitrary_types_allowed": True}) @@ -520,7 +1140,6 @@ async def basic_search( ---------- - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet) - - response_type (str): The response type to return. - query (str): The user query to search for. Returns @@ -531,15 +1150,17 @@ async def basic_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=text_unit_text_embedding, ) - prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt) + prompt = load_search_prompt(config.root_dir, config.basic_search.prompt) search_engine = get_basic_search_engine( config=config, @@ -550,7 +1171,7 @@ async def basic_search( result: SearchResult = await search_engine.asearch(query=query) response = result.response - context_data = _reformat_context_data(result.context_data) # type: ignore + context_data = reformat_context_data(result.context_data) # type: ignore return response, context_data @@ -576,15 +1197,19 @@ async def basic_search_streaming( ------ TODO: Document any exceptions to expect. """ - vector_store_args = config.vector_store.model_dump() + vector_store_args = {} + for index, store in config.vector_store.items(): + vector_store_args[index] = store.model_dump() + else: + vector_store_args = None logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - description_embedding_store = _get_embedding_store( + description_embedding_store = get_embedding_store( config_args=vector_store_args, # type: ignore embedding_name=text_unit_text_embedding, ) - prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt) + prompt = load_search_prompt(config.root_dir, config.basic_search.prompt) search_engine = get_basic_search_engine( config=config, @@ -601,70 +1226,83 @@ async def basic_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) # type: ignore + context_data = reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: yield stream_chunk -def _get_embedding_store( - config_args: dict, - embedding_name: str, -) -> BaseVectorStore: - """Get the embedding description store.""" - vector_store_type = config_args["type"] - collection_name = create_collection_name( - config_args.get("container_name", "default"), embedding_name - ) - embedding_store = VectorStoreFactory().create_vector_store( - vector_store_type=vector_store_type, - kwargs={**config_args, "collection_name": collection_name}, - ) - embedding_store.connect(**config_args) - return embedding_store - +@validate_call(config={"arbitrary_types_allowed": True}) +async def multi_index_basic_search( + config: GraphRagConfig, + text_units_list: list[pd.DataFrame], + index_names: list[str], + streaming: bool, + query: str, +) -> ( + tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], + ] + | AsyncGenerator +): + """Perform a basic search across multiple indexes and return the context data and response. -def _reformat_context_data(context_data: dict) -> dict: - """ - Reformats context_data for all query responses. + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from create_final_text_units.parquet) + - index_names (list[str]): A list of index names. + - streaming (bool): Whether to stream the results or not. + - query (str): The user query to search for. - Reformats a dictionary of dataframes into a dictionary of lists. - One list entry for each record. Records are grouped by original - dictionary keys. + Returns + ------- + TODO: Document the search response type and format. - Note: depending on which query algorithm is used, the context_data may not - contain the same information (keys). In this case, the default behavior will be to - set these keys as empty lists to preserve a standard output format. + Raises + ------ + TODO: Document any exceptions to expect. """ - final_format = { - "reports": [], - "entities": [], - "relationships": [], - "claims": [], - "sources": [], - } - for key in context_data: - records = ( - context_data[key].to_dict(orient="records") - if context_data[key] is not None and not isinstance(context_data[key], dict) - else context_data[key] - ) - if len(records) < 1: - continue - final_format[key] = records - return final_format + # Streaming not supported yet + if streaming: + message = "Streaming not yet implemented for multi_basic_search" + raise NotImplementedError(message) + links = { + "text_units": {}, + } + max_vals = { + "text_units": 0, + } -def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: - """ - Load the search prompt from disk if configured. + text_units_dfs = [] + + for idx, index_name in enumerate(index_names): + # Prepare each index's text units dataframe for merging + text_units_df = text_units_list[idx] + for i in range(text_units_df.shape[0]): + links["text_units"][i + max_vals["text_units"]] = { + "index_name": index_name, + "id": i, + } + text_units_df["id"] = text_units_df["id"].apply( + lambda x, index_name=index_name: f"{x}-{index_name}" + ) + text_units_df["human_readable_id"] = ( + text_units_df["human_readable_id"] + max_vals["text_units"] + ) + max_vals["text_units"] += text_units_df.shape[0] + text_units_dfs.append(text_units_df) - If not, leave it empty - the search functions will load their defaults. + # Merge the dataframes + text_units_combined = pd.concat( + text_units_dfs, axis=0, ignore_index=True, sort=False + ) - """ - if prompt_config: - prompt_file = Path(root_dir) / prompt_config - if prompt_file.exists(): - return prompt_file.read_bytes().decode(encoding="utf-8") - return None + return await basic_search( + config, + text_units=text_units_combined, + query=query, + ) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 1feeacd625..cb961b8b51 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -106,6 +106,7 @@ VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb") VECTOR_STORE_CONTAINER_NAME = "default" VECTOR_STORE_OVERWRITE = True +VECTOR_STORE_INDEX_NAME = "output" # Local Search LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 diff --git a/graphrag/config/embeddings.py b/graphrag/config/embeddings.py index 5b7404606d..a322290125 100644 --- a/graphrag/config/embeddings.py +++ b/graphrag/config/embeddings.py @@ -57,7 +57,14 @@ def get_embedding_settings( embeddings_llm_settings = settings.get_language_model_config( settings.embeddings.model_id ) - vector_store_settings = settings.vector_store.model_dump() + num_entries = len(settings.vector_store) + if num_entries == 1: + store = next(iter(settings.vector_store.values())) + vector_store_settings = store.model_dump() + else: + # The vector_store dict should only have more than one entry for multi-index query + vector_store_settings = None + if vector_store_settings is None: return { "strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings) diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index db42095724..f510ef7f2b 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -40,6 +40,7 @@ # deployment_name: vector_store: + {defs.VECTOR_STORE_INDEX_NAME}: type: {defs.VECTOR_STORE_TYPE} db_uri: {defs.VECTOR_STORE_DB_URI} container_name: {defs.VECTOR_STORE_CONTAINER_NAME} diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 82a333d9ed..8c3b9e867d 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -224,20 +224,20 @@ def _validate_update_index_output_base_dir(self) -> None: ) """The basic search configuration.""" - vector_store: VectorStoreConfig = Field( - description="The vector store configuration.", default=VectorStoreConfig() + vector_store: dict[str, VectorStoreConfig] = Field( + description="The vector store configuration.", + default={"default": VectorStoreConfig()}, ) """The vector store configuration.""" def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" - if self.vector_store.type == VectorStoreType.LanceDB: - if not self.vector_store.db_uri or self.vector_store.db_uri.strip == "": - msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(msg) - self.vector_store.db_uri = str( - (Path(self.root_dir) / self.vector_store.db_uri).resolve() - ) + for store in self.vector_store.values(): + if store.type == VectorStoreType.LanceDB: + if not store.db_uri or store.db_uri.strip == "": + msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." + raise ValueError(msg) + store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) def get_language_model_config(self, model_id: str) -> LanguageModelConfig: """Get a model configuration by ID. diff --git a/graphrag/config/models/vector_store_config.py b/graphrag/config/models/vector_store_config.py index 292f69c884..055f559e32 100644 --- a/graphrag/config/models/vector_store_config.py +++ b/graphrag/config/models/vector_store_config.py @@ -45,10 +45,16 @@ def _validate_url(self) -> None: msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) - if self.type != VectorStoreType.AzureAISearch and ( + if self.type == VectorStoreType.CosmosDB and ( + self.url is None or self.url.strip() == "" + ): + msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." + raise ValueError(msg) + + if self.type == VectorStoreType.LanceDB and ( self.url is not None and self.url.strip() != "" ): - msg = "vector_store.url is only used when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type." + msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) api_key: str | None = Field( @@ -61,12 +67,16 @@ def _validate_url(self) -> None: default=None, ) - container_name: str = Field( - description="The database name to use.", + container_name: str | list[str] = Field( + description="The container name to use.", default=defs.VECTOR_STORE_CONTAINER_NAME, ) - overwrite: bool = Field( + database_name: str | None = Field( + description="The database name to use when type == cosmos_db.", default=None + ) + + overwrite: bool | list[str] = Field( description="Overwrite the existing data.", default=defs.VECTOR_STORE_OVERWRITE ) diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py new file mode 100644 index 0000000000..c714e754b9 --- /dev/null +++ b/graphrag/utils/api.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""API functions for the GraphRAG module.""" + +from pathlib import Path +from typing import Any + +from graphrag.config.embeddings import create_collection_name +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.model.types import TextEmbedder +from graphrag.vector_stores.base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) +from graphrag.vector_stores.factory import VectorStoreFactory + + +class MultiVectorStore(BaseVectorStore): + """Multi Vector Store wrapper implementation.""" + + def __init__( + self, + embedding_stores: list[BaseVectorStore], + index_names: list[str], + ): + self.embedding_stores = embedding_stores + self.index_names = index_names + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into the vector store.""" + msg = "load_documents method not implemented" + raise NotImplementedError(msg) + + def connect(self, **kwargs: Any) -> Any: + """Connect to vector storage.""" + msg = "connect method not implemented" + raise NotImplementedError(msg) + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + msg = "filter_by_id method not implemented" + raise NotImplementedError(msg) + + def search_by_id(self, id: str) -> VectorStoreDocument: + """Search for a document by id.""" + search_index_id = id.split("-")[0] + search_index_name = id.split("-")[1] + for index_name, embedding_store in zip( + self.index_names, self.embedding_stores, strict=False + ): + if index_name == search_index_name: + return embedding_store.search_by_id(search_index_id) + else: + message = f"Index {search_index_name} not found." + raise ValueError(message) + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + all_results = [] + for index_name, embedding_store in zip( + self.index_names, self.embedding_stores, strict=False + ): + results = embedding_store.similarity_search_by_vector( + query_embedding=query_embedding, k=k + ) + mod_results = [] + for r in results: + r.document.id = str(r.document.id) + f"-{index_name}" + mod_results += [r] + all_results += mod_results + return sorted(all_results, key=lambda x: x.score, reverse=True)[:k] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a text-based similarity search.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector( + query_embedding=query_embedding, k=k + ) + return [] + + +def get_embedding_store( + config_args: dict[str, dict], + embedding_name: str, +) -> BaseVectorStore: + """Get the embedding description store.""" + num_indexes = len(config_args) + embedding_stores = [] + index_names = [] + for index, store in config_args.items(): + vector_store_type = store["type"] + collection_name = create_collection_name( + store.get("container_name", "default"), embedding_name + ) + embedding_store = VectorStoreFactory().create_vector_store( + vector_store_type=vector_store_type, + kwargs={**store, "collection_name": collection_name}, + ) + embedding_store.connect(**store) + # If there is only a single index, return the embedding store directly + if num_indexes == 1: + return embedding_store + embedding_stores.append(embedding_store) + index_names.append(index) + return MultiVectorStore(embedding_stores, index_names) + + +def reformat_context_data(context_data: dict) -> dict: + """ + Reformats context_data for all query responses. + + Reformats a dictionary of dataframes into a dictionary of lists. + One list entry for each record. Records are grouped by original + dictionary keys. + + Note: depending on which query algorithm is used, the context_data may not + contain the same information (keys). In this case, the default behavior will be to + set these keys as empty lists to preserve a standard output format. + """ + final_format = { + "reports": [], + "entities": [], + "relationships": [], + "claims": [], + "sources": [], + } + for key in context_data: + records = ( + context_data[key].to_dict(orient="records") + if context_data[key] is not None and not isinstance(context_data[key], dict) + else context_data[key] + ) + if len(records) < 1: + continue + final_format[key] = records + return final_format + + +def update_context_data( + context_data: Any, + links: dict[str, Any], +) -> Any: + """ + Update context data with the links dict so that it contains both the index name and community id. + + Parameters + ---------- + - context_data (str | list[pd.DataFrame] | dict[str, pd.DataFrame]): The context data to update. + - links (dict[str, Any]): A dictionary of links to the original dataframes. + + Returns + ------- + str | list[pd.DataFrame] | dict[str, pd.DataFrame]: The updated context data. + """ + updated_context_data = {} + for key in context_data: + updated_entry = [] + if key == "reports": + updated_entry = [ + dict( + {k: entry[k] for k in entry}, + index_name=links["community_reports"][int(entry["id"])][ + "index_name" + ], + index_id=links["community_reports"][int(entry["id"])]["id"], + ) + for entry in context_data[key] + ] + if key == "entities": + updated_entry = [ + dict( + {k: entry[k] for k in entry}, + entity=entry["entity"].split("-")[0], + index_name=links["entities"][int(entry["id"])]["index_name"], + index_id=links["entities"][int(entry["id"])]["id"], + ) + for entry in context_data[key] + ] + if key == "relationships": + updated_entry = [ + dict( + {k: entry[k] for k in entry}, + source=entry["source"].split("-")[0], + target=entry["target"].split("-")[0], + index_name=links["relationships"][int(entry["id"])]["index_name"], + index_id=links["relationships"][int(entry["id"])]["id"], + ) + for entry in context_data[key] + ] + if key == "claims": + updated_entry = [ + dict( + {k: entry[k] for k in entry}, + entity=entry["entity"].split("-")[0], + index_name=links["covariates"][int(entry["id"])]["index_name"], + index_id=links["covariates"][int(entry["id"])]["id"], + ) + for entry in context_data[key] + ] + if key == "sources": + updated_entry = [ + dict( + {k: entry[k] for k in entry}, + index_name=links["text_units"][int(entry["id"])]["index_name"], + index_id=links["text_units"][int(entry["id"])]["id"], + ) + for entry in context_data[key] + ] + updated_context_data[key] = updated_entry + return updated_context_data + + +def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: + """ + Load the search prompt from disk if configured. + + If not, leave it empty - the search functions will load their defaults. + + """ + if prompt_config: + prompt_file = Path(root_dir) / prompt_config + if prompt_file.exists(): + return prompt_file.read_bytes().decode(encoding="utf-8") + return None + + +def get_workflows_list(config: GraphRagConfig) -> list[str]: + """Return a list of workflows for the indexing pipeline.""" + return [ + "create_base_text_units", + "create_final_documents", + "extract_graph", + "compute_communities", + "create_final_entities", + "create_final_relationships", + "create_final_nodes", + "create_final_communities", + *(["create_final_covariates"] if config.claim_extraction.enabled else []), + "create_final_text_units", + "create_final_community_reports", + "generate_text_embeddings", + ] diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index 5ecec80990..3f054b6717 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -3,10 +3,11 @@ claim_extraction: embeddings: vector_store: - type: "azure_ai_search" - url: ${AZURE_AI_SEARCH_URL_ENDPOINT} - api_key: ${AZURE_AI_SEARCH_API_KEY} - container_name: "azure_ci" + output: + type: "azure_ai_search" + url: ${AZURE_AI_SEARCH_URL_ENDPOINT} + api_key: ${AZURE_AI_SEARCH_API_KEY} + container_name: "azure_ci" input: type: blob diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index bfe242d63b..09642c9260 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -26,10 +26,11 @@ models: async_mode: threaded vector_store: - type: "lancedb" - db_uri: "./tests/fixtures/min-csv/lancedb" - container_name: "lancedb_ci" - overwrite: True + output: + type: "lancedb" + db_uri: "./tests/fixtures/min-csv/lancedb" + container_name: "lancedb_ci" + overwrite: True input: file_type: csv diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index f949a2dd6f..09b5f13d38 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -26,10 +26,11 @@ models: async_mode: threaded vector_store: - type: "azure_ai_search" - url: ${AZURE_AI_SEARCH_URL_ENDPOINT} - api_key: ${AZURE_AI_SEARCH_API_KEY} - container_name: "simple_text_ci" + output: + type: "azure_ai_search" + url: ${AZURE_AI_SEARCH_URL_ENDPOINT} + api_key: ${AZURE_AI_SEARCH_API_KEY} + container_name: "simple_text_ci" claim_extraction: enabled: true diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index cdf9d7389a..c90f15a246 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -288,13 +288,6 @@ def test_fixture( result = self.__run_query(root, query) print(f"Query: {query}\nResponse: {result.stdout}") - # Check stderr because lancedb logs path creating as WARN which leads to false negatives - stderror = ( - result.stderr if "No existing dataset at" not in result.stderr else "" - ) - - assert stderror == "" or stderror.replace("\n", "") in KNOWN_WARNINGS, ( - f"Query failed with error: {stderror}" - ) + assert result.returncode == 0, "Query failed" assert result.stdout is not None, "Query returned no output" assert len(result.stdout) > 0, "Query returned empty output" diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 124929e57a..0e2e55f429 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -50,13 +50,15 @@ DEFAULT_GRAPHRAG_CONFIG_SETTINGS = { "models": DEFAULT_MODEL_CONFIG, "vector_store": { - "type": defs.VECTOR_STORE_TYPE, - "db_uri": defs.VECTOR_STORE_DB_URI, - "container_name": defs.VECTOR_STORE_CONTAINER_NAME, - "overwrite": defs.VECTOR_STORE_OVERWRITE, - "url": None, - "api_key": None, - "audience": None, + "default": { + "type": defs.VECTOR_STORE_TYPE, + "db_uri": defs.VECTOR_STORE_DB_URI, + "container_name": defs.VECTOR_STORE_CONTAINER_NAME, + "overwrite": defs.VECTOR_STORE_OVERWRITE, + "url": None, + "api_key": None, + "audience": None, + }, }, "reporting": { "type": defs.REPORTING_TYPE, @@ -283,14 +285,23 @@ def assert_language_model_configs( assert expected.responses is None -def assert_vector_store_configs(actual: VectorStoreConfig, expected: VectorStoreConfig): - assert actual.type == expected.type - assert actual.db_uri == expected.db_uri - assert actual.container_name == expected.container_name - assert actual.overwrite == expected.overwrite - assert actual.url == expected.url - assert actual.api_key == expected.api_key - assert actual.audience == expected.audience +def assert_vector_store_configs( + actual: dict[str, VectorStoreConfig], + expected: dict[str, VectorStoreConfig], +): + assert type(actual) is type(expected) + assert len(actual) == len(expected) + for (index_a, store_a), (index_e, store_e) in zip( + actual.items(), expected.items(), strict=True + ): + assert index_a == index_e + assert store_a.type == store_e.type + assert store_a.db_uri == store_e.db_uri + assert store_a.url == store_e.url + assert store_a.api_key == store_e.api_key + assert store_a.audience == store_e.audience + assert store_a.container_name == store_e.container_name + assert store_a.overwrite == store_e.overwrite def assert_reporting_configs(