From 025092df145f5c5a06f65b1842c180cd8b12eb98 Mon Sep 17 00:00:00 2001 From: fg-nava <189638926+fg-nava@users.noreply.github.com> Date: Fri, 10 Jan 2025 10:28:28 -0800 Subject: [PATCH] fix: Ensure UI feedback message for batch processing completion (#170) --- app/src/batch_process.py | 27 ++++++++++++++++++++++++++- app/src/chainlit.py | 12 ++++-------- app/tests/src/test_batch_process.py | 20 +++++++++++++++++++- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/app/src/batch_process.py b/app/src/batch_process.py index 9853afeb..41882d6b 100644 --- a/app/src/batch_process.py +++ b/app/src/batch_process.py @@ -1,24 +1,46 @@ import csv +import logging import tempfile +import chainlit as cl from src.chat_engine import ChatEngineInterface from src.citations import simplify_citation_numbers +logger = logging.getLogger(__name__) + async def batch_process(file_path: str, engine: ChatEngineInterface) -> str: + logger.info("Starting batch processing of file: %r", file_path) with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) if not reader.fieldnames or "question" not in reader.fieldnames: + logger.error("Invalid CSV format: missing 'question' column in %r", file_path) raise ValueError("CSV file must contain a 'question' column.") rows = list(reader) # Convert reader to list to preserve order questions = [row["question"] for row in rows] + total_questions = len(questions) + logger.info("Found %d questions to process", total_questions) # Process questions sequentially to avoid thread-safety issues with LiteLLM # Previous parallel implementation caused high CPU usage due to potential thread-safety # concerns in the underlying LLM client libraries - processed_data = [_process_question(q, engine) for q in questions] + processed_data = [] + + progress_msg = cl.Message(content="Received file, starting batch processing...") + await progress_msg.send() + + for i, q in enumerate(questions, 1): + # Update progress message + progress_msg.content = f"Processing question {i} of {total_questions}..." + await progress_msg.update() + logger.info("Processing question %d/%d", i, total_questions) + + processed_data.append(_process_question(q, engine)) + + # Clean up progress message + await progress_msg.remove() # Update rows with processed data while preserving original order for row, data in zip(rows, processed_data, strict=True): @@ -37,10 +59,12 @@ async def batch_process(file_path: str, engine: ChatEngineInterface) -> str: writer.writerows(rows) result_file.close() + logger.info("Batch processing complete. Results written to: %r", result_file.name) return result_file.name def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, str | None]: + logger.debug("Processing question: %r", question) result = engine.on_message(question=question, chat_history=[]) final_result = simplify_citation_numbers(result) @@ -58,4 +82,5 @@ def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, s citation_key + "_text": subsection.text, } + logger.debug("Question processed with %d citations", len(final_result.subsections)) return result_table diff --git a/app/src/chainlit.py b/app/src/chainlit.py index dda635ea..c5b2a45d 100644 --- a/app/src/chainlit.py +++ b/app/src/chainlit.py @@ -213,6 +213,7 @@ async def on_message(message: cl.Message) -> None: metadata=_get_retrieval_metadata(result), ).send() except Exception as err: # pylint: disable=broad-exception-caught + logger.exception("Error processing message: %r", message.content) await cl.Message( author="backend", metadata={"error_class": err.__class__.__name__, "error": str(err)}, @@ -248,26 +249,21 @@ def _get_retrieval_metadata(result: OnMessageResult) -> dict: async def _batch_proccessing(file: AskFileResponse) -> None: - await cl.Message( - author="backend", - content="Received file, processing...", - ).send() - try: engine: chat_engine.ChatEngineInterface = cl.user_session.get("chat_engine") result_file_path = await batch_process(file.path, engine) # E.g., "abcd.csv" to "abcd_results.csv" result_file_name = file.name.removesuffix(".csv") + "_results.csv" - await cl.Message( + author="backend", content="File processed, results attached.", elements=[cl.File(name=result_file_name, path=result_file_path)], ).send() except ValueError as err: + logger.error("Error processing file %r: %s", file.name, err) await cl.Message( author="backend", - metadata={"error_class": err.__class__.__name__, "error": str(err)}, - content=f"{err.__class__.__name__}: {err}", + content=f"Error processing file: {err}", ).send() diff --git a/app/tests/src/test_batch_process.py b/app/tests/src/test_batch_process.py index ae27fc08..19119cdb 100644 --- a/app/tests/src/test_batch_process.py +++ b/app/tests/src/test_batch_process.py @@ -1,3 +1,5 @@ +from unittest.mock import AsyncMock, MagicMock + import pytest from src import chat_engine @@ -30,6 +32,22 @@ def invalid_csv(tmp_path): return str(csv_path) +@pytest.fixture +def mock_chainlit_message(monkeypatch): + mock_message = MagicMock() + mock_message.send = AsyncMock() + mock_message.update = AsyncMock() + mock_message.remove = AsyncMock() + + class MockMessage: + def __init__(self, content): + self.content = content + for attr, value in mock_message.__dict__.items(): + setattr(self, attr, value) + + monkeypatch.setattr("chainlit.Message", MockMessage) + + @pytest.mark.asyncio async def test_batch_process_invalid(invalid_csv, engine): engine = chat_engine.create_engine("ca-edd-web") @@ -38,7 +56,7 @@ async def test_batch_process_invalid(invalid_csv, engine): @pytest.mark.asyncio -async def test_batch_process(monkeypatch, sample_csv, engine): +async def test_batch_process(monkeypatch, sample_csv, engine, mock_chainlit_message): def mock__process_question(question, engine): if question == "What is AI?": return {"answer": "Answer to What is AI?", "field_2": "value_2"}