Skip to content

Commit

Permalink
Merged in main and attempt to make work.
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbaker committed Jan 22, 2025
2 parents 1a3014a + c644338 commit 7702271
Show file tree
Hide file tree
Showing 107 changed files with 2,268 additions and 3,615 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250109223356701278.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Remove config inheritance, hydration, and automatic env var overlays."
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250121205226363912.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix DRIFT search on Azure AI Search."
}
82 changes: 37 additions & 45 deletions docs/examples_notebooks/index_migration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,27 +37,28 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.config.resolve_path import resolve_paths\n",
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
"from graphrag.storage.factory import create_storage\n",
"from graphrag.storage.factory import StorageFactory\n",
"\n",
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
"config = load_config(Path(PROJECT_DIRECTORY))\n",
"resolve_paths(config)\n",
"pipeline_config = create_pipeline_config(config)\n",
"storage = create_storage(pipeline_config.storage)"
"storage_config = config.storage.model_dump() # type: ignore\n",
"storage = StorageFactory().create_storage(\n",
" storage_type=storage_config[\"type\"], kwargs=storage_config\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -68,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -97,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -108,22 +109,16 @@
"# First we'll go through any parquet files that had model changes and update them\n",
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
"\n",
"final_documents = await load_table_from_storage(\n",
" \"create_final_documents.parquet\", storage\n",
")\n",
"final_text_units = await load_table_from_storage(\n",
" \"create_final_text_units.parquet\", storage\n",
")\n",
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships.parquet\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\n",
" \"create_final_communities.parquet\", storage\n",
" \"create_final_relationships\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports.parquet\", storage\n",
" \"create_final_community_reports\", storage\n",
")\n",
"\n",
"\n",
Expand Down Expand Up @@ -183,44 +178,41 @@
" parent_df, on=\"community\", how=\"left\"\n",
" )\n",
"\n",
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
"await write_table_to_storage(\n",
" final_text_units, \"create_final_text_units.parquet\", storage\n",
")\n",
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
"await write_table_to_storage(\n",
" final_relationships, \"create_final_relationships.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_communities, \"create_final_communities.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
" final_community_reports, \"create_final_community_reports\", storage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphrag.cache.factory import create_cache\n",
"from graphrag.cache.factory import CacheFactory\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
"# We'll construct the context and run this function flow directly to avoid everything else\n",
"\n",
"workflow = next(\n",
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
")\n",
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"\n",
"embedded_fields = get_embedded_fields(config)\n",
"text_embed = get_embedding_settings(config)\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"cache_config = config.cache.model_dump() # type: ignore\n",
"cache = CacheFactory().create_cache(\n",
" cache_type=cache_config[\"type\"], # type: ignore\n",
" root_dir=PROJECT_DIRECTORY,\n",
" kwargs=cache_config,\n",
")\n",
"\n",
"await generate_text_embeddings(\n",
" final_documents=None,\n",
Expand Down
15 changes: 2 additions & 13 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging

from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.callbacks.factory import create_pipeline_reporter
from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import CacheType
from graphrag.config.models.graph_rag_config import GraphRagConfig
Expand All @@ -25,8 +25,6 @@

async def build_index(
config: GraphRagConfig,
run_id: str = "",
is_resume_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
Expand All @@ -37,10 +35,6 @@ async def build_index(
----------
config : GraphRagConfig
The configuration.
run_id : str
The run id. Creates a output directory with this name.
is_resume_run : bool default=False
Whether to resume a previous index run.
memory_profile : bool
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
Expand All @@ -53,11 +47,7 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
is_update_run = bool(config.update_index_storage)

if is_resume_run and is_update_run:
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)
is_update_run = bool(config.update_index_output)

pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
Expand All @@ -79,7 +69,6 @@ async def build_index(
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
run_id=run_id,
is_update_run=is_update_run,
):
outputs.append(output)
Expand Down
13 changes: 9 additions & 4 deletions graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
from graphrag.prompt_tune.generator.community_report_rating import (
generate_community_report_rating,
)
Expand Down Expand Up @@ -95,9 +95,11 @@ async def generate_indexing_prompts(
)

# Create LLM from config
# TODO: Expose way to specify Prompt Tuning model ID through config
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
llm = load_llm(
"prompt_tuning",
config.llm,
default_llm_settings,
cache=None,
callbacks=NoopWorkflowCallbacks(),
)
Expand All @@ -120,14 +122,17 @@ async def generate_indexing_prompts(
)

entity_types = None
entity_extraction_llm_settings = config.get_language_model_config(
config.entity_extraction.model_id
)
if discover_entity_types:
logger.info("Generating entity types...")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
json_mode=entity_extraction_llm_settings.model_supports_json or False,
)

logger.info("Generating entity relationship examples...")
Expand All @@ -147,7 +152,7 @@ async def generate_indexing_prompts(
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
encoding_model=config.encoding_model,
encoding_model=entity_extraction_llm_settings.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)
Expand Down
18 changes: 10 additions & 8 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import pandas as pd
from pydantic import validate_call

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import (
from graphrag.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_basic_search_engine,
Expand Down Expand Up @@ -410,7 +410,9 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store
if type(config.vector_store) is not list:
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down Expand Up @@ -476,7 +478,7 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down Expand Up @@ -785,7 +787,7 @@ async def drift_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down Expand Up @@ -859,7 +861,7 @@ async def drift_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down Expand Up @@ -1143,7 +1145,7 @@ async def basic_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down Expand Up @@ -1188,7 +1190,7 @@ async def basic_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = get_embedding_store(
Expand Down
45 changes: 0 additions & 45 deletions graphrag/callbacks/factory.py

This file was deleted.

Loading

0 comments on commit 7702271

Please sign in to comment.