Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Merge pull request #144 from pinecone-io/refactor_context_content
Browse files Browse the repository at this point in the history
Context must contain a ContextContent that implements to_text()
  • Loading branch information
igiloh-pinecone authored Nov 6, 2023
2 parents 4e7a2ce + 3c828c4 commit 2104b65
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 107 deletions.
52 changes: 40 additions & 12 deletions src/canopy/context_engine/context_builder/stuffing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
from itertools import zip_longest
from typing import List, Tuple

from pydantic import BaseModel

from canopy.context_engine.context_builder.base import ContextBuilder
from canopy.context_engine.models import ContextQueryResult, ContextSnippet
from canopy.knowledge_base.models import QueryResult, DocumentWithScore
from canopy.tokenizer import Tokenizer
from canopy.models.data_models import Context
from canopy.models.data_models import Context, ContextContent


# ------------- DATA MODELS -------------

class ContextSnippet(BaseModel):
source: str
text: str


class ContextQueryResult(BaseModel):
query: str
snippets: List[ContextSnippet]


class StuffingContextContent(ContextContent):
__root__: List[ContextQueryResult]

def dict(self, **kwargs):
return super().dict(**kwargs)['__root__']

# In the case of StuffingContextBuilder, we simply want the text representation to
# be a json. Other ContextContent subclasses may render into text differently
def to_text(self, **kwargs):
return self.json(**kwargs)


# ------------- CONTEXT BUILDER -------------

class StuffingContextBuilder(ContextBuilder):

Expand All @@ -24,12 +51,11 @@ def build(self,
ContextQueryResult(query=qr.query, snippets=[])
for qr in query_results]
debug_info = {"num_docs": len(sorted_docs_with_origin)}
context = Context(content=context_query_results,
num_tokens=0,
debug_info=debug_info)
content = StuffingContextContent(__root__=context_query_results)

if self._tokenizer.token_count(context.to_text()) > max_context_tokens:
return Context(content=[], num_tokens=0, debug_info=debug_info)
if self._tokenizer.token_count(content.to_text()) > max_context_tokens:
return Context(content=StuffingContextContent(__root__=[]),
num_tokens=1, debug_info=debug_info)

seen_doc_ids = set()
for doc, origin_query_idx in sorted_docs_with_origin:
Expand All @@ -41,15 +67,17 @@ def build(self,
snippet)
seen_doc_ids.add(doc.id)
# if the context is too long, remove the snippet
if self._tokenizer.token_count(context.to_text()) > max_context_tokens:
if self._tokenizer.token_count(content.to_text()) > max_context_tokens:
context_query_results[origin_query_idx].snippets.pop()

# remove queries with no snippets
context.content = [qr for qr in context_query_results
if len(qr.snippets) > 0]
content = StuffingContextContent(
__root__=[qr for qr in context_query_results if len(qr.snippets) > 0]
)

context.num_tokens = self._tokenizer.token_count(context.to_text())
return context
return Context(content=content,
num_tokens=self._tokenizer.token_count(content.to_text()),
debug_info=debug_info)

