Skip to content

Commit

Permalink
feature: Log metadata for retrieved chunks (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJBoyer authored Aug 13, 2024
1 parent b0c8719 commit 43fc8da
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
25 changes: 17 additions & 8 deletions app/src/chainlit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import pprint
from typing import Any
from typing import Any, Sequence
from urllib.parse import parse_qs, urlparse

import chainlit as cl
from chainlit.input_widget import InputWidget, Select, Slider
from src import chat_engine
from src.app_config import app_config
from src.chat_engine import ChatEngineInterface
from src.db.models.document import ChunkWithScore
from src.generate import get_models
from src.login import require_login

Expand Down Expand Up @@ -47,7 +48,7 @@ async def start() -> None:

await cl.Message(
author="backend",
metadata={"engine": engine_id, "settings": str(settings)},
metadata={"engine": engine_id, "settings": settings},
content=f"{engine.name} started with settings:\n{pprint.pformat(settings, indent=3)}",
).send()

Expand Down Expand Up @@ -149,14 +150,9 @@ async def on_message(message: cl.Message) -> None:
chunks_with_scores=result.chunks_with_scores,
)

chunk_titles_and_scores: dict[str, float] = {}
for chunk_with_score in result.chunks_with_scores:
title = chunk_with_score.chunk.document.name
chunk_titles_and_scores |= {title: chunk_with_score.score}

await cl.Message(
content=msg_content,
metadata=chunk_titles_and_scores,
metadata=_get_retrieval_metadata(result.chunks_with_scores),
).send()
except Exception as err: # pylint: disable=broad-exception-caught
await cl.Message(
Expand All @@ -166,3 +162,16 @@ async def on_message(message: cl.Message) -> None:
).send()
# Re-raise error to have it in the logs
raise err


def _get_retrieval_metadata(chunks_with_scores: Sequence[ChunkWithScore]) -> dict:
return {
"chunks": [
{
"document.name": chunk_with_score.chunk.document.name,
"chunk.id": chunk_with_score.chunk.id,
"score": chunk_with_score.score,
}
for chunk_with_score in chunks_with_scores
]
}
10 changes: 10 additions & 0 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tests.src.db.models.factories as factories
from src.app_config import AppConfig
from src.db import models
from src.db.models.document import ChunkWithScore
from src.util.local import load_local_env_vars
from tests.lib import db_testing
from tests.mock.mock_sentence_transformer import MockSentenceTransformer
Expand Down Expand Up @@ -156,3 +157,12 @@ def mock_s3_bucket_resource(mock_s3):
@pytest.fixture
def mock_s3_bucket(mock_s3_bucket_resource):
yield mock_s3_bucket_resource.name


@pytest.fixture
def chunks_with_scores():
return [
ChunkWithScore(factories.ChunkFactory.build(), 0.99),
ChunkWithScore(factories.ChunkFactory.build(), 0.90),
ChunkWithScore(factories.ChunkFactory.build(), 0.85),
]
23 changes: 23 additions & 0 deletions app/tests/src/test_chainlit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src import chainlit, chat_engine
from src.chainlit import _get_retrieval_metadata


def test_url_query_values(monkeypatch):
Expand All @@ -16,3 +17,25 @@ def test_url_query_values(monkeypatch):
# Only 1 query parameter remains
assert len(query_values) == 1
assert query_values["someunknownparam"] == "42"


def test__get_retrieval_metadata(chunks_with_scores):
assert _get_retrieval_metadata(chunks_with_scores) == {
"chunks": [
{
"document.name": chunks_with_scores[0].chunk.document.name,
"chunk.id": chunks_with_scores[0].chunk.id,
"score": chunks_with_scores[0].score,
},
{
"document.name": chunks_with_scores[1].chunk.document.name,
"chunk.id": chunks_with_scores[1].chunk.id,
"score": chunks_with_scores[1].score,
},
{
"document.name": chunks_with_scores[2].chunk.document.name,
"chunk.id": chunks_with_scores[2].chunk.id,
"score": chunks_with_scores[2].score,
},
]
}
16 changes: 4 additions & 12 deletions app/tests/src/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,18 @@ def test_format_guru_cards_with_score(monkeypatch, app_config, db_session, enabl
assert len(_unique_accordion_ids(html + next_html)) == 2 * len(chunks_with_scores)


def _chunks_with_scores():
return [
ChunkWithScore(ChunkFactory.build(), 0.99),
ChunkWithScore(ChunkFactory.build(), 0.90),
ChunkWithScore(ChunkFactory.build(), 0.85),
]


def test_format_guru_cards_given_chunks_shown_max_num():
def test_format_guru_cards_given_chunks_shown_max_num(chunks_with_scores):
html = format_guru_cards(
chunks_shown_max_num=2, chunks_shown_min_score=0.8, chunks_with_scores=_chunks_with_scores()
chunks_shown_max_num=2, chunks_shown_min_score=0.8, chunks_with_scores=chunks_with_scores
)
assert len(_unique_accordion_ids(html)) == 2


def test_format_guru_cards_given_chunks_shown_max_num_and_min_score():
def test_format_guru_cards_given_chunks_shown_max_num_and_min_score(chunks_with_scores):
html = format_guru_cards(
chunks_shown_max_num=2,
chunks_shown_min_score=0.91,
chunks_with_scores=_chunks_with_scores(),
chunks_with_scores=chunks_with_scores,
)
assert len(_unique_accordion_ids(html)) == 1

Expand Down
16 changes: 7 additions & 9 deletions app/tests/src/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import ollama

from src.db.models.document import ChunkWithScore
from src.generate import PROMPT, generate, get_models
from tests.mock import mock_completion
from tests.src.db.models.factories import ChunkFactory


def ollama_model_list():
Expand Down Expand Up @@ -113,18 +111,18 @@ def test_generate(monkeypatch):
assert generate("gpt-4o", "some query") == expected_response


def test_generate_with_context_with_score(monkeypatch):
def test_generate_with_context_with_score(monkeypatch, chunks_with_scores):
monkeypatch.setattr("src.generate.completion", mock_completion.mock_completion)
context = [
ChunkWithScore(ChunkFactory.build(), 0.2000),
ChunkWithScore(ChunkFactory.build(), -0.3000),
]
context_text = f"{context[0].chunk.document.name}\n{context[0].chunk.content}\n\n{context[1].chunk.document.name}\n{context[1].chunk.content}"
context_text = (
f"{chunks_with_scores[0].chunk.document.name}\n{chunks_with_scores[0].chunk.content}\n\n"
+ f"{chunks_with_scores[1].chunk.document.name}\n{chunks_with_scores[1].chunk.content}\n\n"
+ f"{chunks_with_scores[2].chunk.document.name}\n{chunks_with_scores[2].chunk.content}"
)
expected_response = (
'Called gpt-4o with [{"content": "'
+ PROMPT
+ '", "role": "system"}, {"content": "Use the following context to answer the question: '
+ context_text
+ '", "role": "system"}, {"content": "some query", "role": "user"}]'
)
assert generate("gpt-4o", "some query", context=context) == expected_response
assert generate("gpt-4o", "some query", context=chunks_with_scores) == expected_response

0 comments on commit 43fc8da

Please sign in to comment.