Skip to content

Commit

Permalink
feat: Require login (#17)
Browse files Browse the repository at this point in the history
* Initial implementation

* Refine implementation

* PR Feedback

* Add comment to require_login
  • Loading branch information
KevinJBoyer authored Jul 2, 2024
1 parent 718763c commit cae05e6
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 24 deletions.
10 changes: 4 additions & 6 deletions app/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ LOG_ENABLE_AUDIT=FALSE
# Change the message length for the human readable formatter
# LOG_HUMAN_READABLE_FORMATTER__MESSAGE_WIDTH=50

############################
# Authentication
############################
# The auth token used by the local endpoints
API_AUTH_TOKEN=LOCAL_AUTH_12345678

############################
# DB Environment Variables
############################
Expand Down Expand Up @@ -72,4 +66,8 @@ AWS_SECRET_ACCESS_KEY=DO_NOT_SET_HERE

AWS_DEFAULT_REGION=us-east-1

###########################
# DST app configuration
###########################

EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-dot-v1
13 changes: 13 additions & 0 deletions app/src/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,17 @@


class AppConfig(PydanticBaseEnvConfig):
# These are default configuration values for the app, and
# are shared across both local and deployed environments.

# These values are overridden by environment variables.

# To override these values for local development, set them
# in .env (if they should be set just for you), or set
# them in local.env (if they should be committed to the repo.)

# To customize these values in deployed environments, set
# them in infra/app/app-config/env-config/environment-variables.tf

embedding_model: str = "multi-qa-mpnet-base-dot-v1"
global_password: str | None = None
9 changes: 4 additions & 5 deletions app/src/chainlit.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging

from sentence_transformers import SentenceTransformer

import chainlit as cl
import src.adapters.db as db
from src.app_config import AppConfig
from src.format import format_guru_cards
from src.generate import generate
from src.login import require_login
from src.retrieve import retrieve
from src.shared import get_embedding_model

logger = logging.getLogger(__name__)

embedding_model = SentenceTransformer(AppConfig().embedding_model)
require_login()


@cl.on_message
Expand All @@ -21,7 +20,7 @@ async def main(message: cl.Message) -> None:
with db.PostgresDBClient().get_session() as db_session:
chunks = retrieve(
db_session,
embedding_model,
get_embedding_model(),
message.content,
)

Expand Down
4 changes: 2 additions & 2 deletions app/src/ingest_guru_cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from smart_open import open

import src.adapters.db as db
from src.app_config import AppConfig
from src.db.models.document import Chunk, Document
from src.shared import get_embedding_model
from src.util.html import get_text_from_html

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,7 +59,7 @@ def main() -> None:

logger.info(f"Processing Guru cards at {guru_cards_filepath}")

embedding_model = SentenceTransformer(AppConfig().embedding_model)
embedding_model = get_embedding_model()

with db.PostgresDBClient().get_session() as db_session:
_ingest_cards(db_session, embedding_model, guru_cards_filepath)
Expand Down
18 changes: 18 additions & 0 deletions app/src/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import chainlit as cl
from src.shared import get_app_config


def login_callback(username: str, password: str) -> cl.User | None:
if password == get_app_config().global_password:
return cl.User(identifier=username)
else:
return None


def require_login() -> None:
# In addition to setting GLOBAL_PASSWORD, Chainlit also
# requires CHAINLIT_AUTH_SECRET to be set (to sign the
# authorization tokens.)

if get_app_config().global_password:
cl.password_auth_callback(login_callback)
15 changes: 15 additions & 0 deletions app/src/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from functools import cache

from sentence_transformers import SentenceTransformer

from src.app_config import AppConfig


@cache
def get_app_config() -> AppConfig:
return AppConfig()


@cache
def get_embedding_model() -> SentenceTransformer:
return SentenceTransformer(get_app_config().embedding_model)
34 changes: 34 additions & 0 deletions app/tests/src/test_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import chainlit.config
import src.shared
from src.login import login_callback, require_login


def test_require_login_no_password(monkeypatch):
if "GLOBAL_PASSWORD" in os.environ:
monkeypatch.delenv("GLOBAL_PASSWORD")

# Rebuild AppConfig with new environment variables
src.shared.get_app_config.cache_clear()

require_login()

assert not chainlit.config.code.password_auth_callback


def test_require_login_with_password(monkeypatch):
monkeypatch.setenv("GLOBAL_PASSWORD", "password")
src.shared.get_app_config.cache_clear()

require_login()

assert chainlit.config.code.password_auth_callback


def test_login_callback(monkeypatch):
monkeypatch.setenv("GLOBAL_PASSWORD", "correct pass")
src.shared.get_app_config.cache_clear()

assert login_callback("some user", "wrong pass") is None
assert login_callback("some user", "correct pass").identifier == "some user"
23 changes: 12 additions & 11 deletions infra/app/app-config/env-config/environment-variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ locals {
# }
# }
secrets = {
# Example generated secret
# RANDOM_SECRET = {
# manage_method = "generated"
# secret_store_name = "/${var.app_name}-${var.environment}/random-secret"
# }

# This is used by Chainlit to sign authentication tokens
# See: https://docs.chainlit.io/authentication/overview
CHAINLIT_AUTH_SECRET = {
manage_method = "generated"
secret_store_name = "/${var.app_name}-${var.environment}/CHAINLIT_AUTH_SECRET"
}

GLOBAL_PASSWORD = {
manage_method = "manual"
secret_store_name = "/${var.app_name}-${var.environment}/GLOBAL_PASSWORD"
}

OPENAI_API_KEY = {
manage_method = "manual"
Expand All @@ -33,11 +40,5 @@ locals {
manage_method = "manual"
secret_store_name = "/${var.app_name}-${var.environment}/LITERAL_API_KEY"
}

# Example secret that references a manually created secret
# SECRET_SAUCE = {
# manage_method = "manual"
# secret_store_name = "/${var.app_name}-${var.environment}/secret-sauce"
# }
}
}

0 comments on commit cae05e6

Please sign in to comment.