Skip to content

Commit

Permalink
feat: Enable BEM Chatbot (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJBoyer authored Aug 7, 2024
1 parent 8531fc8 commit c61ec0a
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 23 deletions.
5 changes: 3 additions & 2 deletions app/src/chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from src import chat_engine
from src.app_config import app_config
from src.chat_engine import ChatEngineInterface
from src.format import format_guru_cards
from src.generate import get_models
from src.login import require_login

Expand Down Expand Up @@ -143,11 +142,13 @@ async def on_message(message: cl.Message) -> None:
engine: chat_engine.ChatEngineInterface = cl.user_session.get("chat_engine")
try:
result = await cl.make_async(lambda: engine.on_message(question=message.content))()
msg_content = result.response + format_guru_cards(

msg_content = result.response + engine.formatter(
docs_shown_max_num=engine.docs_shown_max_num,
docs_shown_min_score=engine.docs_shown_min_score,
chunks_with_scores=result.chunks_with_scores,
)

chunk_titles_and_scores: dict[str, float] = {}
for chunk_with_score in result.chunks_with_scores:
title = chunk_with_score.chunk.document.name
Expand Down
28 changes: 15 additions & 13 deletions app/src/chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence
from typing import Callable, Sequence

from src.db.models.document import ChunkWithScore
from src.format import format_bem_documents, format_guru_cards
from src.generate import generate
from src.retrieve import retrieve_with_scores
from src.util.class_utils import all_subclasses
Expand All @@ -21,6 +22,9 @@ class ChatEngineInterface(ABC):
engine_id: str
name: str

# Function for formatting responses
formatter: Callable

# Thresholds that determine which retrieved documents are shown in the UI
docs_shown_max_num: int = 5
docs_shown_min_score: float = 0.65
Expand Down Expand Up @@ -58,7 +62,7 @@ def create_engine(engine_id: str) -> ChatEngineInterface | None:


# Subclasses of ChatEngineInterface can be extracted into a separate file if it gets too large
class GuruBaseEngine(ChatEngineInterface):
class BaseEngine(ChatEngineInterface):
datasets: list[str] = []
llm: str = "gpt-4o"

Expand Down Expand Up @@ -86,24 +90,22 @@ def on_message(self, question: str) -> OnMessageResult:
return OnMessageResult(response, chunks_with_scores)


class GuruMultiprogramEngine(GuruBaseEngine):
class GuruMultiprogramEngine(BaseEngine):
engine_id: str = "guru-multiprogram"
name: str = "Guru Multi-program Chat Engine"
datasets = ["guru-multiprogram"]
formatter = staticmethod(format_guru_cards)


class GuruSnapEngine(GuruBaseEngine):
class GuruSnapEngine(BaseEngine):
engine_id: str = "guru-snap"
name: str = "Guru SNAP Chat Engine"
datasets = ["guru-snap"]
formatter = staticmethod(format_guru_cards)


class PolicyMichiganEngine(ChatEngineInterface):
engine_id: str = "policy-mi"
name: str = "Michigan Bridges Policy Manual Chat Engine"

def on_message(self, question: str) -> OnMessageResult:
logger.warning("TODO: Retrieve from MI Policy Manual")
chunks: Sequence[ChunkWithScore] = []
response = "TEMP: Replace with generated response once chunks are correct"
return OnMessageResult(response, chunks)
class BridgesEligibilityManualEngine(BaseEngine):
engine_id: str = "bridges-eligibility-manual"
name: str = "Michigan Bridges Eligibility Manual Chat Engine"
datasets = ["bridges-eligibility-manual"]
formatter = staticmethod(format_bem_documents)
7 changes: 7 additions & 0 deletions app/src/db/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ class Chunk(Base, IdMixin, TimestampMixin):
class ChunkWithScore:
chunk: Chunk
score: float


@dataclass
class DocumentWithMaxScore:
document: Document
# The maxmium similarity score of all Chunks associated with that document
max_score: float
41 changes: 40 additions & 1 deletion app/src/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
from typing import Sequence

from src.db.models.document import ChunkWithScore
from src.db.models.document import ChunkWithScore, Document, DocumentWithMaxScore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,3 +49,42 @@ def format_guru_cards(
</div>
</div>"""
return "<h3>Related Guru cards</h3>" + cards_html


def _get_bem_documents_to_show(
docs_shown_max_num: int,
docs_shown_min_score: float,
chunks_with_scores: Sequence[ChunkWithScore],
) -> Sequence[Document]:
# Build a deduplicated list of documents with the max score
# of all chunks associated with the document.
documents_with_scores: list[DocumentWithMaxScore] = []
for chunk_with_score in chunks_with_scores:
if chunk_with_score.score >= docs_shown_min_score:
document = chunk_with_score.chunk.document
existing_doc = next(
(d for d in documents_with_scores if d.document == document),
None,
)
if existing_doc:
existing_doc.max_score = max(existing_doc.max_score, chunk_with_score.score)
else:
documents_with_scores.append(DocumentWithMaxScore(document, chunk_with_score.score))

# Sort the list by score
documents_with_scores.sort(key=lambda d: d.max_score, reverse=True)

# Only return the top docs_shown_max_num documents
return [d.document for d in documents_with_scores[:docs_shown_max_num]]


def format_bem_documents(
docs_shown_max_num: int,
docs_shown_min_score: float,
chunks_with_scores: Sequence[ChunkWithScore],
) -> str:
documents = _get_bem_documents_to_show(
docs_shown_max_num, docs_shown_min_score, chunks_with_scores
)
formatted_documents = "".join([f"<li>{document.name}</li>" for document in documents])
return f"<h3>Source(s)</h3><ul>{formatted_documents}</ul>"
11 changes: 9 additions & 2 deletions app/tests/src/test_chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src import chat_engine
from src.chat_engine import GuruMultiprogramEngine, GuruSnapEngine
from src.chat_engine import BridgesEligibilityManualEngine, GuruMultiprogramEngine, GuruSnapEngine


def test_available_engines():
Expand All @@ -8,7 +8,7 @@ def test_available_engines():
assert len(engines) > 0
assert "guru-multiprogram" in engines
assert "guru-snap" in engines
assert "policy-mi" in engines
assert "bridges-eligibility-manual" in engines


def test_create_engine_Guru_Multiprogram():
Expand All @@ -23,3 +23,10 @@ def test_create_engine_Guru_SNAP():
engine = chat_engine.create_engine(engine_id)
assert engine is not None
assert engine.name == GuruSnapEngine.name


def test_create_engine_BridgesEligibilityManualEngine():
engine_id = "bridges-eligibility-manual"
engine = chat_engine.create_engine(engine_id)
assert engine is not None
assert engine.name == BridgesEligibilityManualEngine.name
37 changes: 32 additions & 5 deletions app/tests/src/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from sqlalchemy import delete

from src.db.models.document import Chunk, ChunkWithScore, Document
from src.format import format_guru_cards
from src.db.models.document import ChunkWithScore, Document
from src.format import format_bem_documents, format_guru_cards
from src.retrieve import retrieve_with_scores
from tests.src.db.models.factories import ChunkFactory, DocumentFactory
from tests.src.test_retrieve import _create_chunks


Expand Down Expand Up @@ -43,9 +44,9 @@ def test_format_guru_cards_with_score(monkeypatch, app_config, db_session, enabl

def _chunks_with_scores():
return [
ChunkWithScore(Chunk(document=Document(name="name1", content="content1")), 0.99),
ChunkWithScore(Chunk(document=Document(name="name2", content="content2")), 0.90),
ChunkWithScore(Chunk(document=Document(name="name3", content="content3")), 0.85),
ChunkWithScore(ChunkFactory.build(), 0.99),
ChunkWithScore(ChunkFactory.build(), 0.90),
ChunkWithScore(ChunkFactory.build(), 0.85),
]


Expand All @@ -61,3 +62,29 @@ def test_format_guru_cards_given_docs_shown_max_num_and_min_score():
docs_shown_max_num=2, docs_shown_min_score=0.91, chunks_with_scores=_chunks_with_scores()
)
assert len(_unique_accordion_ids(html)) == 1


def test_format_bem_documents():
docs = DocumentFactory.build_batch(4)

chunks_with_scores = [
# This document is ignored because below docs_shown_min_score
ChunkWithScore(ChunkFactory.build(document=docs[0]), 0.90),
# This document is excluded because docs_shown_max_num = 2,
# and it has the lowest score of the three documents with chunks over
# the docs_shown_min_score threshold
ChunkWithScore(ChunkFactory.build(document=docs[1]), 0.92),
# This document is included because a chunk puts
# it over the docs_shown_min_score threshold
ChunkWithScore(ChunkFactory.build(document=docs[2]), 0.90),
ChunkWithScore(ChunkFactory.build(document=docs[2]), 0.93),
# This document is included, but only once
# And it will be displayed first because it has the highest score
ChunkWithScore(ChunkFactory.build(document=docs[3]), 0.94),
ChunkWithScore(ChunkFactory.build(document=docs[3]), 0.95),
]

html = format_bem_documents(
docs_shown_max_num=2, docs_shown_min_score=0.91, chunks_with_scores=chunks_with_scores
)
assert html == f"<h3>Source(s)</h3><ul><li>{docs[3].name}</li><li>{docs[2].name}</li></ul>"

0 comments on commit c61ec0a

Please sign in to comment.