Skip to content

Commit

Permalink
Merge branch 'main' into DST-304-metadataFix
Browse files Browse the repository at this point in the history
  • Loading branch information
ccheng26 authored Jul 26, 2024
2 parents f10f85f + c753bb5 commit b438c1c
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 41 deletions.
3 changes: 3 additions & 0 deletions app/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,6 @@ endif

ingest-guru-cards: check-ingest-arguments
$(PY_RUN_CMD) ingest-guru-cards $(DATASET_ID) $(BENEFIT_PROGRAM) $(BENEFIT_REGION) $(FILEPATH)

ingest-policy-pdfs: check-ingest-arguments
$(PY_RUN_CMD) ingest-policy-pdfs $(DATASET_ID) $(BENEFIT_PROGRAM) $(BENEFIT_REGION) $(FILEPATH)
6 changes: 5 additions & 1 deletion app/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ AWS_DEFAULT_REGION=us-east-1
# DST app configuration
###########################

EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-dot-v1
# Default chat engine
# CHAT_ENGINE=guru-snap

# Path to embedding model used for vector database
EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-cos-v1
1 change: 1 addition & 0 deletions app/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ db-migrate = "src.db.migrations.run:up"
db-migrate-down = "src.db.migrations.run:down"
db-migrate-down-all = "src.db.migrations.run:downall"
ingest-guru-cards = "src.ingest_guru_cards:main"
ingest-policy-pdfs = "src.ingest_policy_pdfs:main"

[tool.black]
line-length = 100
Expand Down
2 changes: 1 addition & 1 deletion app/src/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AppConfig(PydanticBaseEnvConfig):
# To customize these values in deployed environments, set
# them in infra/app/app-config/env-config/environment-variables.tf

embedding_model: str = "multi-qa-mpnet-base-dot-v1"
embedding_model: str = "multi-qa-mpnet-base-cos-v1"
global_password: str | None = None
host: str = "127.0.0.1"
port: int = 8080
Expand Down
22 changes: 2 additions & 20 deletions app/src/ingest_guru_cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.app_config import app_config
from src.db.models.document import Chunk, Document
from src.util.html import get_text_from_html
from src.util.ingest_utils import process_and_ingest_sys_args

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,23 +61,4 @@ def main() -> None:
)
return

# TODO: improve command-line argument handling using getopt module
dataset_id = sys.argv[1]
benefit_program = sys.argv[2]
benefit_region = sys.argv[3]
guru_cards_filepath = sys.argv[4]

logger.info(
f"Processing Guru cards {dataset_id} at {guru_cards_filepath} for {benefit_program} in {benefit_region}"
)

doc_attribs = {
"dataset": dataset_id,
"program": benefit_program,
"region": benefit_region,
}
with app_config.db_session() as db_session:
_ingest_cards(db_session, guru_cards_filepath, doc_attribs)
db_session.commit()

logger.info("Finished processing Guru cards.")
process_and_ingest_sys_args(sys, logger, _ingest_cards)
38 changes: 38 additions & 0 deletions app/src/ingest_policy_pdfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
import sys

import src.adapters.db as db
from src.app_config import app_config
from src.util.file_util import get_files
from src.util.ingest_utils import process_and_ingest_sys_args

logger = logging.getLogger(__name__)

# Print INFO messages since this is often run from the terminal
# during local development
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def _ingest_policy_pdfs(
db_session: db.Session,
pdf_file_dir: str,
doc_attribs: dict[str, str],
) -> None:
file_list = get_files(pdf_file_dir)
embedding_model = app_config.sentence_transformer
for file in file_list:
if file.endswith(".pdf"):
logger.info(
f"Processing pdf file: {file} at {pdf_file_dir} using {embedding_model}, {db_session}, with {doc_attribs}"
)


