From cae05e663dcfc6bd0b17466d1465051d9e5c7276 Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Tue, 2 Jul 2024 13:40:43 -0400 Subject: [PATCH] feat: Require login (#17) * Initial implementation * Refine implementation * PR Feedback * Add comment to require_login --- app/local.env | 10 +++--- app/src/app_config.py | 13 +++++++ app/src/chainlit.py | 9 +++-- app/src/ingest_guru_cards.py | 4 +-- app/src/login.py | 18 ++++++++++ app/src/shared.py | 15 ++++++++ app/tests/src/test_login.py | 34 +++++++++++++++++++ .../env-config/environment-variables.tf | 23 +++++++------ 8 files changed, 102 insertions(+), 24 deletions(-) create mode 100644 app/src/login.py create mode 100644 app/src/shared.py create mode 100644 app/tests/src/test_login.py diff --git a/app/local.env b/app/local.env index 99cc9889..f4bd6e74 100644 --- a/app/local.env +++ b/app/local.env @@ -31,12 +31,6 @@ LOG_ENABLE_AUDIT=FALSE # Change the message length for the human readable formatter # LOG_HUMAN_READABLE_FORMATTER__MESSAGE_WIDTH=50 -############################ -# Authentication -############################ -# The auth token used by the local endpoints -API_AUTH_TOKEN=LOCAL_AUTH_12345678 - ############################ # DB Environment Variables ############################ @@ -72,4 +66,8 @@ AWS_SECRET_ACCESS_KEY=DO_NOT_SET_HERE AWS_DEFAULT_REGION=us-east-1 +########################### +# DST app configuration +########################### + EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-dot-v1 \ No newline at end of file diff --git a/app/src/app_config.py b/app/src/app_config.py index fb13ff33..5cbe8294 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -2,4 +2,17 @@ class AppConfig(PydanticBaseEnvConfig): + # These are default configuration values for the app, and + # are shared across both local and deployed environments. + + # These values are overridden by environment variables. + + # To override these values for local development, set them + # in .env (if they should be set just for you), or set + # them in local.env (if they should be committed to the repo.) + + # To customize these values in deployed environments, set + # them in infra/app/app-config/env-config/environment-variables.tf + embedding_model: str = "multi-qa-mpnet-base-dot-v1" + global_password: str | None = None diff --git a/app/src/chainlit.py b/app/src/chainlit.py index 2d5dffd5..fe197a7e 100644 --- a/app/src/chainlit.py +++ b/app/src/chainlit.py @@ -1,17 +1,16 @@ import logging -from sentence_transformers import SentenceTransformer - import chainlit as cl import src.adapters.db as db -from src.app_config import AppConfig from src.format import format_guru_cards from src.generate import generate +from src.login import require_login from src.retrieve import retrieve +from src.shared import get_embedding_model logger = logging.getLogger(__name__) -embedding_model = SentenceTransformer(AppConfig().embedding_model) +require_login() @cl.on_message @@ -21,7 +20,7 @@ async def main(message: cl.Message) -> None: with db.PostgresDBClient().get_session() as db_session: chunks = retrieve( db_session, - embedding_model, + get_embedding_model(), message.content, ) diff --git a/app/src/ingest_guru_cards.py b/app/src/ingest_guru_cards.py index e47f9492..2ce039d0 100644 --- a/app/src/ingest_guru_cards.py +++ b/app/src/ingest_guru_cards.py @@ -6,8 +6,8 @@ from smart_open import open import src.adapters.db as db -from src.app_config import AppConfig from src.db.models.document import Chunk, Document +from src.shared import get_embedding_model from src.util.html import get_text_from_html logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def main() -> None: logger.info(f"Processing Guru cards at {guru_cards_filepath}") - embedding_model = SentenceTransformer(AppConfig().embedding_model) + embedding_model = get_embedding_model() with db.PostgresDBClient().get_session() as db_session: _ingest_cards(db_session, embedding_model, guru_cards_filepath) diff --git a/app/src/login.py b/app/src/login.py new file mode 100644 index 00000000..0695ef23 --- /dev/null +++ b/app/src/login.py @@ -0,0 +1,18 @@ +import chainlit as cl +from src.shared import get_app_config + + +def login_callback(username: str, password: str) -> cl.User | None: + if password == get_app_config().global_password: + return cl.User(identifier=username) + else: + return None + + +def require_login() -> None: + # In addition to setting GLOBAL_PASSWORD, Chainlit also + # requires CHAINLIT_AUTH_SECRET to be set (to sign the + # authorization tokens.) + + if get_app_config().global_password: + cl.password_auth_callback(login_callback) diff --git a/app/src/shared.py b/app/src/shared.py new file mode 100644 index 00000000..a22b3c9f --- /dev/null +++ b/app/src/shared.py @@ -0,0 +1,15 @@ +from functools import cache + +from sentence_transformers import SentenceTransformer + +from src.app_config import AppConfig + + +@cache +def get_app_config() -> AppConfig: + return AppConfig() + + +@cache +def get_embedding_model() -> SentenceTransformer: + return SentenceTransformer(get_app_config().embedding_model) diff --git a/app/tests/src/test_login.py b/app/tests/src/test_login.py new file mode 100644 index 00000000..799a31e1 --- /dev/null +++ b/app/tests/src/test_login.py @@ -0,0 +1,34 @@ +import os + +import chainlit.config +import src.shared +from src.login import login_callback, require_login + + +def test_require_login_no_password(monkeypatch): + if "GLOBAL_PASSWORD" in os.environ: + monkeypatch.delenv("GLOBAL_PASSWORD") + + # Rebuild AppConfig with new environment variables + src.shared.get_app_config.cache_clear() + + require_login() + + assert not chainlit.config.code.password_auth_callback + + +def test_require_login_with_password(monkeypatch): + monkeypatch.setenv("GLOBAL_PASSWORD", "password") + src.shared.get_app_config.cache_clear() + + require_login() + + assert chainlit.config.code.password_auth_callback + + +def test_login_callback(monkeypatch): + monkeypatch.setenv("GLOBAL_PASSWORD", "correct pass") + src.shared.get_app_config.cache_clear() + + assert login_callback("some user", "wrong pass") is None + assert login_callback("some user", "correct pass").identifier == "some user" diff --git a/infra/app/app-config/env-config/environment-variables.tf b/infra/app/app-config/env-config/environment-variables.tf index 9a9f1d0e..a8a73c25 100644 --- a/infra/app/app-config/env-config/environment-variables.tf +++ b/infra/app/app-config/env-config/environment-variables.tf @@ -19,11 +19,18 @@ locals { # } # } secrets = { - # Example generated secret - # RANDOM_SECRET = { - # manage_method = "generated" - # secret_store_name = "/${var.app_name}-${var.environment}/random-secret" - # } + + # This is used by Chainlit to sign authentication tokens + # See: https://docs.chainlit.io/authentication/overview + CHAINLIT_AUTH_SECRET = { + manage_method = "generated" + secret_store_name = "/${var.app_name}-${var.environment}/CHAINLIT_AUTH_SECRET" + } + + GLOBAL_PASSWORD = { + manage_method = "manual" + secret_store_name = "/${var.app_name}-${var.environment}/GLOBAL_PASSWORD" + } OPENAI_API_KEY = { manage_method = "manual" @@ -33,11 +40,5 @@ locals { manage_method = "manual" secret_store_name = "/${var.app_name}-${var.environment}/LITERAL_API_KEY" } - - # Example secret that references a manually created secret - # SECRET_SAUCE = { - # manage_method = "manual" - # secret_store_name = "/${var.app_name}-${var.environment}/secret-sauce" - # } } }