Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
support kwargs for to_text of context
Browse files Browse the repository at this point in the history
  • Loading branch information
acatav committed Oct 26, 2023
1 parent d586a48 commit 85cf6c4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/canopy/context_engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class ContextQueryResult(ContextContent):
query: str
snippets: List[ContextSnippet]

def to_text(self):
return self.json()
def to_text(self, **kwargs):
return self.json(**kwargs)
8 changes: 4 additions & 4 deletions src/canopy/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ContextContent(BaseModel, ABC):
# Any context should be able to be represented as well formatted text.
# In the most minimal case, that could simply be a call to `.json()`.
@abstractmethod
def to_text(self) -> str:
def to_text(self, **kwargs) -> str:
pass


Expand All @@ -52,11 +52,11 @@ class Context(BaseModel):
num_tokens: int = Field(exclude=True)
debug_info: dict = Field(default_factory=dict, exclude=True)

def to_text(self) -> str:
def to_text(self, **kwargs) -> str:
if isinstance(self.content, ContextContent):
return self.content.to_text()
return self.content.to_text(**kwargs)
else:
return "\n".join([c.to_text() for c in self.content])
return "\n".join([c.to_text(**kwargs) for c in self.content])


# --------------------- LLM models ------------------------
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/context_engine/test_context_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json

import pytest
from unittest.mock import create_autospec

from canopy.context_engine import ContextEngine
from canopy.context_engine.context_builder.base import ContextBuilder
from canopy.context_engine.models import ContextQueryResult, ContextSnippet
from canopy.knowledge_base.base import BaseKnowledgeBase
from canopy.knowledge_base.models import QueryResult, DocumentWithScore
from canopy.models.data_models import Query, Context, ContextContent
Expand Down Expand Up @@ -175,6 +178,16 @@ def test_empty_query_results(context_engine,

assert result == mock_context

@staticmethod
def test_context_query_result_to_text():
query_result = ContextQueryResult(query="How does photosynthesis work?",
snippets=[ContextSnippet(text="42",
source="ref")])
context = Context(content=query_result, num_tokens=1)

assert context.to_text() == json.dumps(query_result.dict())
assert context.to_text(indent=2) == json.dumps(query_result.dict(), indent=2)

@staticmethod
@pytest.mark.asyncio
async def test_aquery_not_implemented(context_engine):
Expand Down

0 comments on commit 85cf6c4

Please sign in to comment.