From 99c3fa733d961b3758a4fa6303d38b617f2f7adb Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Mon, 12 Feb 2024 12:29:34 +0200 Subject: [PATCH] Add support for Pydantic v2 --- pyproject.toml | 2 +- src/canopy/chat_engine/chat_engine.py | 2 +- .../context_builder/stuffing.py | 14 +++++------- src/canopy/context_engine/context_engine.py | 2 +- src/canopy/knowledge_base/knowledge_base.py | 6 ++--- .../knowledge_base/record_encoder/dense.py | 4 ++-- .../knowledge_base/record_encoder/hybrid.py | 2 +- src/canopy/knowledge_base/reranker/cohere.py | 2 +- src/canopy/llm/cohere.py | 4 ++-- src/canopy/llm/models.py | 11 +++++----- src/canopy/llm/openai.py | 16 +++++++------- src/canopy/models/api_models.py | 9 ++++---- src/canopy/models/data_models.py | 21 +++++++++--------- src/canopy/tokenizer/cohere.py | 4 ++-- src/canopy/tokenizer/llama.py | 2 +- src/canopy/tokenizer/openai.py | 2 +- src/canopy_cli/cli.py | 2 +- src/canopy_server/app.py | 2 +- src/canopy_server/models/v1/api_models.py | 6 ++--- tests/e2e/test_app.py | 12 +++++----- tests/system/llm/test_cohere.py | 2 +- tests/system/llm/test_openai.py | 2 +- tests/unit/chat_engine/test_chat_engine.py | 2 +- tests/unit/chunker/test_markdown_chunker.py | 18 +++++++-------- .../test_recursive_character_chunker.py | 22 +++++++++---------- tests/unit/chunker/test_token_chunker.py | 18 +++++++-------- .../test_stuffing_context_builder.py | 14 ++++++------ .../context_engine/test_context_engine.py | 6 ++--- .../test_raising_history_pruner.py | 2 +- .../test_recent_history_pruner.py | 3 +-- .../base_test_record_encoder.py | 4 ++-- tests/unit/stubs/stub_record_encoder.py | 4 ++-- 32 files changed, 108 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5636c238..f8a865e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ pinecone-client = [{ version = "^3.0.0" }, python-dotenv = "^1.0.0" openai = "^1.2.3" tiktoken = "^0.3.3" -pydantic = "^1.10.7" +pydantic = "^2.0.0" pandas-stubs = "^2.0.3.230814" fastapi = ">=0.93.0, <1.0.0" uvicorn = ">=0.20.0, <1.0.0" diff --git a/src/canopy/chat_engine/chat_engine.py b/src/canopy/chat_engine/chat_engine.py index c75801ca..645fca17 100644 --- a/src/canopy/chat_engine/chat_engine.py +++ b/src/canopy/chat_engine/chat_engine.py @@ -221,7 +221,7 @@ def chat(self, model_params=model_params_dict) debug_info = {} if CANOPY_DEBUG_INFO: - debug_info['context'] = context.dict() + debug_info['context'] = context.model_dump() debug_info['context'].update(context.debug_info) if stream: diff --git a/src/canopy/context_engine/context_builder/stuffing.py b/src/canopy/context_engine/context_builder/stuffing.py index a0624863..67f34282 100644 --- a/src/canopy/context_engine/context_builder/stuffing.py +++ b/src/canopy/context_engine/context_builder/stuffing.py @@ -1,3 +1,4 @@ +import json from itertools import zip_longest from typing import List, Tuple @@ -23,15 +24,12 @@ class ContextQueryResult(BaseModel): class StuffingContextContent(ContextContent): - __root__: List[ContextQueryResult] - - def dict(self, **kwargs): - return super().dict(**kwargs)['__root__'] + root: List[ContextQueryResult] # In the case of StuffingContextBuilder, we simply want the text representation to # be a json. Other ContextContent subclasses may render into text differently def to_text(self, **kwargs): - return self.json(**kwargs) + return json.dumps(self.model_dump(), **kwargs) # ------------- CONTEXT BUILDER ------------- @@ -52,10 +50,10 @@ def build(self, ContextQueryResult(query=qr.query, snippets=[]) for qr in query_results] debug_info = {"num_docs": len(sorted_docs_with_origin), "snippet_ids": []} - content = StuffingContextContent(__root__=context_query_results) + content = StuffingContextContent(root=context_query_results) if self._tokenizer.token_count(content.to_text()) > max_context_tokens: - return Context(content=StuffingContextContent(__root__=[]), + return Context(content=StuffingContextContent(root=[]), num_tokens=1, debug_info=debug_info) seen_doc_ids = set() @@ -78,7 +76,7 @@ def build(self, # remove queries with no snippets content = StuffingContextContent( - __root__=[qr for qr in context_query_results if len(qr.snippets) > 0] + root=[qr for qr in context_query_results if len(qr.snippets) > 0] ) return Context(content=content, diff --git a/src/canopy/context_engine/context_engine.py b/src/canopy/context_engine/context_engine.py index 173adf44..31cd397a 100644 --- a/src/canopy/context_engine/context_engine.py +++ b/src/canopy/context_engine/context_engine.py @@ -110,7 +110,7 @@ def query(self, queries: List[Query], if CANOPY_DEBUG_INFO: context.debug_info["query_results"] = [ - {**qr.dict(), **qr.debug_info} for qr in query_results + {**qr.model_dump(), **qr.debug_info} for qr in query_results ] return context diff --git a/src/canopy/knowledge_base/knowledge_base.py b/src/canopy/knowledge_base/knowledge_base.py index 23e400b0..67688187 100644 --- a/src/canopy/knowledge_base/knowledge_base.py +++ b/src/canopy/knowledge_base/knowledge_base.py @@ -444,7 +444,7 @@ def query(self, query=rr.query, documents=[ DocumentWithScore( - **d.dict(exclude={ + **d.model_dump(exclude={ 'document_id' }) ) @@ -454,13 +454,13 @@ def query(self, query=r.query, documents=[ DocumentWithScore( - **d.dict(exclude={ + **d.model_dump(exclude={ 'document_id' }) ) for d in r.documents ] - ).dict()} if CANOPY_DEBUG_INFO else {} + ).model_dump()} if CANOPY_DEBUG_INFO else {} ) for rr, r in zip(ranked_results, results) ] diff --git a/src/canopy/knowledge_base/record_encoder/dense.py b/src/canopy/knowledge_base/record_encoder/dense.py index 7e605a25..3cf70e07 100644 --- a/src/canopy/knowledge_base/record_encoder/dense.py +++ b/src/canopy/knowledge_base/record_encoder/dense.py @@ -40,7 +40,7 @@ def _encode_documents_batch(self, encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector. """ # noqa: E501 dense_values = self._dense_encoder.encode_documents([d.text for d in documents]) - return [KBEncodedDocChunk(**d.dict(), values=v) for d, v in + return [KBEncodedDocChunk(**d.model_dump(), values=v) for d, v in zip(documents, dense_values)] def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]: @@ -52,7 +52,7 @@ def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]: encoded queries: A list of KBQuery, with the `values` field populated by the generated embeddings vector. """ # noqa: E501 dense_values = self._dense_encoder.encode_queries([q.text for q in queries]) - return [KBQuery(**q.dict(), values=v) for q, v in zip(queries, dense_values)] + return [KBQuery(**q.model_dump(), values=v) for q, v in zip(queries, dense_values)] @cached_property def dimension(self) -> int: diff --git a/src/canopy/knowledge_base/record_encoder/hybrid.py b/src/canopy/knowledge_base/record_encoder/hybrid.py index ecc428ae..0fa5ff8a 100644 --- a/src/canopy/knowledge_base/record_encoder/hybrid.py +++ b/src/canopy/knowledge_base/record_encoder/hybrid.py @@ -124,7 +124,7 @@ def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]: zip(dense_queries, sparse_values) ] - return [q.copy(update=dict(values=v, sparse_values=sv)) for q, (v, sv) in + return [q.model_copy(update=dict(values=v, sparse_values=sv)) for q, (v, sv) in zip(dense_queries, scaled_values)] @property diff --git a/src/canopy/knowledge_base/reranker/cohere.py b/src/canopy/knowledge_base/reranker/cohere.py index 2ab3a2b8..615581cb 100644 --- a/src/canopy/knowledge_base/reranker/cohere.py +++ b/src/canopy/knowledge_base/reranker/cohere.py @@ -70,7 +70,7 @@ def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: reranked_docs = [] for rerank_result in response: - doc = result.documents[rerank_result.index].copy( + doc = result.documents[rerank_result.index].model_copy( deep=True, update=dict(score=rerank_result.relevance_score) ) diff --git a/src/canopy/llm/cohere.py b/src/canopy/llm/cohere.py index f16f9fd1..cb55e75b 100644 --- a/src/canopy/llm/cohere.py +++ b/src/canopy/llm/cohere.py @@ -396,8 +396,8 @@ def generate_documents_from_stuffing_context_content( """ documents = [] - for result in content.__root__: + for result in content.root: for snippet in result.snippets: - documents.append(snippet.dict()) + documents.append(snippet.model_dump()) return documents diff --git a/src/canopy/llm/models.py b/src/canopy/llm/models.py index ed0c33a2..fe65997b 100644 --- a/src/canopy/llm/models.py +++ b/src/canopy/llm/models.py @@ -1,6 +1,6 @@ from typing import Optional, List, Union -from pydantic import BaseModel +from pydantic import BaseModel, model_serializer class FunctionPrimitiveProperty(BaseModel): @@ -17,8 +17,8 @@ class FunctionArrayProperty(BaseModel): # because the model is more struggling with them description: str - def dict(self, *args, **kwargs): - super_dict = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs): + super_dict = super().model_dump(*args, **kwargs) if "items_type" in super_dict: super_dict["type"] = "array" super_dict["items"] = {"type": super_dict.pop("items_type")} @@ -32,11 +32,12 @@ class FunctionParameters(BaseModel): required_properties: List[FunctionProperty] optional_properties: List[FunctionProperty] = [] - def dict(self, *args, **kwargs): + @model_serializer() + def serialize_model(self): return { "type": "object", "properties": { - pro.name: pro.dict(exclude_none=True, exclude={"name"}) + pro.name: pro.model_dump(exclude_none=True, exclude={"name"}) for pro in self.required_properties + self.optional_properties }, "required": [pro.name for pro in self.required_properties], diff --git a/src/canopy/llm/openai.py b/src/canopy/llm/openai.py index 3e73248d..89c3ea84 100644 --- a/src/canopy/llm/openai.py +++ b/src/canopy/llm/openai.py @@ -121,8 +121,8 @@ def chat_completion(self, system_message = system_prompt else: system_message = system_prompt + f"\nContext: {context.to_text()}" - messages = [SystemMessage(content=system_message).dict() - ] + [m.dict() for m in chat_history] + messages = [SystemMessage(content=system_message).model_dump() + ] + [m.model_dump() for m in chat_history] try: response = self._client.chat.completions.create(model=model, messages=messages, @@ -133,12 +133,12 @@ def chat_completion(self, def streaming_iterator(response): for chunk in response: - yield StreamingChatChunk.parse_obj(chunk) + yield StreamingChatChunk.model_validate(chunk.model_dump()) if stream: return streaming_iterator(response) - return ChatResponse.parse_obj(response) + return ChatResponse.model_validate(response.model_dump()) @retry( reraise=True, @@ -206,10 +206,10 @@ def enforced_function_call(self, model = model_params_dict.pop("model", self.model_name) function_dict = cast(ChatCompletionToolParam, - {"type": "function", "function": function.dict()}) + {"type": "function", "function": function.model_dump()}) - messages = [SystemMessage(content=system_prompt).dict() - ] + [m.dict() for m in chat_history] + messages = [SystemMessage(content=system_prompt).model_dump() + ] + [m.model_dump() for m in chat_history] try: chat_completion = self._client.chat.completions.create( model=model, @@ -226,7 +226,7 @@ def enforced_function_call(self, result = chat_completion.choices[0].message.tool_calls[0].function.arguments arguments = json.loads(result) - jsonschema.validate(instance=arguments, schema=function.parameters.dict()) + jsonschema.validate(instance=arguments, schema=function.parameters.model_dump()) return arguments async def achat_completion(self, diff --git a/src/canopy/models/api_models.py b/src/canopy/models/api_models.py index 53a93585..2b70ac5b 100644 --- a/src/canopy/models/api_models.py +++ b/src/canopy/models/api_models.py @@ -1,6 +1,6 @@ from typing import Optional, Sequence, Iterable -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, validator, ConfigDict, computed_field from canopy.models.data_models import MessageBase @@ -20,11 +20,10 @@ class _StreamChoice(BaseModel): class TokenCounts(BaseModel): prompt_tokens: int completion_tokens: int - total_tokens: Optional[int] = None - @validator("total_tokens", always=True) - def calc_total_tokens(cls, v, values, **kwargs): - return values["prompt_tokens"] + values["completion_tokens"] + @computed_field + def total_tokens(self) -> int: + return self.prompt_tokens + self.completion_tokens class ChatResponse(BaseModel): diff --git a/src/canopy/models/data_models.py b/src/canopy/models/data_models.py index 59badae5..4edfcb0b 100644 --- a/src/canopy/models/data_models.py +++ b/src/canopy/models/data_models.py @@ -2,8 +2,8 @@ from enum import Enum from typing import Optional, List, Union, Dict, Literal -from pydantic import BaseModel, Field, validator, Extra -from typing import TypedDict +from pydantic import field_validator, ConfigDict, BaseModel, Field, RootModel +from typing_extensions import TypedDict Metadata = Dict[str, Union[str, int, float, List[str]]] @@ -42,11 +42,10 @@ class Document(BaseModel): default_factory=dict, description="The document metadata. To learn more about metadata, see https://docs.pinecone.io/docs/manage-data", # noqa: E501 ) + model_config = ConfigDict(extra="forbid", coerce_numbers_to_str=True) - class Config: - extra = Extra.forbid - - @validator("metadata") + @field_validator("metadata") + @classmethod def metadata_reseved_fields(cls, v): if "text" in v: raise ValueError('Metadata cannot contain reserved field "text"') @@ -57,7 +56,7 @@ def metadata_reseved_fields(cls, v): return v -class ContextContent(BaseModel, ABC): +class ContextContent(RootModel, ABC): # Any context should be able to be represented as well formatted text. # In the most minimal case, that could simply be a call to `.json()`. @abstractmethod @@ -69,10 +68,10 @@ def __str__(self): class StringContextContent(ContextContent): - __root__: str + root: str def to_text(self, **kwargs) -> str: - return self.__root__ + return self.root class Context(BaseModel): @@ -98,8 +97,8 @@ class MessageBase(BaseModel): "Can be one of ['User', 'Assistant', 'System']") content: str = Field(description="The contents of the message.") - def dict(self, *args, **kwargs): - d = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs): + d = super().model_dump(*args, **kwargs) d["role"] = d["role"].value return d diff --git a/src/canopy/tokenizer/cohere.py b/src/canopy/tokenizer/cohere.py index 172bbac0..0f00540c 100644 --- a/src/canopy/tokenizer/cohere.py +++ b/src/canopy/tokenizer/cohere.py @@ -105,7 +105,7 @@ def messages_token_count(self, messages: Messages) -> int: num_tokens = 0 for message in messages: num_tokens += self.MESSAGE_TOKENS_OVERHEAD - for key, value in message.dict().items(): + for key, value in message.model_dump().items(): num_tokens += self.token_count(value) num_tokens += self.FIXED_PREFIX_TOKENS return num_tokens @@ -191,7 +191,7 @@ def messages_token_count(self, messages: Messages) -> int: num_tokens = 0 for message in messages: num_tokens += self.MESSAGE_TOKENS_OVERHEAD - for key, value in message.dict().items(): + for key, value in message.model_dump().items(): num_tokens += self.token_count(value) num_tokens += self.FIXED_PREFIX_TOKENS return num_tokens diff --git a/src/canopy/tokenizer/llama.py b/src/canopy/tokenizer/llama.py index 9a71b50c..4c4ddc84 100644 --- a/src/canopy/tokenizer/llama.py +++ b/src/canopy/tokenizer/llama.py @@ -111,7 +111,7 @@ def messages_token_count(self, messages: Messages) -> int: num_tokens = 0 for message in messages: num_tokens += self.MESSAGE_TOKENS_OVERHEAD - for key, value in message.dict().items(): + for key, value in message.model_dump().items(): num_tokens += self.token_count(value) num_tokens += self.FIXED_PREFIX_TOKENS return num_tokens diff --git a/src/canopy/tokenizer/openai.py b/src/canopy/tokenizer/openai.py index 2c00256a..fc34a8d9 100644 --- a/src/canopy/tokenizer/openai.py +++ b/src/canopy/tokenizer/openai.py @@ -91,7 +91,7 @@ def messages_token_count(self, messages: Messages) -> int: num_tokens = 0 for message in messages: num_tokens += self.MESSAGE_TOKENS_OVERHEAD - for key, value in message.dict().items(): + for key, value in message.model_dump().items(): num_tokens += self.token_count(value) num_tokens += self.FIXED_PREFIX_TOKENS return num_tokens diff --git a/src/canopy_cli/cli.py b/src/canopy_cli/cli.py index 5e1bbd6b..4022cbc6 100644 --- a/src/canopy_cli/cli.py +++ b/src/canopy_cli/cli.py @@ -408,7 +408,7 @@ def upsert(index_name: str, ) raise CLIError(msg) pd.options.display.max_colwidth = 20 - click.echo(pd.DataFrame([doc.dict(exclude_none=True) for doc in data[:5]])) + click.echo(pd.DataFrame([doc.model_dump(exclude_none=True) for doc in data[:5]])) click.echo(click.style(f"\nTotal records: {len(data)}")) click.confirm(click.style("\nDoes this data look right?", fg="red"), abort=True) diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index 19b2f793..69c048d0 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -132,7 +132,7 @@ async def chat( session_id = request.user or "None" # noqa: F841 question_id = str(uuid.uuid4()) logger.debug(f"Received chat request: {request.messages[-1].content}") - model_params = request.dict(exclude={"messages", "stream"}) + model_params = request.model_dump(exclude={"messages", "stream"}) answer = await run_in_threadpool( chat_engine.chat, messages=request.messages, diff --git a/src/canopy_server/models/v1/api_models.py b/src/canopy_server/models/v1/api_models.py index 8d2c5aab..db1080f6 100644 --- a/src/canopy_server/models/v1/api_models.py +++ b/src/canopy_server/models/v1/api_models.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import ConfigDict, BaseModel, Field from canopy.models.data_models import Messages, Query, Document @@ -70,9 +70,7 @@ class ChatRequest(BaseModel): default=None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Unused, reserved for future extensions", # noqa: E501 ) - - class Config: - extra = "ignore" + model_config = ConfigDict(extra="ignore") class ContextQueryRequest(BaseModel): diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index 6ffc39e9..4fab8312 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -104,7 +104,7 @@ def test_health(client): assert health_response.is_success assert ( health_response.json() - == HealthStatus(pinecone_status="OK", llm_status="OK").dict() + == HealthStatus(pinecone_status="OK", llm_status="OK").model_dump() ) @@ -112,7 +112,7 @@ def test_upsert(client, namespace_prefix): # Upsert a document to the index upsert_response = client.post( f"{namespace_prefix}context/upsert", - json=upsert_payload.dict()) + json=upsert_payload.model_dump()) assert upsert_response.is_success @@ -133,7 +133,7 @@ def test_query(client, namespace_prefix): query_response = client.post( f"{namespace_prefix}context/query", - json=query_payload.dict()) + json=query_payload.model_dump()) assert query_response.is_success query_response = query_response.json() @@ -143,12 +143,12 @@ def test_query(client, namespace_prefix): stuffing_content = json.loads(query_response["content"]) assert ( stuffing_content[0]["query"] - == query_payload.dict()["queries"][0]["text"] + == query_payload.model_dump()["queries"][0]["text"] and stuffing_content[0]["snippets"][0]["text"] - == upsert_payload.dict()["documents"][0]["text"] + == upsert_payload.model_dump()["documents"][0]["text"] ) assert (stuffing_content[0]["snippets"][0]["source"] == - upsert_payload.dict()["documents"][0]["source"]) + upsert_payload.model_dump()["documents"][0]["source"]) def test_chat_required_params(client, namespace_prefix): diff --git a/tests/system/llm/test_cohere.py b/tests/system/llm/test_cohere.py index 69a66571..45efe3e8 100644 --- a/tests/system/llm/test_cohere.py +++ b/tests/system/llm/test_cohere.py @@ -223,7 +223,7 @@ def test_chat_completion_with_stuffing_context_snippets(cohere_llm, expected_chat_kwargs, system_prompt): cohere_llm._client = MagicMock(wraps=cohere_llm._client) - content = StuffingContextContent(__root__=[ + content = StuffingContextContent(root=[ ContextQueryResult(query="", snippets=[ ContextSnippet( source="https://www.example.com/document", diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index b0d0c21f..f52a22b3 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -111,7 +111,7 @@ def test_chat_completion_with_context(openai_llm, messages): chat_history=messages, context=Context( content=StringContextContent( - __root__="context from kb" + root="context from kb" ), num_tokens=5 )) diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index 97628019..92ec7d62 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -60,7 +60,7 @@ def _get_inputs_and_expected(self, mock_queries = [Query(text="How does photosynthesis work?")] mock_context = Context( content=StuffingContextContent( - __root__=[ContextQueryResult( + root=[ContextQueryResult( query="How does photosynthesis work?", snippets=[ContextSnippet(source="ref 1", diff --git a/tests/unit/chunker/test_markdown_chunker.py b/tests/unit/chunker/test_markdown_chunker.py index db9692b9..4025e5bc 100644 --- a/tests/unit/chunker/test_markdown_chunker.py +++ b/tests/unit/chunker/test_markdown_chunker.py @@ -127,7 +127,7 @@ def expected_chunks(documents): '\ntext in level 3\n#### Level 4\ntext in level 4\n##### Level 5' '\ntext in level 5\n###### Level 6\ntext in level 6', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk( @@ -139,7 +139,7 @@ def expected_chunks(documents): '~~Strikethrough text~~\n\n' '## Another second level header\ntext after second level header', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk( @@ -156,13 +156,13 @@ def expected_chunks(documents): '\n\n## Blockquotes\n\n' '> This is a blockquote.', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_3', text='## long text', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_4', @@ -176,7 +176,7 @@ def expected_chunks(documents): 'Inside, not gold, But memories and ' 'tales. Of', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_5', @@ -185,7 +185,7 @@ def expected_chunks(documents): 'Of brave ancestors, And ' 'magical whales.', source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_6', @@ -195,7 +195,7 @@ def expected_chunks(documents): "\nThe village united, " "Bathed in tales' light.", source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_7', @@ -206,13 +206,13 @@ def expected_chunks(documents): "\n```\n## table" "\na | b | c\n--- | --- | ---\n1 | 2 | 3", source='doc_1', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_3_0', text='# short markdown\nmarkdown is short', source='', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_3') ] return chunks diff --git a/tests/unit/chunker/test_recursive_character_chunker.py b/tests/unit/chunker/test_recursive_character_chunker.py index e3fdbd22..70f195e1 100644 --- a/tests/unit/chunker/test_recursive_character_chunker.py +++ b/tests/unit/chunker/test_recursive_character_chunker.py @@ -19,49 +19,49 @@ def expected_chunks(documents): return [ KBDocChunk(id='test_document_1_0', text='I am a', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_1', text='a simple test', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_2', text='test string to', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_3', text='to check the', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_4', text='the happy path', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_5', text='path of this', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_6', text='this simple chunker', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_2_0', text='another simple test', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_2', source='doc_2'), KBDocChunk(id='test_document_2_1', text='test string', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_2', source='doc_2'), KBDocChunk(id='test_document_3_0', text='sho', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_3', source='doc_3'), KBDocChunk(id='test_document_3_1', text='ort', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_3', source='doc_3')] diff --git a/tests/unit/chunker/test_token_chunker.py b/tests/unit/chunker/test_token_chunker.py index b5550e3c..b6a1d739 100644 --- a/tests/unit/chunker/test_token_chunker.py +++ b/tests/unit/chunker/test_token_chunker.py @@ -19,32 +19,32 @@ def chunker(): def expected_chunks(documents): return [KBDocChunk(id='test_document_1_0', text='I am a simple test', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_1', text='simple test string to check', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_2', text='to check the happy path', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_3', text='happy path of this simple', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_4', text='this simple chunker', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1',), KBDocChunk(id='test_document_2_0', text='another simple test string', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_2', source='doc_2'), KBDocChunk(id='test_document_3_0', text='short', - metadata={'test': '2'}, + metadata={'test': 2}, document_id='test_document_3', source='doc_3'), ] @@ -59,11 +59,11 @@ def test_chunk_single_document_zero_overlap(chunker): expected = [KBDocChunk(id='test_document_1_0', text='I am a test string', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1'), KBDocChunk(id='test_document_1_1', text='with no overlap', - metadata={'test': '1'}, + metadata={'test': 1}, document_id='test_document_1')] for actual_chunk, expected_chunk in zip(actual, expected): diff --git a/tests/unit/context_builder/test_stuffing_context_builder.py b/tests/unit/context_builder/test_stuffing_context_builder.py index 4881926b..b6e25670 100644 --- a/tests/unit/context_builder/test_stuffing_context_builder.py +++ b/tests/unit/context_builder/test_stuffing_context_builder.py @@ -50,7 +50,7 @@ def setup_method(self): ]) ] self.full_context = Context( - content=StuffingContextContent(__root__=[ + content=StuffingContextContent(root=[ ContextQueryResult(query="test query 1", snippets=[ ContextSnippet( @@ -80,7 +80,7 @@ def test_context_fits_within_max_tokens(self): def test_context_exceeds_max_tokens(self): context = self.builder.build(self.query_results, max_context_tokens=30) - expected_context = Context(content=StuffingContextContent(__root__=[ + expected_context = Context(content=StuffingContextContent(root=[ ContextQueryResult(query="test query 1", snippets=[ ContextSnippet( @@ -102,7 +102,7 @@ def test_context_exceeds_max_tokens_unordered(self): self.query_results[0].documents[0].text = self.text1 * 100 context = self.builder.build(self.query_results, max_context_tokens=20) - expected_context = Context(content=StuffingContextContent(__root__=[ + expected_context = Context(content=StuffingContextContent(root=[ ContextQueryResult(query="test query 2", snippets=[ ContextSnippet( @@ -118,17 +118,17 @@ def test_context_exceeds_max_tokens_unordered(self): def test_whole_query_results_not_fit(self): context = self.builder.build(self.query_results, max_context_tokens=10) assert context.num_tokens == 1 - assert context.content == [] + assert context.content.model_dump() == [] def test_max_tokens_zero(self): context = self.builder.build(self.query_results, max_context_tokens=0) self.assert_num_tokens(context, 1) - assert context.content == [] + assert context.content.model_dump() == [] def test_empty_query_results(self): context = self.builder.build([], max_context_tokens=100) self.assert_num_tokens(context, 1) - assert context.content == [] + assert context.content.model_dump() == [] def test_documents_with_duplicates(self): duplicate_query_results = self.query_results + [ @@ -173,7 +173,7 @@ def test_empty_documents(self): context = self.builder.build( empty_query_results, max_context_tokens=100) self.assert_num_tokens(context, 1) - assert context.content == [] + assert context.content.model_dump() == [] def assert_num_tokens(self, context: Context, max_tokens: int): assert context.num_tokens <= max_tokens diff --git a/tests/unit/context_engine/test_context_engine.py b/tests/unit/context_engine/test_context_engine.py index 1ed2b52b..f86dc66c 100644 --- a/tests/unit/context_engine/test_context_engine.py +++ b/tests/unit/context_engine/test_context_engine.py @@ -186,11 +186,11 @@ def test_context_query_result_to_text(): query_result = ContextQueryResult(query="How does photosynthesis work?", snippets=[ContextSnippet(text="42", source="ref")]) - context = Context(content=StuffingContextContent(__root__=[query_result]), + context = Context(content=StuffingContextContent(root=[query_result]), num_tokens=1) - assert context.to_text() == json.dumps([query_result.dict()]) - assert context.to_text(indent=2) == json.dumps([query_result.dict()], indent=2) + assert context.to_text() == json.dumps([query_result.model_dump()]) + assert context.to_text(indent=2) == json.dumps([query_result.model_dump()], indent=2) @pytest.mark.asyncio diff --git a/tests/unit/history_pruner/test_raising_history_pruner.py b/tests/unit/history_pruner/test_raising_history_pruner.py index 03a3956c..8bb62a3e 100644 --- a/tests/unit/history_pruner/test_raising_history_pruner.py +++ b/tests/unit/history_pruner/test_raising_history_pruner.py @@ -7,7 +7,7 @@ SAMPLE_CONTEXT = Context(content=StringContextContent( - __root__="Some context information"), num_tokens=3 + root="Some context information"), num_tokens=3 ) SYSTEM_PROMPT = "This is a system prompt." diff --git a/tests/unit/history_pruner/test_recent_history_pruner.py b/tests/unit/history_pruner/test_recent_history_pruner.py index 7cf73a7e..ec3346d7 100644 --- a/tests/unit/history_pruner/test_recent_history_pruner.py +++ b/tests/unit/history_pruner/test_recent_history_pruner.py @@ -6,8 +6,7 @@ from canopy.tokenizer import Tokenizer -SAMPLE_CONTEXT = Context(content=StringContextContent( - __root__="Some context information"), num_tokens=3 +SAMPLE_CONTEXT = Context(content=StringContextContent(root="Some context information"), num_tokens=3 ) SYSTEM_PROMPT = "This is a system prompt." diff --git a/tests/unit/record_encoder/base_test_record_encoder.py b/tests/unit/record_encoder/base_test_record_encoder.py index 627d12ca..16658e50 100644 --- a/tests/unit/record_encoder/base_test_record_encoder.py +++ b/tests/unit/record_encoder/base_test_record_encoder.py @@ -47,14 +47,14 @@ def queries(): @pytest.fixture def expected_encoded_documents(documents, inner_encoder): values = inner_encoder.encode_documents([d.text for d in documents]) - return [KBEncodedDocChunk(**d.dict(), values=v) for d, v in + return [KBEncodedDocChunk(**d.model_dump(), values=v) for d, v in zip(documents, values)] @staticmethod @pytest.fixture def expected_encoded_queries(queries, inner_encoder): values = inner_encoder.encode_queries([q.text for q in queries]) - return [KBQuery(**q.dict(), values=v) for q, v in zip(queries, values)] + return [KBQuery(**q.model_dump(), values=v) for q, v in zip(queries, values)] @staticmethod def test_dimension(record_encoder, expected_dimension): diff --git a/tests/unit/stubs/stub_record_encoder.py b/tests/unit/stubs/stub_record_encoder.py index 2b77f170..7222d0df 100644 --- a/tests/unit/stubs/stub_record_encoder.py +++ b/tests/unit/stubs/stub_record_encoder.py @@ -22,7 +22,7 @@ def _encode_documents_batch(self, values = self._dense_encoder.encode_documents(doc.text) result.append( KBEncodedDocChunk( - **doc.dict(), + **doc.model_dump(), values=values)) return result @@ -33,7 +33,7 @@ def _encode_queries_batch(self, for query in queries: values = self._dense_encoder.encode_queries(query.text) result.append( - KBQuery(**query.dict(), + KBQuery(**query.model_dump(), values=values)) return result