Skip to content

Commit

Permalink
feat: Normalize similarity scores of retrieved Guru cards (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam authored Jul 25, 2024
1 parent aafb698 commit c753bb5
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 7 deletions.
6 changes: 5 additions & 1 deletion app/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ AWS_DEFAULT_REGION=us-east-1
# DST app configuration
###########################

EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-dot-v1
# Default chat engine
# CHAT_ENGINE=guru-snap

# Path to embedding model used for vector database
EMBEDDING_MODEL=/app/models/multi-qa-mpnet-base-cos-v1
2 changes: 1 addition & 1 deletion app/src/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AppConfig(PydanticBaseEnvConfig):
# To customize these values in deployed environments, set
# them in infra/app/app-config/env-config/environment-variables.tf

embedding_model: str = "multi-qa-mpnet-base-dot-v1"
embedding_model: str = "multi-qa-mpnet-base-cos-v1"
global_password: str | None = None
host: str = "127.0.0.1"
port: int = 8080
Expand Down
7 changes: 5 additions & 2 deletions app/src/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def retrieve_with_scores(
).all()

for chunk, score in chunks_with_scores:
logger.info(f"Retrieved: {chunk.document.name!r} with score {score}")
# Confirmed that the `max_inner_product` method returns the same score as using sentence_transformers.util.dot_score
# used in code at https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1
logger.info(f"Retrieved: {chunk.document.name!r} with score {-score}")

return [ChunkWithScore(chunk, score) for chunk, score in chunks_with_scores]
# Scores from the DB query are negated, presumably to reverse the default sort order
return [ChunkWithScore(chunk, -score) for chunk, score in chunks_with_scores]
33 changes: 33 additions & 0 deletions app/src/util/embedding_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sentence_transformers import SentenceTransformer, util


def test_sentence_transformer(embedding_model: str) -> None:
"""
Exercises specified embedding model and calculates scores from the embedding vectors.
The embedding models will be downloaded automatically to ~/.cache/huggingface/hub, if it does not already exist.
Used the scores to confirm/compare against those of pgvector's max_inner_product.
"""
transformer = SentenceTransformer(embedding_model)
# transformer.save(f"sentence_transformers/{embedding_model}")
text = "Curiosity inspires creative, innovative communities worldwide."
embedding = transformer.encode(text)
print("=== ", embedding_model, len(embedding))

for query in [
text,
"How does curiosity inspire communities?",
"What's the best pet?",
"What's the meaning of life?",
]:
query_embedding = transformer.encode(query)
# Code adapted from https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1
score = util.dot_score(embedding, query_embedding)
print("Score:", score.item(), "for:", query)


# To run: python -m src.util.embedding_models
if __name__ == "__main__":
embedding_models = ["multi-qa-mpnet-base-cos-v1", "multi-qa-mpnet-base-dot-v1"]
for model in embedding_models:
print(model)
test_sentence_transformer(model)
4 changes: 2 additions & 2 deletions app/tests/src/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def test_retrieve_with_scores(app_config, db_session, enable_factory_create):

assert len(results) == 2
assert results[0].chunk == short_chunk
assert results[0].score == -0.7071067690849304
assert results[0].score == 0.7071067690849304
assert results[1].chunk == medium_chunk
assert results[1].score == -0.25881901383399963
assert results[1].score == 0.25881901383399963
2 changes: 1 addition & 1 deletion docs/app/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ A very simple [docker-compose.yml](/docker-compose.yml) has been included to sup
**Note:** Run everything from within the `/app` folder:

1. Set up an (empty) local secrets file: `touch .env` and copy the provided example Docker override: `cp ../docker-compose.override.yml.example ../docker-compose.override.yml`
2. Download the `multi-qa-mpnet-base-dot-v1` model into the `models` directory: `git clone https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1 models/multi-qa-mpnet-base-dot-v1`
2. Download the embedding model into the `models` directory: `git clone https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1 models/multi-qa-mpnet-base-cos-v1`
3. Run `make init start` to build the image and start the container.
4. Navigate to `localhost:8000/chat` to access the Chainlit UI.
5. Run `make run-logs` to see the logs of the running application container
Expand Down

0 comments on commit c753bb5

Please sign in to comment.