def main() -> None:
if len(sys.argv) < 5:
logger.warning(
"Expecting 4 arguments: DATASET_ID BENEFIT_PROGRAM BENEFIT_REGION FILEPATH\n but got: %s",
sys.argv[1:],
)
return

process_and_ingest_sys_args(sys, logger, _ingest_policy_pdfs)
7 changes: 5 additions & 2 deletions app/src/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def retrieve_with_scores(
).all()

for chunk, score in chunks_with_scores:
logger.info(f"Retrieved: {chunk.document.name!r} with score {score}")
# Confirmed that the `max_inner_product` method returns the same score as using sentence_transformers.util.dot_score
# used in code at https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1
logger.info(f"Retrieved: {chunk.document.name!r} with score {-score}")

return [ChunkWithScore(chunk, score) for chunk, score in chunks_with_scores]
# Scores from the DB query are negated, presumably to reverse the default sort order
return [ChunkWithScore(chunk, -score) for chunk, score in chunks_with_scores]
33 changes: 33 additions & 0 deletions app/src/util/embedding_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sentence_transformers import SentenceTransformer, util


def test_sentence_transformer(embedding_model: str) -> None:
"""
Exercises specified embedding model and calculates scores from the embedding vectors.
The embedding models will be downloaded automatically to ~/.cache/huggingface/hub, if it does not already exist.
Used the scores to confirm/compare against those of pgvector's max_inner_product.
"""
transformer = SentenceTransformer(embedding_model)
# transformer.save(f"sentence_transformers/{embedding_model}")
text = "Curiosity inspires creative, innovative communities worldwide."
embedding = transformer.encode(text)
print("=== ", embedding_model, len(embedding))

for query in [
text,
"How does curiosity inspire communities?",
"What's the best pet?",
"What's the meaning of life?",
]:
query_embedding = transformer.encode(query)
# Code adapted from https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1
score = util.dot_score(embedding, query_embedding)
print("Score:", score.item(), "for:", query)


# To run: python -m src.util.embedding_models
if __name__ == "__main__":
embedding_models = ["multi-qa-mpnet-base-cos-v1", "multi-qa-mpnet-base-dot-v1"]
for model in embedding_models:
print(model)
test_sentence_transformer(model)
12 changes: 12 additions & 0 deletions app/src/util/file_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def get_file_name(path: str) -> str:
return os.path.basename(path)


def get_files(path: str) -> list[str]:
"""Return a list of paths to all files in a directory, whether on local disk or on S3"""
if is_s3_path(path):
bucket_name, prefix = split_s3_url(path)
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)
files = [f"s3://{bucket_name}/{obj.key}" for obj in bucket.objects.filter(Prefix=prefix)]
return files

return [str(file) for file in PosixPath(path).rglob("*") if file.is_file()]


##################################
# S3 Utilities
##################################
Expand Down
35 changes: 35 additions & 0 deletions app/src/util/ingest_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import getopt
from logging import Logger
from types import ModuleType
from typing import Callable

from src.app_config import app_config


def process_and_ingest_sys_args(sys: ModuleType, logger: Logger, ingestion_call: Callable) -> None:
"""Method that reads sys args and passes them into ingestion call"""

opts, args = getopt.getopt(
sys.argv[1:], shortopts="", longopts=["DATASET_ID BENEFIT_PROGRAM BENEFIT_REGION FILEPATH)"]
)

dataset_id = args[0]
benefit_program = args[1]
benefit_region = args[2]
pdf_file_dir = args[3]

logger.info(
f"Processing files {dataset_id} at {pdf_file_dir} for {benefit_program} in {benefit_region}"
)

doc_attribs = {
"dataset": dataset_id,
"program": benefit_program,
"region": benefit_region,
}

with app_config.db_session() as db_session:
ingestion_call(db_session, pdf_file_dir, doc_attribs)
db_session.commit()

logger.info("Finished processing")
15 changes: 2 additions & 13 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from functools import cached_property

