Skip to content

Commit

Permalink
fix: Ensure UI feedback message for batch processing completion (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
fg-nava authored Jan 10, 2025
1 parent 91e3d27 commit 025092d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
27 changes: 26 additions & 1 deletion app/src/batch_process.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
import csv
import logging
import tempfile

import chainlit as cl
from src.chat_engine import ChatEngineInterface
from src.citations import simplify_citation_numbers

logger = logging.getLogger(__name__)


async def batch_process(file_path: str, engine: ChatEngineInterface) -> str:
logger.info("Starting batch processing of file: %r", file_path)
with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)

if not reader.fieldnames or "question" not in reader.fieldnames:
logger.error("Invalid CSV format: missing 'question' column in %r", file_path)
raise ValueError("CSV file must contain a 'question' column.")

rows = list(reader) # Convert reader to list to preserve order
questions = [row["question"] for row in rows]
total_questions = len(questions)
logger.info("Found %d questions to process", total_questions)

# Process questions sequentially to avoid thread-safety issues with LiteLLM
# Previous parallel implementation caused high CPU usage due to potential thread-safety
# concerns in the underlying LLM client libraries
processed_data = [_process_question(q, engine) for q in questions]
processed_data = []

progress_msg = cl.Message(content="Received file, starting batch processing...")
await progress_msg.send()

for i, q in enumerate(questions, 1):
# Update progress message
progress_msg.content = f"Processing question {i} of {total_questions}..."
await progress_msg.update()
logger.info("Processing question %d/%d", i, total_questions)

processed_data.append(_process_question(q, engine))

# Clean up progress message
await progress_msg.remove()

# Update rows with processed data while preserving original order
for row, data in zip(rows, processed_data, strict=True):
Expand All @@ -37,10 +59,12 @@ async def batch_process(file_path: str, engine: ChatEngineInterface) -> str:
writer.writerows(rows)
result_file.close()

logger.info("Batch processing complete. Results written to: %r", result_file.name)
return result_file.name


def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, str | None]:
logger.debug("Processing question: %r", question)
result = engine.on_message(question=question, chat_history=[])
final_result = simplify_citation_numbers(result)

Expand All @@ -58,4 +82,5 @@ def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, s
citation_key + "_text": subsection.text,
}

logger.debug("Question processed with %d citations", len(final_result.subsections))
return result_table
12 changes: 4 additions & 8 deletions app/src/chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ async def on_message(message: cl.Message) -> None:
metadata=_get_retrieval_metadata(result),
).send()
except Exception as err: # pylint: disable=broad-exception-caught
logger.exception("Error processing message: %r", message.content)
await cl.Message(
author="backend",
metadata={"error_class": err.__class__.__name__, "error": str(err)},
Expand Down Expand Up @@ -248,26 +249,21 @@ def _get_retrieval_metadata(result: OnMessageResult) -> dict:


async def _batch_proccessing(file: AskFileResponse) -> None:
await cl.Message(
author="backend",
content="Received file, processing...",
).send()

try:
engine: chat_engine.ChatEngineInterface = cl.user_session.get("chat_engine")
result_file_path = await batch_process(file.path, engine)

# E.g., "abcd.csv" to "abcd_results.csv"
result_file_name = file.name.removesuffix(".csv") + "_results.csv"

await cl.Message(
author="backend",
content="File processed, results attached.",
elements=[cl.File(name=result_file_name, path=result_file_path)],
).send()

except ValueError as err:
logger.error("Error processing file %r: %s", file.name, err)
await cl.Message(
author="backend",
metadata={"error_class": err.__class__.__name__, "error": str(err)},
content=f"{err.__class__.__name__}: {err}",
content=f"Error processing file: {err}",
).send()
20 changes: 19 additions & 1 deletion app/tests/src/test_batch_process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import AsyncMock, MagicMock

import pytest

from src import chat_engine
Expand Down Expand Up @@ -30,6 +32,22 @@ def invalid_csv(tmp_path):
return str(csv_path)


@pytest.fixture
def mock_chainlit_message(monkeypatch):
mock_message = MagicMock()
mock_message.send = AsyncMock()
mock_message.update = AsyncMock()
mock_message.remove = AsyncMock()

class MockMessage:
def __init__(self, content):
self.content = content
for attr, value in mock_message.__dict__.items():
setattr(self, attr, value)

monkeypatch.setattr("chainlit.Message", MockMessage)


@pytest.mark.asyncio
async def test_batch_process_invalid(invalid_csv, engine):
engine = chat_engine.create_engine("ca-edd-web")
Expand All @@ -38,7 +56,7 @@ async def test_batch_process_invalid(invalid_csv, engine):


@pytest.mark.asyncio
async def test_batch_process(monkeypatch, sample_csv, engine):
async def test_batch_process(monkeypatch, sample_csv, engine, mock_chainlit_message):
def mock__process_question(question, engine):
if question == "What is AI?":
return {"answer": "Answer to What is AI?", "field_2": "value_2"}
Expand Down

0 comments on commit 025092d

Please sign in to comment.