@staticmethod
def _round_robin_sort(
Expand Down
18 changes: 0 additions & 18 deletions src/canopy/context_engine/models.py

This file was deleted.

15 changes: 6 additions & 9 deletions src/canopy/models/data_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, List, Union, Dict, Sequence, Literal
from typing import Optional, List, Union, Dict, Literal

from pydantic import BaseModel, Field, validator, Extra

Expand Down Expand Up @@ -56,27 +56,24 @@ def metadata_reseved_fields(cls, v):
return v


class _ContextContent(BaseModel, ABC):
class ContextContent(BaseModel, ABC):
# Any context should be able to be represented as well formatted text.
# In the most minimal case, that could simply be a call to `.json()`.
@abstractmethod
def to_text(self, **kwargs) -> str:
pass


ContextContent = Union[_ContextContent, Sequence[_ContextContent]]
def __str__(self):
return self.to_text()


class Context(BaseModel):
content: ContextContent
num_tokens: int = Field(exclude=True)
num_tokens: int
debug_info: dict = Field(default_factory=dict, exclude=True)

def to_text(self, **kwargs) -> str:
if isinstance(self.content, Sequence):
return "\n".join([c.to_text(**kwargs) for c in self.content])
else:
return self.content.to_text(**kwargs)
return self.content.to_text(**kwargs)


# --------------------- LLM models ------------------------
Expand Down
5 changes: 5 additions & 0 deletions src/canopy_server/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class ContextQueryRequest(BaseModel):
max_tokens: int


class ContextResponse(BaseModel):
content: str
num_tokens: int


class ContextUpsertRequest(BaseModel):
documents: List[Document]
batch_size: int = Field(
Expand Down
11 changes: 6 additions & 5 deletions src/canopy_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
StreamingChatResponse,
ChatResponse,
)
from canopy.models.data_models import Context, UserMessage, ContextContent
from canopy.models.data_models import Context, UserMessage
from .api_models import (
ChatRequest,
ContextQueryRequest,
Expand All @@ -36,6 +36,7 @@
ShutdownResponse,
SuccessUpsertResponse,
SuccessDeleteResponse,
ContextResponse,
)

from canopy.llm.openai import OpenAILLM
Expand Down Expand Up @@ -127,14 +128,14 @@ def stringify_content(response: StreamingChatResponse):

@app.post(
"/context/query",
response_model=ContextContent,
response_model=ContextResponse,
responses={
500: {"description": "Failed to query the knowledge base or build the context"}
},
)
async def query(
request: ContextQueryRequest = Body(...),
) -> ContextContent:
) -> ContextResponse:
"""
Query the knowledge base for relevant context.
The returned text may be structured or unstructured, depending on the Canopy configuration.
Expand All @@ -147,8 +148,8 @@ async def query(
queries=request.queries,
max_context_tokens=request.max_tokens,
)

return context.content
return ContextResponse(content=context.content.to_text(),
num_tokens=context.num_tokens)

except Exception as e:
logger.exception(e)
Expand Down
22 changes: 13 additions & 9 deletions tests/e2e/test_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import List

Expand All @@ -12,7 +13,7 @@

from canopy_server.app import app
from canopy_server.api_models import (HealthStatus, ContextUpsertRequest,
ContextQueryRequest)
ContextQueryRequest, )
from .. import Tokenizer

upsert_payload = ContextUpsertRequest(
Expand All @@ -27,14 +28,14 @@
)


@retry(stop=stop_after_attempt(60), wait=wait_fixed(1))
@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1))
def assert_vector_ids_exist(vector_ids: List[str],
knowledge_base: KnowledgeBase):
fetch_response = knowledge_base._index.fetch(ids=vector_ids)
assert all([v_id in fetch_response["vectors"] for v_id in vector_ids])


@retry(stop=stop_after_attempt(60), wait=wait_fixed(1))
@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1))
def assert_vector_ids_not_exist(vector_ids: List[str],
knowledge_base: KnowledgeBase):
fetch_response = knowledge_base._index.fetch(ids=vector_ids)
Expand Down Expand Up @@ -98,9 +99,10 @@ def test_upsert(client):
assert upsert_response.is_success


@retry(stop=stop_after_attempt(60), wait=wait_fixed(1))
@retry(reraise=True, stop=stop_after_attempt(60), wait=wait_fixed(1))
def test_query(client):
# fetch the context with all the right filters
tokenizer = Tokenizer()
query_payload = ContextQueryRequest(
queries=[
{
Expand All @@ -115,16 +117,18 @@ def test_query(client):
query_response = client.post("/context/query", json=query_payload.dict())
assert query_response.is_success

# test response is as expected on /query
response_as_json = query_response.json()
query_response = query_response.json()
assert (query_response["num_tokens"] ==
len(tokenizer.tokenize(query_response["content"])))

stuffing_content = json.loads(query_response["content"])
assert (
response_as_json[0]["query"]
stuffing_content[0]["query"]
== query_payload.dict()["queries"][0]["text"]
and response_as_json[0]["snippets"][0]["text"]
and stuffing_content[0]["snippets"][0]["text"]
== upsert_payload.dict()["documents"][0]["text"]
)
assert (response_as_json[0]["snippets"][0]["source"] ==
assert (stuffing_content[0]["snippets"][0]["source"] ==
upsert_payload.dict()["documents"][0]["source"])


Expand Down
20 changes: 12 additions & 8 deletions tests/unit/chat_engine/test_chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from canopy.chat_engine import ChatEngine
from canopy.chat_engine.query_generator import QueryGenerator
from canopy.context_engine import ContextEngine
from canopy.context_engine.models import ContextQueryResult, ContextSnippet
from canopy.context_engine.context_builder.stuffing import (ContextSnippet,
ContextQueryResult,
StuffingContextContent, )
from canopy.llm import BaseLLM
from canopy.models.data_models import SystemMessage
from canopy.models.api_models import ChatResponse, _Choice, TokenCounts
Expand Down Expand Up @@ -58,13 +60,15 @@ def _get_inputs_and_expected(self,
]
mock_queries = [Query(text="How does photosynthesis work?")]
mock_context = Context(
content=ContextQueryResult(
query="How does photosynthesis work?",

snippets=[ContextSnippet(source="ref 1",
text=self._generate_text(snippet_length)),
ContextSnippet(source="ref 2",
text=self._generate_text(12))]
content=StuffingContextContent(
__root__=[ContextQueryResult(
query="How does photosynthesis work?",

snippets=[ContextSnippet(source="ref 1",
text=self._generate_text(snippet_length)),
ContextSnippet(source="ref 2",
text=self._generate_text(12))]
)]
),
num_tokens=1 # TODO: This is a dummy value. Need to improve.
)
Expand Down
Loading

0 comments on commit 2104b65

Please sign in to comment.