import _pytest.monkeypatch
import boto3
Expand Down Expand Up @@ -109,18 +108,8 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session:

@pytest.fixture
def app_config(monkeypatch, db_session):
class MockAppConfig:
def db_session(self):
return db_session

@cached_property
def sentence_transformer(self):
return MockSentenceTransformer()

mock_app_config = MockAppConfig()
monkeypatch.setattr(AppConfig, "db_session", mock_app_config.db_session)
monkeypatch.setattr(AppConfig, "sentence_transformer", mock_app_config.sentence_transformer)
return mock_app_config
monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_session)
monkeypatch.setattr(AppConfig, "sentence_transformer", MockSentenceTransformer())


####################
Expand Down
2 changes: 1 addition & 1 deletion app/tests/src/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _get_chunks_with_scores():
return retrieve_with_scores("Very tiny words.", k=2)


def test_format_guru_cards_with_score(monkeypatch, app_config, db_session, enable_factory_create):
def test_format_guru_cards_with_score(app_config, db_session, enable_factory_create):
db_session.execute(delete(Document))

chunks_with_scores = _get_chunks_with_scores()
Expand Down
46 changes: 46 additions & 0 deletions app/tests/src/test_ingest_policy_pdfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import io
import logging
import tempfile

import pytest
from sqlalchemy import delete

from src.db.models.document import Document
from src.ingest_policy_pdfs import _ingest_policy_pdfs


@pytest.fixture
def policy_local_file():
with tempfile.TemporaryDirectory() as tmpdirname:
tempfile.NamedTemporaryFile(prefix="policy", suffix=".pdf", dir=tmpdirname, delete=False)
yield tmpdirname


@pytest.fixture
def policy_s3_file(mock_s3_bucket_resource):
mock_s3_bucket_resource.put_object(
Body=io.BytesIO(b"%PDF-1.4\n%Fake PDF content for testing\n"), Key="policy.pdf"
)
return "s3://test_bucket/policy.pdf"


doc_attribs = {
"dataset": "test_dataset",
"program": "test_benefit_program",
"region": "Michigan",
}


@pytest.mark.parametrize("file_location", ["local", "s3"])
def test__ingest_policy_pdfs(
caplog, app_config, db_session, policy_s3_file, policy_local_file, file_location
):
db_session.execute(delete(Document))

with caplog.at_level(logging.INFO):
if file_location == "local":
_ingest_policy_pdfs(db_session, policy_local_file, doc_attribs)
else:
_ingest_policy_pdfs(db_session, policy_s3_file, doc_attribs)

assert any(text.startswith("Processing pdf file:") for text in caplog.messages)
4 changes: 2 additions & 2 deletions app/tests/src/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def test_retrieve_with_scores(app_config, db_session, enable_factory_create):

assert len(results) == 2
assert results[0].chunk == short_chunk
assert results[0].score == -0.7071067690849304
assert results[0].score == 0.7071067690849304
assert results[1].chunk == medium_chunk
assert results[1].score == -0.25881901383399963
assert results[1].score == 0.25881901383399963
2 changes: 1 addition & 1 deletion docs/app/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ A very simple [docker-compose.yml](/docker-compose.yml) has been included to sup
**Note:** Run everything from within the `/app` folder:

1. Set up an (empty) local secrets file: `touch .env` and copy the provided example Docker override: `cp ../docker-compose.override.yml.example ../docker-compose.override.yml`
2. Download the `multi-qa-mpnet-base-dot-v1` model into the `models` directory: `git clone https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1 models/multi-qa-mpnet-base-dot-v1`
2. Download the embedding model into the `models` directory: `git clone https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1 models/multi-qa-mpnet-base-cos-v1`
3. Run `make init start` to build the image and start the container.
4. Navigate to `localhost:8000/chat` to access the Chainlit UI.
5. Run `make run-logs` to see the logs of the running application container
Expand Down

0 comments on commit b438c1c

Please sign in to comment.