Skip to content

Commit

Permalink
Reduce Drift Response and Streaming endpoint (#1624)
Browse files Browse the repository at this point in the history
* Adding basic wrappes for reduce in drift

* Add response_type parameter to run_drift_search and enhance reduce response functionality

* Add streaming endpoint

* Semver

* Spellcheck

* Ruff checks

* Count tokens on reduce

* Use list comprehension and remove llm_params map in favor of just using kwargs
  • Loading branch information
AlonsoGuevara authored Jan 15, 2025
1 parent 4637270 commit 3defab2
Show file tree
Hide file tree
Showing 16 changed files with 809 additions and 581 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250115181733910773.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add Drift Reduce response and streaming endpoint"
}
2 changes: 2 additions & 0 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
basic_search,
basic_search_streaming,
drift_search,
drift_search_streaming,
global_search,
global_search_streaming,
local_search,
Expand All @@ -29,6 +30,7 @@
"local_search",
"local_search_streaming",
"drift_search",
"drift_search_streaming",
"basic_search",
"basic_search_streaming",
# prompt tuning API
Expand Down
98 changes: 89 additions & 9 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,87 @@ async def local_search_streaming(
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)

full_content_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)

entities_ = read_indexer_entities(nodes, entities, community_level)
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)

search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
Expand All @@ -357,6 +438,7 @@ async def drift_search(
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
Expand Down Expand Up @@ -400,6 +482,10 @@ async def drift_search(
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
reports=reports,
Expand All @@ -408,21 +494,15 @@ async def drift_search(
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)

result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore

# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
# For the time being, return highest scoring response (position 0) and context data
match response:
case dict():
return response["nodes"][0]["answer"], context_data # type: ignore
case str():
return response, context_data
case list():
return response, context_data
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
Expand Down
6 changes: 5 additions & 1 deletion graphrag/cli/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
Expand Down Expand Up @@ -57,6 +60,7 @@ def initialize_project_at(path: Path) -> None:
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
Expand Down
3 changes: 2 additions & 1 deletion graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def _query_cli(
data_dir=data,
root_dir=root,
community_level=community_level,
streaming=False, # Drift search does not support streaming (yet)
streaming=streaming,
response_type=response_type,
query=query,
)
case SearchType.BASIC:
Expand Down
33 changes: 29 additions & 4 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_drift_search(
data_dir: Path | None,
root_dir: Path,
community_level: int,
response_type: str,
streaming: bool,
query: str,
):
Expand Down Expand Up @@ -234,8 +235,33 @@ def run_drift_search(

# call the Query API
if streaming:
error_msg = "Streaming is not supported yet for DRIFT search."
raise NotImplementedError(error_msg)

async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.drift_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data

return asyncio.run(run_streaming_search())

# not streaming
response, context_data = asyncio.run(
Expand All @@ -247,6 +273,7 @@ def run_drift_search(
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
)
)
Expand Down Expand Up @@ -281,8 +308,6 @@ def run_basic_search(
)
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]

print(streaming) # noqa: T201

# # call the Query API
if streaming:

Expand Down
5 changes: 5 additions & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def hydrate_parallelization_params(
):
drift_search_model = DRIFTSearchConfig(
prompt=reader.str("prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
Expand All @@ -597,6 +598,10 @@ def hydrate_parallelization_params(
or defs.DRIFT_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
reduce_max_tokens=reader.int("reduce_max_tokens")
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
reduce_temperature=reader.float("reduce_temperature")
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
drift_k_followups=reader.int("drift_k_followups")
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
DRIFT_SEARCH_PRIMER_FOLDS = 5
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000

DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE = 0
DRIFT_SEARCH_REDUCE_MAX_TOKENS = 2_000

DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
drift_search:
prompt: "prompts/drift_search_system_prompt.txt"
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
basic_search:
prompt: "prompts/basic_search_system_prompt.txt"
Expand Down
13 changes: 13 additions & 0 deletions graphrag/config/models/drift_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class DRIFTSearchConfig(BaseModel):
prompt: str | None = Field(
description="The drift search prompt to use.", default=None
)
reduce_prompt: str | None = Field(
description="The drift search reduce prompt to use.", default=None
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
Expand All @@ -35,6 +38,16 @@ class DRIFTSearchConfig(BaseModel):
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
)

reduce_max_tokens: int = Field(
description="The reduce llm maximum tokens response to produce.",
default=defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
)

reduce_temperature: float = Field(
description="The temperature to use for token generation in reduce.",
default=defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
)

concurrency: int = Field(
description="The number of concurrent requests.",
default=defs.DRIFT_SEARCH_CONCURRENCY,
Expand Down
4 changes: 1 addition & 3 deletions graphrag/prompts/query/drift_search_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
---Target response length and format---
Multiple paragraphs
{response_type}
---Goal---
Expand All @@ -133,8 +133,6 @@
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
{query}
"""


Expand Down
4 changes: 4 additions & 0 deletions graphrag/query/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def get_drift_search_engine(
entities: list[Entity],
relationships: list[Relationship],
description_embedding_store: BaseVectorStore,
response_type: str,
local_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
Expand All @@ -191,7 +193,9 @@ def get_drift_search_engine(
entity_text_embeddings=description_embedding_store,
text_units=text_units,
local_system_prompt=local_system_prompt,
reduce_system_prompt=reduce_system_prompt,
config=config.drift_search,
response_type=response_type,
),
token_encoder=token_encoder,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from graphrag.model.text_unit import TextUnit
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
Expand Down Expand Up @@ -51,13 +52,16 @@ def __init__(
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.chat_llm = chat_llm
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT

self.entities = entities
self.entity_text_embeddings = entity_text_embeddings
Expand All @@ -67,6 +71,8 @@ def __init__(
self.covariates = covariates
self.embedding_vectorstore_key = embedding_vectorstore_key

self.response_type = response_type

self.local_mixed_context = (
local_mixed_context or self.init_local_context_builder()
)
Expand Down
Loading

0 comments on commit 3defab2

Please sign in to comment.