Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure UI feedback message for batch processing completion #170

27 changes: 26 additions & 1 deletion app/src/batch_process.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
import csv
import logging
import tempfile

import chainlit as cl
from src.chat_engine import ChatEngineInterface
from src.citations import simplify_citation_numbers

logger = logging.getLogger(__name__)


async def batch_process(file_path: str, engine: ChatEngineInterface) -> str:
logger.info("Starting batch processing of file: %r", file_path)
with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)

if not reader.fieldnames or "question" not in reader.fieldnames:
logger.error("Invalid CSV format: missing 'question' column in %r", file_path)
raise ValueError("CSV file must contain a 'question' column.")

rows = list(reader) # Convert reader to list to preserve order
questions = [row["question"] for row in rows]
total_questions = len(questions)
logger.info("Found %d questions to process", total_questions)

# Process questions sequentially to avoid thread-safety issues with LiteLLM
# Previous parallel implementation caused high CPU usage due to potential thread-safety
# concerns in the underlying LLM client libraries
processed_data = [_process_question(q, engine) for q in questions]
processed_data = []

progress_msg = cl.Message(content="Received file, starting batch processing...")
await progress_msg.send()

for i, q in enumerate(questions, 1):
# Update progress message
progress_msg.content = f"Processing question {i} of {total_questions}..."
await progress_msg.update()
logger.info("Processing question %d/%d", i, total_questions)

processed_data.append(_process_question(q, engine))

# Clean up progress message
await progress_msg.remove()

# Update rows with processed data while preserving original order
for row, data in zip(rows, processed_data, strict=True):
Expand All @@ -37,10 +59,12 @@ async def batch_process(file_path: str, engine: ChatEngineInterface) -> str:
writer.writerows(rows)
result_file.close()

logger.info("Batch processing complete. Results written to: %r", result_file.name)
return result_file.name


def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, str | None]:
logger.debug("Processing question: %r", question)
result = engine.on_message(question=question, chat_history=[])
final_result = simplify_citation_numbers(result)

Expand All @@ -58,4 +82,5 @@ def _process_question(question: str, engine: ChatEngineInterface) -> dict[str, s
citation_key + "_text": subsection.text,
}

logger.debug("Question processed with %d citations", len(final_result.subsections))
return result_table
12 changes: 4 additions & 8 deletions app/src/chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ async def on_message(message: cl.Message) -> None:
metadata=_get_retrieval_metadata(result),
).send()
except Exception as err: # pylint: disable=broad-exception-caught
logger.exception("Error processing message: %r", message.content)
await cl.Message(
author="backend",
metadata={"error_class": err.__class__.__name__, "error": str(err)},
Expand Down Expand Up @@ -248,26 +249,21 @@ def _get_retrieval_metadata(result: OnMessageResult) -> dict:


async def _batch_proccessing(file: AskFileResponse) -> None:
await cl.Message(
author="backend",
content="Received file, processing...",
).send()

try:
engine: chat_engine.ChatEngineInterface = cl.user_session.get("chat_engine")
result_file_path = await batch_process(file.path, engine)

# E.g., "abcd.csv" to "abcd_results.csv"
result_file_name = file.name.removesuffix(".csv") + "_results.csv"

await cl.Message(
author="backend",
fg-nava marked this conversation as resolved.
Show resolved Hide resolved
content="File processed, results attached.",
elements=[cl.File(name=result_file_name, path=result_file_path)],
fg-nava marked this conversation as resolved.
Show resolved Hide resolved
).send()

except ValueError as err:
logger.error("Error processing file %r: %s", file.name, err)
await cl.Message(
author="backend",
metadata={"error_class": err.__class__.__name__, "error": str(err)},
content=f"{err.__class__.__name__}: {err}",
content=f"Error processing file: {err}",
fg-nava marked this conversation as resolved.
Show resolved Hide resolved
).send()
20 changes: 19 additions & 1 deletion app/tests/src/test_batch_process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import AsyncMock, MagicMock

import pytest

from src import chat_engine
Expand Down Expand Up @@ -30,6 +32,22 @@ def invalid_csv(tmp_path):
return str(csv_path)


@pytest.fixture
def mock_chainlit_message(monkeypatch):
mock_message = MagicMock()
mock_message.send = AsyncMock()
mock_message.update = AsyncMock()
mock_message.remove = AsyncMock()

class MockMessage:
def __init__(self, content):
self.content = content
for attr, value in mock_message.__dict__.items():
setattr(self, attr, value)

monkeypatch.setattr("chainlit.Message", MockMessage)


@pytest.mark.asyncio
async def test_batch_process_invalid(invalid_csv, engine):
engine = chat_engine.create_engine("ca-edd-web")
Expand All @@ -38,7 +56,7 @@ async def test_batch_process_invalid(invalid_csv, engine):


@pytest.mark.asyncio
async def test_batch_process(monkeypatch, sample_csv, engine):
async def test_batch_process(monkeypatch, sample_csv, engine, mock_chainlit_message):
def mock__process_question(question, engine):
if question == "What is AI?":
return {"answer": "Answer to What is AI?", "field_2": "value_2"}
Expand Down
Loading