From 48c8d1bb17e4fda210e08cd7013ddf3ea7acb759 Mon Sep 17 00:00:00 2001 From: ilai Date: Sun, 5 Nov 2023 23:16:03 +0200 Subject: [PATCH 1/7] Context must contain a ContextContent that implements to_text() In order to support our current StuffingContextBuilder, I added a new StuffingConxtextContent which inherits from ContextContent and implement to_text() correctly. The app's `/query` path returns a `str`, which is the only guaranteed format of Context. It can be any strucured on unstrucutured data - depending on the ContextBuilder --- .../context_builder/stuffing.py | 19 +++--- src/canopy/context_engine/models.py | 22 ++++++- src/canopy/models/data_models.py | 10 +--- src/canopy_server/app.py | 6 +- tests/e2e/test_app.py | 9 +-- tests/unit/chat_engine/test_chat_engine.py | 19 +++--- .../test_stuffing_context_builder.py | 58 ++++++++++--------- .../context_engine/test_context_engine.py | 16 ++--- 8 files changed, 91 insertions(+), 68 deletions(-) diff --git a/src/canopy/context_engine/context_builder/stuffing.py b/src/canopy/context_engine/context_builder/stuffing.py index 18ebc56a..8024d547 100644 --- a/src/canopy/context_engine/context_builder/stuffing.py +++ b/src/canopy/context_engine/context_builder/stuffing.py @@ -2,7 +2,8 @@ from typing import List, Tuple from canopy.context_engine.context_builder.base import ContextBuilder -from canopy.context_engine.models import ContextQueryResult, ContextSnippet +from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, + StuffingContextContent, ) from canopy.knowledge_base.models import QueryResult, DocumentWithScore from canopy.tokenizer import Tokenizer from canopy.models.data_models import Context @@ -24,12 +25,15 @@ def build(self, ContextQueryResult(query=qr.query, snippets=[]) for qr in query_results] debug_info = {"num_docs": len(sorted_docs_with_origin)} - context = Context(content=context_query_results, - num_tokens=0, - debug_info=debug_info) + context = Context( + content=StuffingContextContent(__root__=context_query_results), + num_tokens=0, + debug_info=debug_info + ) if self._tokenizer.token_count(context.to_text()) > max_context_tokens: - return Context(content=[], num_tokens=0, debug_info=debug_info) + return Context(content=StuffingContextContent(__root__=[]), + num_tokens=1, debug_info=debug_info) seen_doc_ids = set() for doc, origin_query_idx in sorted_docs_with_origin: @@ -45,8 +49,9 @@ def build(self, context_query_results[origin_query_idx].snippets.pop() # remove queries with no snippets - context.content = [qr for qr in context_query_results - if len(qr.snippets) > 0] + context.content = StuffingContextContent( + __root__=[qr for qr in context_query_results if len(qr.snippets) > 0] + ) context.num_tokens = self._tokenizer.token_count(context.to_text()) return context diff --git a/src/canopy/context_engine/models.py b/src/canopy/context_engine/models.py index 5f273425..1d0266fe 100644 --- a/src/canopy/context_engine/models.py +++ b/src/canopy/context_engine/models.py @@ -1,8 +1,8 @@ -from typing import List +from typing import List, Union from pydantic import BaseModel -from canopy.models.data_models import _ContextContent +from canopy.models.data_models import ContextContent class ContextSnippet(BaseModel): @@ -10,9 +10,25 @@ class ContextSnippet(BaseModel): text: str -class ContextQueryResult(_ContextContent): +class ContextQueryResult(BaseModel): query: str snippets: List[ContextSnippet] + +class StuffingContextContent(ContextContent): + __root__: Union[ContextQueryResult, List[ContextQueryResult]] + + def dict(self, **kwargs): + return super().dict(**kwargs)['__root__'] + + def __iter__(self): + return iter(self.__root__) + + def __getitem__(self, item): + return self.__root__[item] + + def __len__(self): + return len(self.__root__) + def to_text(self, **kwargs): return self.json(**kwargs) diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index 3f57ccb6..1fb365d1 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -56,7 +56,7 @@ def metadata_reseved_fields(cls, v): return v -class _ContextContent(BaseModel, ABC): +class ContextContent(BaseModel, ABC): # Any context should be able to be represented as well formatted text. # In the most minimal case, that could simply be a call to `.json()`. @abstractmethod @@ -64,19 +64,13 @@ def to_text(self, **kwargs) -> str: pass -ContextContent = Union[_ContextContent, Sequence[_ContextContent]] - - class Context(BaseModel): content: ContextContent num_tokens: int = Field(exclude=True) debug_info: dict = Field(default_factory=dict, exclude=True) def to_text(self, **kwargs) -> str: - if isinstance(self.content, Sequence): - return "\n".join([c.to_text(**kwargs) for c in self.content]) - else: - return self.content.to_text(**kwargs) + return self.content.to_text(**kwargs) # --------------------- LLM models ------------------------ diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index 3341c7bb..84aaf530 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -127,14 +127,13 @@ def stringify_content(response: StreamingChatResponse): @app.post( "/context/query", - response_model=ContextContent, responses={ 500: {"description": "Failed to query the knowledge base or build the context"} }, ) async def query( request: ContextQueryRequest = Body(...), -) -> ContextContent: +) -> str: """ Query the knowledge base for relevant context. The returned text may be structured or unstructured, depending on the Canopy configuration. @@ -147,8 +146,7 @@ async def query( queries=request.queries, max_context_tokens=request.max_tokens, ) - - return context.content + return context.to_text() except Exception as e: logger.exception(e) diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index 41357f4a..fdd8c148 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -1,3 +1,4 @@ +import json import os from typing import List @@ -27,14 +28,14 @@ ) -@retry(stop=stop_after_attempt(60), wait=wait_fixed(1)) +@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1)) def assert_vector_ids_exist(vector_ids: List[str], knowledge_base: KnowledgeBase): fetch_response = knowledge_base._index.fetch(ids=vector_ids) assert all([v_id in fetch_response["vectors"] for v_id in vector_ids]) -@retry(stop=stop_after_attempt(60), wait=wait_fixed(1)) +@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1)) def assert_vector_ids_not_exist(vector_ids: List[str], knowledge_base: KnowledgeBase): fetch_response = knowledge_base._index.fetch(ids=vector_ids) @@ -98,7 +99,7 @@ def test_upsert(client): assert upsert_response.is_success -@retry(stop=stop_after_attempt(60), wait=wait_fixed(1)) +@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1)) def test_query(client): # fetch the context with all the right filters query_payload = ContextQueryRequest( @@ -116,7 +117,7 @@ def test_query(client): assert query_response.is_success # test response is as expected on /query - response_as_json = query_response.json() + response_as_json = json.loads(query_response.json()) assert ( response_as_json[0]["query"] diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index 9841424a..330fb616 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -6,7 +6,8 @@ from canopy.chat_engine import ChatEngine from canopy.chat_engine.query_generator import QueryGenerator from canopy.context_engine import ContextEngine -from canopy.context_engine.models import ContextQueryResult, ContextSnippet +from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, + StuffingContextContent, ) from canopy.llm import BaseLLM from canopy.models.data_models import SystemMessage from canopy.models.api_models import ChatResponse, _Choice, TokenCounts @@ -58,13 +59,15 @@ def _get_inputs_and_expected(self, ] mock_queries = [Query(text="How does photosynthesis work?")] mock_context = Context( - content=ContextQueryResult( - query="How does photosynthesis work?", - - snippets=[ContextSnippet(source="ref 1", - text=self._generate_text(snippet_length)), - ContextSnippet(source="ref 2", - text=self._generate_text(12))] + content=StuffingContextContent( + __root__=ContextQueryResult( + query="How does photosynthesis work?", + + snippets=[ContextSnippet(source="ref 1", + text=self._generate_text(snippet_length)), + ContextSnippet(source="ref 2", + text=self._generate_text(12))] + ) ), num_tokens=1 # TODO: This is a dummy value. Need to improve. ) diff --git a/tests/unit/context_builder/test_stuffing_context_builder.py b/tests/unit/context_builder/test_stuffing_context_builder.py index bfd9a899..3d487712 100644 --- a/tests/unit/context_builder/test_stuffing_context_builder.py +++ b/tests/unit/context_builder/test_stuffing_context_builder.py @@ -1,6 +1,6 @@ from canopy.context_engine.models import \ - ContextSnippet, ContextQueryResult -from canopy.models.data_models import Context + (ContextSnippet, ContextQueryResult, StuffingContextContent, ) +from canopy.models.data_models import Context, ContextContent from ..stubs.stub_tokenizer import StubTokenizer from canopy.knowledge_base.models import \ QueryResult, DocumentWithScore @@ -46,22 +46,25 @@ def setup_method(self): score=1.0) ]) ] - self.full_context = Context(content=[ - ContextQueryResult(query="test query 1", - snippets=[ - ContextSnippet( - text=self.text1, source="test_source1"), - ContextSnippet( - text=self.text2, source="test_source2") - ]), - ContextQueryResult(query="test query 2", - snippets=[ - ContextSnippet( - text=self.text3, source="test_source3"), - ContextSnippet( - text=self.text4, source="test_source4") - ]) - ], num_tokens=0) + self.full_context = Context( + content=StuffingContextContent(__root__=[ + ContextQueryResult(query="test query 1", + snippets=[ + ContextSnippet( + text=self.text1, source="test_source1"), + ContextSnippet( + text=self.text2, source="test_source2") + ]), + ContextQueryResult(query="test query 2", + snippets=[ + ContextSnippet( + text=self.text3, source="test_source3"), + ContextSnippet( + text=self.text4, source="test_source4") + ]) + ]), + num_tokens=0 + ) self.full_context.num_tokens = self.tokenizer.token_count( self.full_context.to_text()) @@ -74,7 +77,7 @@ def test_context_fits_within_max_tokens(self): def test_context_exceeds_max_tokens(self): context = self.builder.build(self.query_results, max_context_tokens=30) - expected_context = Context(content=[ + expected_context = Context(content=StuffingContextContent(__root__=[ ContextQueryResult(query="test query 1", snippets=[ ContextSnippet( @@ -85,7 +88,7 @@ def test_context_exceeds_max_tokens(self): ContextSnippet( text=self.text3, source="test_source3"), ]) - ], num_tokens=0) + ]), num_tokens=0) expected_context.num_tokens = self.tokenizer.token_count( expected_context.to_text()) @@ -96,13 +99,13 @@ def test_context_exceeds_max_tokens_unordered(self): self.query_results[0].documents[0].text = self.text1 * 100 context = self.builder.build(self.query_results, max_context_tokens=20) - expected_context = Context(content=[ + expected_context = Context(content=StuffingContextContent(__root__=[ ContextQueryResult(query="test query 2", snippets=[ ContextSnippet( text=self.text3, source="test_source3"), ]) - ], num_tokens=0) + ]), num_tokens=0) expected_context.num_tokens = self.tokenizer.token_count( expected_context.to_text()) @@ -111,18 +114,18 @@ def test_context_exceeds_max_tokens_unordered(self): def test_whole_query_results_not_fit(self): context = self.builder.build(self.query_results, max_context_tokens=10) - assert context.num_tokens == 0 + assert context.num_tokens == 1 assert context.content == [] def test_max_tokens_zero(self): context = self.builder.build(self.query_results, max_context_tokens=0) - self.assert_num_tokens(context, 0) + self.assert_num_tokens(context, 1) assert context.content == [] def test_empty_query_results(self): context = self.builder.build([], max_context_tokens=100) - self.assert_num_tokens(context, 0) - assert len(context.content) == 0 + self.assert_num_tokens(context, 1) + assert context.content == [] def test_documents_with_duplicates(self): duplicate_query_results = self.query_results + [ @@ -165,7 +168,7 @@ def test_empty_documents(self): ] context = self.builder.build( empty_query_results, max_context_tokens=100) - self.assert_num_tokens(context, 0) + self.assert_num_tokens(context, 1) assert context.content == [] def assert_num_tokens(self, context: Context, max_tokens: int): @@ -175,6 +178,7 @@ def assert_num_tokens(self, context: Context, max_tokens: int): @staticmethod def assert_contexts_equal(actual: Context, expected: Context): + assert isinstance(actual.content, ContextContent) assert actual.num_tokens == expected.num_tokens assert len(actual.content) == len(expected.content) for actual_qr, expected_qr in zip(actual.content, expected.content): diff --git a/tests/unit/context_engine/test_context_engine.py b/tests/unit/context_engine/test_context_engine.py index a102db36..17c85ee8 100644 --- a/tests/unit/context_engine/test_context_engine.py +++ b/tests/unit/context_engine/test_context_engine.py @@ -5,10 +5,11 @@ from canopy.context_engine import ContextEngine from canopy.context_engine.context_builder.base import ContextBuilder -from canopy.context_engine.models import ContextQueryResult, ContextSnippet +from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, + StuffingContextContent, ) from canopy.knowledge_base.base import BaseKnowledgeBase from canopy.knowledge_base.models import QueryResult, DocumentWithScore -from canopy.models.data_models import Query, Context, _ContextContent +from canopy.models.data_models import Query, Context, ContextContent class TestContextEngine: @@ -68,7 +69,7 @@ def test_query(context_engine, queries = [Query(text="How does photosynthesis work?")] max_context_tokens = 100 - mock_context_content = create_autospec(_ContextContent) + mock_context_content = create_autospec(ContextContent) mock_context_content.to_text.return_value = sample_context_text mock_context = Context(content=mock_context_content, num_tokens=21) @@ -93,7 +94,7 @@ def test_query_with_metadata_filter(context_engine, queries = [Query(text="How does photosynthesis work?")] max_context_tokens = 100 - mock_context_content = create_autospec(_ContextContent) + mock_context_content = create_autospec(ContextContent) mock_context_content.to_text.return_value = sample_context_text mock_context = Context(content=mock_context_content, num_tokens=21) @@ -149,7 +150,7 @@ def test_multiple_queries(context_engine, mock_knowledge_base.query.return_value = extended_mock_query_result combined_text = sample_context_text + "\n" + text - mock_context_content = create_autospec(_ContextContent) + mock_context_content = create_autospec(ContextContent) mock_context_content.to_text.return_value = combined_text mock_context = Context(content=mock_context_content, num_tokens=40) @@ -168,7 +169,7 @@ def test_empty_query_results(context_engine, mock_knowledge_base.query.return_value = [] - mock_context_content = create_autospec(_ContextContent) + mock_context_content = create_autospec(ContextContent) mock_context_content.to_text.return_value = "" mock_context = Context(content=mock_context_content, num_tokens=0) @@ -183,7 +184,8 @@ def test_context_query_result_to_text(): query_result = ContextQueryResult(query="How does photosynthesis work?", snippets=[ContextSnippet(text="42", source="ref")]) - context = Context(content=query_result, num_tokens=1) + context = Context(content=StuffingContextContent(__root__=query_result), + num_tokens=1) assert context.to_text() == json.dumps(query_result.dict()) assert context.to_text(indent=2) == json.dumps(query_result.dict(), indent=2) From 457b61bda55cc74d2c288db5e281cc6eefb34bc1 Mon Sep 17 00:00:00 2001 From: ilai Date: Sun, 5 Nov 2023 23:44:14 +0200 Subject: [PATCH 2/7] linters --- examples/canopy-lib-quickstart.ipynb | 4 ++-- src/canopy/models/data_models.py | 2 +- src/canopy_server/app.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/canopy-lib-quickstart.ipynb b/examples/canopy-lib-quickstart.ipynb index ac4981bf..a3e805cc 100644 --- a/examples/canopy-lib-quickstart.ipynb +++ b/examples/canopy-lib-quickstart.ipynb @@ -32,8 +32,8 @@ "output_type": "stream", "text": [ "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip available: \u001B[0m\u001B[31;49m22.2.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.3.1\u001B[0m\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n" ] } ], diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index 1fb365d1..b56ef185 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, List, Union, Dict, Sequence, Literal +from typing import Optional, List, Union, Dict, Literal from pydantic import BaseModel, Field, validator, Extra diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index 84aaf530..a1415e04 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -26,7 +26,7 @@ StreamingChatResponse, ChatResponse, ) -from canopy.models.data_models import Context, UserMessage, ContextContent +from canopy.models.data_models import Context, UserMessage from .api_models import ( ChatRequest, ContextQueryRequest, From 50efc687f689f62fa94f0cd4f08613b17d91e0fd Mon Sep 17 00:00:00 2001 From: ilai Date: Mon, 6 Nov 2023 10:56:08 +0200 Subject: [PATCH 3/7] [context] Simplify ContextContent - Made StuffingContextContent always a List - Slightly improved readability of `StuffingContextBuilder` --- .../context_engine/context_builder/stuffing.py | 17 +++++++---------- src/canopy/context_engine/models.py | 4 ++-- tests/unit/chat_engine/test_chat_engine.py | 4 ++-- .../unit/context_engine/test_context_engine.py | 6 +++--- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/canopy/context_engine/context_builder/stuffing.py b/src/canopy/context_engine/context_builder/stuffing.py index 8024d547..4b8402c6 100644 --- a/src/canopy/context_engine/context_builder/stuffing.py +++ b/src/canopy/context_engine/context_builder/stuffing.py @@ -25,13 +25,9 @@ def build(self, ContextQueryResult(query=qr.query, snippets=[]) for qr in query_results] debug_info = {"num_docs": len(sorted_docs_with_origin)} - context = Context( - content=StuffingContextContent(__root__=context_query_results), - num_tokens=0, - debug_info=debug_info - ) + content = StuffingContextContent(__root__=context_query_results) - if self._tokenizer.token_count(context.to_text()) > max_context_tokens: + if self._tokenizer.token_count(content.to_text()) > max_context_tokens: return Context(content=StuffingContextContent(__root__=[]), num_tokens=1, debug_info=debug_info) @@ -45,16 +41,17 @@ def build(self, snippet) seen_doc_ids.add(doc.id) # if the context is too long, remove the snippet - if self._tokenizer.token_count(context.to_text()) > max_context_tokens: + if self._tokenizer.token_count(content.to_text()) > max_context_tokens: context_query_results[origin_query_idx].snippets.pop() # remove queries with no snippets - context.content = StuffingContextContent( + content = StuffingContextContent( __root__=[qr for qr in context_query_results if len(qr.snippets) > 0] ) - context.num_tokens = self._tokenizer.token_count(context.to_text()) - return context + return Context(content=content, + num_tokens=self._tokenizer.token_count(content.to_text()), + debug_info=debug_info) @staticmethod def _round_robin_sort( diff --git a/src/canopy/context_engine/models.py b/src/canopy/context_engine/models.py index 1d0266fe..2b21e20a 100644 --- a/src/canopy/context_engine/models.py +++ b/src/canopy/context_engine/models.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List from pydantic import BaseModel @@ -16,7 +16,7 @@ class ContextQueryResult(BaseModel): class StuffingContextContent(ContextContent): - __root__: Union[ContextQueryResult, List[ContextQueryResult]] + __root__: List[ContextQueryResult] def dict(self, **kwargs): return super().dict(**kwargs)['__root__'] diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index 330fb616..d6381a0c 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -60,14 +60,14 @@ def _get_inputs_and_expected(self, mock_queries = [Query(text="How does photosynthesis work?")] mock_context = Context( content=StuffingContextContent( - __root__=ContextQueryResult( + __root__=[ContextQueryResult( query="How does photosynthesis work?", snippets=[ContextSnippet(source="ref 1", text=self._generate_text(snippet_length)), ContextSnippet(source="ref 2", text=self._generate_text(12))] - ) + )] ), num_tokens=1 # TODO: This is a dummy value. Need to improve. ) diff --git a/tests/unit/context_engine/test_context_engine.py b/tests/unit/context_engine/test_context_engine.py index 17c85ee8..61977f2a 100644 --- a/tests/unit/context_engine/test_context_engine.py +++ b/tests/unit/context_engine/test_context_engine.py @@ -184,11 +184,11 @@ def test_context_query_result_to_text(): query_result = ContextQueryResult(query="How does photosynthesis work?", snippets=[ContextSnippet(text="42", source="ref")]) - context = Context(content=StuffingContextContent(__root__=query_result), + context = Context(content=StuffingContextContent(__root__=[query_result]), num_tokens=1) - assert context.to_text() == json.dumps(query_result.dict()) - assert context.to_text(indent=2) == json.dumps(query_result.dict(), indent=2) + assert context.to_text() == json.dumps([query_result.dict()]) + assert context.to_text(indent=2) == json.dumps([query_result.dict()], indent=2) @staticmethod @pytest.mark.asyncio From f0b40e970ac23f44545848b98f3d3235d34a8b92 Mon Sep 17 00:00:00 2001 From: ilai Date: Mon, 6 Nov 2023 11:03:01 +0200 Subject: [PATCH 4/7] [context] StuffingContextContent - Removed special iterator functions I changed the tests to use explicit json.loads() --- src/canopy/context_engine/models.py | 9 -------- .../test_stuffing_context_builder.py | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/canopy/context_engine/models.py b/src/canopy/context_engine/models.py index 2b21e20a..89618451 100644 --- a/src/canopy/context_engine/models.py +++ b/src/canopy/context_engine/models.py @@ -21,14 +21,5 @@ class StuffingContextContent(ContextContent): def dict(self, **kwargs): return super().dict(**kwargs)['__root__'] - def __iter__(self): - return iter(self.__root__) - - def __getitem__(self, item): - return self.__root__[item] - - def __len__(self): - return len(self.__root__) - def to_text(self, **kwargs): return self.json(**kwargs) diff --git a/tests/unit/context_builder/test_stuffing_context_builder.py b/tests/unit/context_builder/test_stuffing_context_builder.py index 3d487712..e362d8e5 100644 --- a/tests/unit/context_builder/test_stuffing_context_builder.py +++ b/tests/unit/context_builder/test_stuffing_context_builder.py @@ -1,3 +1,5 @@ +import json + from canopy.context_engine.models import \ (ContextSnippet, ContextQueryResult, StuffingContextContent, ) from canopy.models.data_models import Context, ContextContent @@ -153,7 +155,8 @@ def test_source_metadata_missing(self): context = self.builder.build( missing_metadata_query_results, max_context_tokens=100) self.assert_num_tokens(context, 100) - assert context.content[0].snippets[0].source == "" + content = json.loads(context.to_text()) + assert content[0]["snippets"][0]["source"] == "" def test_empty_documents(self): empty_query_results = [ @@ -180,11 +183,13 @@ def assert_num_tokens(self, context: Context, max_tokens: int): def assert_contexts_equal(actual: Context, expected: Context): assert isinstance(actual.content, ContextContent) assert actual.num_tokens == expected.num_tokens - assert len(actual.content) == len(expected.content) - for actual_qr, expected_qr in zip(actual.content, expected.content): - assert actual_qr.query == expected_qr.query - assert len(actual_qr.snippets) == len(expected_qr.snippets) - for actual_snippet, expected_snippet in zip(actual_qr.snippets, - expected_qr.snippets): - assert actual_snippet.text == expected_snippet.text - assert actual_snippet.source == expected_snippet.source + actual_content = json.loads(actual.to_text()) + expected_content = json.loads(expected.to_text()) + assert len(actual_content) == len(expected_content) + for actual_qr, expected_qr in zip(actual_content, expected_content): + assert actual_qr["query"] == expected_qr["query"] + assert len(actual_qr["snippets"]) == len(expected_qr["snippets"]) + for actual_snippet, expected_snippet in zip(actual_qr["snippets"], + expected_qr["snippets"]): + assert actual_snippet["text"] == expected_snippet["text"] + assert actual_snippet["source"] == expected_snippet["source"] From 02637da6819dc18d74fa92cbfe4e2ef6ff70e03f Mon Sep 17 00:00:00 2001 From: ilai Date: Mon, 6 Nov 2023 13:43:53 +0200 Subject: [PATCH 5/7] [context] Moved SuffingContextBuilder's data models into the same file Makes the code more readable and explicit --- .../context_builder/stuffing.py | 34 +++++++++++++++++-- src/canopy/context_engine/models.py | 25 -------------- src/canopy/models/data_models.py | 15 ++++++-- src/canopy_server/app.py | 5 +-- tests/unit/chat_engine/test_chat_engine.py | 5 +-- .../test_stuffing_context_builder.py | 5 +-- .../context_engine/test_context_engine.py | 5 +-- 7 files changed, 56 insertions(+), 38 deletions(-) delete mode 100644 src/canopy/context_engine/models.py diff --git a/src/canopy/context_engine/context_builder/stuffing.py b/src/canopy/context_engine/context_builder/stuffing.py index 4b8402c6..5aa9b21b 100644 --- a/src/canopy/context_engine/context_builder/stuffing.py +++ b/src/canopy/context_engine/context_builder/stuffing.py @@ -1,13 +1,41 @@ from itertools import zip_longest from typing import List, Tuple +from pydantic import BaseModel + from canopy.context_engine.context_builder.base import ContextBuilder -from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, - StuffingContextContent, ) from canopy.knowledge_base.models import QueryResult, DocumentWithScore from canopy.tokenizer import Tokenizer -from canopy.models.data_models import Context +from canopy.models.data_models import Context, ContextContent + + +# ------------- DATA MODELS ------------- + +class ContextSnippet(BaseModel): + source: str + text: str + + +class ContextQueryResult(BaseModel): + query: str + snippets: List[ContextSnippet] + + +class StuffingContextContent(ContextContent): + __root__: List[ContextQueryResult] + + def dict(self, **kwargs): + return super().dict(**kwargs)['__root__'] + + # In the case of StuffingContextBuilder, we simply want the text representation to + # be a json. Other ContextContent subclasses may render into text differently + def to_text(self, **kwargs): + # We can't use self.json() since this is mapped back to self.to_text() in the + # base class, which would cause infinite recursion. + return super(ContextContent, self).json(**kwargs) + +# ------------- CONTEXT BUILDER ------------- class StuffingContextBuilder(ContextBuilder): diff --git a/src/canopy/context_engine/models.py b/src/canopy/context_engine/models.py deleted file mode 100644 index 89618451..00000000 --- a/src/canopy/context_engine/models.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List - -from pydantic import BaseModel - -from canopy.models.data_models import ContextContent - - -class ContextSnippet(BaseModel): - source: str - text: str - - -class ContextQueryResult(BaseModel): - query: str - snippets: List[ContextSnippet] - - -class StuffingContextContent(ContextContent): - __root__: List[ContextQueryResult] - - def dict(self, **kwargs): - return super().dict(**kwargs)['__root__'] - - def to_text(self, **kwargs): - return self.json(**kwargs) diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index b56ef185..78a07dd8 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, List, Union, Dict, Literal +from typing import Optional, List, Union, Dict, Literal, Any from pydantic import BaseModel, Field, validator, Extra @@ -63,15 +63,26 @@ class ContextContent(BaseModel, ABC): def to_text(self, **kwargs) -> str: pass + def __str__(self): + return self.to_text() + + def json(self, **kwargs): + return self.to_text(**kwargs) + class Context(BaseModel): content: ContextContent - num_tokens: int = Field(exclude=True) + num_tokens: int debug_info: dict = Field(default_factory=dict, exclude=True) def to_text(self, **kwargs) -> str: return self.content.to_text(**kwargs) + class Config: + @staticmethod + # Override the JSON schema, to show `content` as a string in the docs + def schema_extra(schema: dict[str, Any]) -> None: + schema['properties']['content'] = {'type': 'String', 'title': 'content'} # --------------------- LLM models ------------------------ diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index a1415e04..79af0868 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -127,13 +127,14 @@ def stringify_content(response: StreamingChatResponse): @app.post( "/context/query", + response_model=Context, responses={ 500: {"description": "Failed to query the knowledge base or build the context"} }, ) async def query( request: ContextQueryRequest = Body(...), -) -> str: +) -> Context: """ Query the knowledge base for relevant context. The returned text may be structured or unstructured, depending on the Canopy configuration. @@ -146,7 +147,7 @@ async def query( queries=request.queries, max_context_tokens=request.max_tokens, ) - return context.to_text() + return context except Exception as e: logger.exception(e) diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index d6381a0c..ea8d6415 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -6,8 +6,9 @@ from canopy.chat_engine import ChatEngine from canopy.chat_engine.query_generator import QueryGenerator from canopy.context_engine import ContextEngine -from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, - StuffingContextContent, ) +from canopy.context_engine.context_builder.stuffing import (ContextSnippet, + ContextQueryResult, + StuffingContextContent, ) from canopy.llm import BaseLLM from canopy.models.data_models import SystemMessage from canopy.models.api_models import ChatResponse, _Choice, TokenCounts diff --git a/tests/unit/context_builder/test_stuffing_context_builder.py b/tests/unit/context_builder/test_stuffing_context_builder.py index e362d8e5..4881926b 100644 --- a/tests/unit/context_builder/test_stuffing_context_builder.py +++ b/tests/unit/context_builder/test_stuffing_context_builder.py @@ -1,7 +1,8 @@ import json -from canopy.context_engine.models import \ - (ContextSnippet, ContextQueryResult, StuffingContextContent, ) +from canopy.context_engine.context_builder.stuffing import (ContextSnippet, + ContextQueryResult, + StuffingContextContent, ) from canopy.models.data_models import Context, ContextContent from ..stubs.stub_tokenizer import StubTokenizer from canopy.knowledge_base.models import \ diff --git a/tests/unit/context_engine/test_context_engine.py b/tests/unit/context_engine/test_context_engine.py index 61977f2a..ec17c55c 100644 --- a/tests/unit/context_engine/test_context_engine.py +++ b/tests/unit/context_engine/test_context_engine.py @@ -5,8 +5,9 @@ from canopy.context_engine import ContextEngine from canopy.context_engine.context_builder.base import ContextBuilder -from canopy.context_engine.models import (ContextQueryResult, ContextSnippet, - StuffingContextContent, ) +from canopy.context_engine.context_builder.stuffing import (ContextSnippet, + ContextQueryResult, + StuffingContextContent, ) from canopy.knowledge_base.base import BaseKnowledgeBase from canopy.knowledge_base.models import QueryResult, DocumentWithScore from canopy.models.data_models import Query, Context, ContextContent From db9ca51ce36c2285aebe78be8fbb11fd43eaec9c Mon Sep 17 00:00:00 2001 From: ilai Date: Mon, 6 Nov 2023 16:41:25 +0200 Subject: [PATCH 6/7] [app] `/query` return type - added ContextResponse model KISS solution - simply return a different model than the actual internal `Context` --- .../context_engine/context_builder/stuffing.py | 6 ++---- src/canopy/models/data_models.py | 8 -------- src/canopy_server/api_models.py | 5 +++++ src/canopy_server/app.py | 6 ++++-- tests/e2e/test_app.py | 15 +++++++++------ 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/canopy/context_engine/context_builder/stuffing.py b/src/canopy/context_engine/context_builder/stuffing.py index 5aa9b21b..e1cd7c8d 100644 --- a/src/canopy/context_engine/context_builder/stuffing.py +++ b/src/canopy/context_engine/context_builder/stuffing.py @@ -30,12 +30,10 @@ def dict(self, **kwargs): # In the case of StuffingContextBuilder, we simply want the text representation to # be a json. Other ContextContent subclasses may render into text differently def to_text(self, **kwargs): - # We can't use self.json() since this is mapped back to self.to_text() in the - # base class, which would cause infinite recursion. - return super(ContextContent, self).json(**kwargs) + return self.json(**kwargs) -# ------------- CONTEXT BUILDER ------------- +# ------------- CONTEXT BUILDER ------------- class StuffingContextBuilder(ContextBuilder): diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index 78a07dd8..58c702b2 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -66,9 +66,6 @@ def to_text(self, **kwargs) -> str: def __str__(self): return self.to_text() - def json(self, **kwargs): - return self.to_text(**kwargs) - class Context(BaseModel): content: ContextContent @@ -78,11 +75,6 @@ class Context(BaseModel): def to_text(self, **kwargs) -> str: return self.content.to_text(**kwargs) - class Config: - @staticmethod - # Override the JSON schema, to show `content` as a string in the docs - def schema_extra(schema: dict[str, Any]) -> None: - schema['properties']['content'] = {'type': 'String', 'title': 'content'} # --------------------- LLM models ------------------------ diff --git a/src/canopy_server/api_models.py b/src/canopy_server/api_models.py index e965b8cb..49a7872a 100644 --- a/src/canopy_server/api_models.py +++ b/src/canopy_server/api_models.py @@ -31,6 +31,11 @@ class ContextQueryRequest(BaseModel): max_tokens: int +class ContextResponse(BaseModel): + content: str + num_tokens: int + + class ContextUpsertRequest(BaseModel): documents: List[Document] batch_size: int = Field( diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index 79af0868..8e18be1c 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -36,6 +36,7 @@ ShutdownResponse, SuccessUpsertResponse, SuccessDeleteResponse, + ContextResponse, ) from canopy.llm.openai import OpenAILLM @@ -127,7 +128,7 @@ def stringify_content(response: StreamingChatResponse): @app.post( "/context/query", - response_model=Context, + response_model=ContextResponse, responses={ 500: {"description": "Failed to query the knowledge base or build the context"} }, @@ -147,7 +148,8 @@ async def query( queries=request.queries, max_context_tokens=request.max_tokens, ) - return context + return ContextResponse(content=context.content.to_text(), + num_tokens=context.num_tokens) except Exception as e: logger.exception(e) diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index fdd8c148..141c3b4a 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -13,7 +13,7 @@ from canopy_server.app import app from canopy_server.api_models import (HealthStatus, ContextUpsertRequest, - ContextQueryRequest) + ContextQueryRequest, ContextResponse, ) from .. import Tokenizer upsert_payload = ContextUpsertRequest( @@ -102,6 +102,7 @@ def test_upsert(client): @retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1)) def test_query(client): # fetch the context with all the right filters + tokenizer = Tokenizer() query_payload = ContextQueryRequest( queries=[ { @@ -116,16 +117,18 @@ def test_query(client): query_response = client.post("/context/query", json=query_payload.dict()) assert query_response.is_success - # test response is as expected on /query - response_as_json = json.loads(query_response.json()) + query_response = query_response.json() + assert (query_response["num_tokens"] == + len(tokenizer.tokenize(query_response["content"]))) + stuffing_content = json.loads(query_response["content"]) assert ( - response_as_json[0]["query"] + stuffing_content[0]["query"] == query_payload.dict()["queries"][0]["text"] - and response_as_json[0]["snippets"][0]["text"] + and stuffing_content[0]["snippets"][0]["text"] == upsert_payload.dict()["documents"][0]["text"] ) - assert (response_as_json[0]["snippets"][0]["source"] == + assert (stuffing_content[0]["snippets"][0]["source"] == upsert_payload.dict()["documents"][0]["source"]) From 3c828c405badf61f11a2e047dc169a15b11ad1b4 Mon Sep 17 00:00:00 2001 From: ilai Date: Mon, 6 Nov 2023 16:49:49 +0200 Subject: [PATCH 7/7] Linter fixes + wrong return type --- src/canopy/models/data_models.py | 2 +- src/canopy_server/app.py | 2 +- tests/e2e/test_app.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index 58c702b2..dbaa8096 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, List, Union, Dict, Literal, Any +from typing import Optional, List, Union, Dict, Literal from pydantic import BaseModel, Field, validator, Extra diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index 8e18be1c..6c67494a 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -135,7 +135,7 @@ def stringify_content(response: StreamingChatResponse): ) async def query( request: ContextQueryRequest = Body(...), -) -> Context: +) -> ContextResponse: """ Query the knowledge base for relevant context. The returned text may be structured or unstructured, depending on the Canopy configuration. diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index 141c3b4a..70e5adca 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -13,7 +13,7 @@ from canopy_server.app import app from canopy_server.api_models import (HealthStatus, ContextUpsertRequest, - ContextQueryRequest, ContextResponse, ) + ContextQueryRequest, ) from .. import Tokenizer upsert_payload = ContextUpsertRequest(