From ce5bc9ad3e3c1b73e2fe3f3555d77403588ed155 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 6 Aug 2024 08:09:02 -0500 Subject: [PATCH] feat: Use URL query params to set initial chat settings (#39) --- app/src/chainlit.py | 64 +++++++++++++++++++++------------- app/src/chat_engine.py | 1 + app/tests/src/test_chainlit.py | 18 ++++++++++ 3 files changed, 59 insertions(+), 24 deletions(-) create mode 100644 app/tests/src/test_chainlit.py diff --git a/app/src/chainlit.py b/app/src/chainlit.py index e4369a9d..39989b0d 100644 --- a/app/src/chainlit.py +++ b/app/src/chainlit.py @@ -19,7 +19,11 @@ @cl.on_chat_start async def start() -> None: - engine_id = engine_url_query_value() + url = cl.user_session.get("http_referer") + logger.debug("Referer URL: %s", url) + query_values = url_query_values(url) + + engine_id = query_values.pop("engine", app_config.chat_engine) logger.info("Engine ID: %s", engine_id) engine = _init_chat_engine(engine_id) @@ -31,7 +35,17 @@ async def start() -> None: ).send() return - settings = await _init_chat_settings(engine) + input_widgets = _init_chat_settings(engine, query_values) + settings = await cl.ChatSettings(input_widgets).send() + logger.info("Initialized settings: %s", pprint.pformat(settings, indent=4)) + if query_values: + logger.warning("Unused URL query parameters: %r", query_values) + await cl.Message( + author="backend", + metadata={"url": url}, + content=f"Unused URL query parameters: {query_values}", + ).send() + await cl.Message( author="backend", metadata={"engine": engine_id, "settings": str(settings)}, @@ -39,6 +53,15 @@ async def start() -> None: ).send() +def url_query_values(url: str) -> dict[str, str]: + # Using this suggestion: https://github.com/Chainlit/chainlit/issues/144#issuecomment-2227543547 + parsed_url = urlparse(url) + # For a given query key, only the first value is used + query_values = {key: values[0] for key, values in parse_qs(parsed_url.query).items()} + logger.info("URL query values: %r", query_values) + return query_values + + def _init_chat_engine(engine_id: str) -> ChatEngineInterface | None: engine = chat_engine.create_engine(engine_id) if engine: @@ -47,16 +70,19 @@ def _init_chat_engine(engine_id: str) -> ChatEngineInterface | None: return None -async def _init_chat_settings(engine: ChatEngineInterface) -> dict[str, Any]: +def _init_chat_settings( + engine: ChatEngineInterface, query_values: dict[str, str] +) -> list[InputWidget]: input_widgets: list[InputWidget] = [ - _WIDGET_FACTORIES[setting_name](getattr(engine, setting_name)) + _WIDGET_FACTORIES[setting_name]( + query_values.pop(setting_name, None) + or getattr(app_config, setting_name, None) + or getattr(engine, setting_name) + ) for setting_name in engine.user_settings if setting_name in _WIDGET_FACTORIES ] - input_widgets.append(_WIDGET_FACTORIES["llm"](app_config.llm or getattr(engine, "llm", None))) - settings = await cl.ChatSettings(input_widgets).send() - logger.info("Initialized settings: %s", pprint.pformat(settings, indent=4)) - return settings + return input_widgets @cl.on_settings_update @@ -83,26 +109,26 @@ def update_settings(settings: dict[str, Any]) -> Any: max=10, step=1, ), - "retrieval_k_min_score": lambda default_value: Slider( + "retrieval_k_min_score": lambda initial_value: Slider( id="retrieval_k_min_score", label="Minimum document score required for generating LLM response", - initial=default_value, + initial=initial_value, min=-1, max=1, step=0.25, ), - "docs_shown_max_num": lambda default_value: Slider( + "docs_shown_max_num": lambda initial_value: Slider( id="docs_shown_max_num", label="Maximum number of retrieved documents to show in the UI", - initial=default_value, + initial=initial_value, min=0, max=10, step=1, ), - "docs_shown_min_score": lambda default_value: Slider( + "docs_shown_min_score": lambda initial_value: Slider( id="docs_shown_min_score", label="Minimum document score required to show document in the UI", - initial=default_value, + initial=initial_value, min=-1, max=1, step=0.25, @@ -110,16 +136,6 @@ def update_settings(settings: dict[str, Any]) -> Any: } -def engine_url_query_value() -> str: - url = cl.user_session.get("http_referer") - logger.debug("Referer URL: %s", url) - - # Using this suggestion: https://github.com/Chainlit/chainlit/issues/144#issuecomment-2227543547 - parsed_url = urlparse(url) - qs = parse_qs(parsed_url.query) - return qs.get("engine", [app_config.chat_engine])[0] - - @cl.on_message async def on_message(message: cl.Message) -> None: logger.info("Received: %r", message.content) diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index c4daaaf4..cf5e61da 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -67,6 +67,7 @@ class GuruBaseEngine(ChatEngineInterface): retrieval_k_min_score: float = 0.45 user_settings = [ + "llm", "retrieval_k", "retrieval_k_min_score", "docs_shown_max_num", diff --git a/app/tests/src/test_chainlit.py b/app/tests/src/test_chainlit.py new file mode 100644 index 00000000..bc2a4bbb --- /dev/null +++ b/app/tests/src/test_chainlit.py @@ -0,0 +1,18 @@ +from src import chainlit, chat_engine + + +def test_url_query_values(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "mock_key") + + url = "https://example.com/chat/?engine=guru-snap&llm=gpt-4o&retrieval_k=3&someunknownparam=42" + query_values = chainlit.url_query_values(url) + engine_id = query_values.pop("engine") + assert engine_id == "guru-snap" + + engine = chat_engine.create_engine(engine_id) + input_widgets = chainlit._init_chat_settings(engine, query_values) + assert len(input_widgets) == len(engine.user_settings) + + # Only 1 query parameter remains + assert len(query_values) == 1 + assert query_values["someunknownparam"] == "42"