-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add MockSentenceTransformer (#5)
* Add MockSentenceTransformer * Fix typos * Disable vulnerability scans
- Loading branch information
1 parent
8e604d1
commit 6596da7
Showing
9 changed files
with
1,086 additions
and
14 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,4 +31,7 @@ coverage.* | |
poetry-installer-error-*.log | ||
|
||
# Chainlit | ||
.chainlit | ||
.chainlit | ||
|
||
# Cached models | ||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
|
||
|
||
class AppConfig(PydanticBaseEnvConfig): | ||
... | ||
embedding_model: str = "multi-qa-mpnet-base-dot-v1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import math | ||
|
||
|
||
class MockSentenceTransformer: | ||
def __init__(self, *args, **kwargs): | ||
# Imitate multi-qa-mpnet-base-dot-v1 | ||
self.max_seq_length = 512 | ||
self.tokenizer = MockTokenizer() | ||
|
||
def encode(self, text, **kwargs): | ||
""" | ||
Encode text into a 768-dimensional embedding that allows for similarity comparison via the dot product. | ||
The embedding represents the average word length of the text | ||
""" | ||
|
||
tokens = self.tokenizer.tokenize(text) | ||
average_token_length = sum(len(token) for token in tokens) / len(tokens) | ||
|
||
# Convert average word length to an angle, and pad the vector to length 768 | ||
angle = (1 / average_token_length) * 2 * math.pi | ||
embedding = [math.cos(angle), math.sin(angle)] + ([0] * 766) | ||
|
||
# Normalize the embedding before returning it | ||
return [x / sum(embedding) for x in embedding] | ||
|
||
|
||
class MockTokenizer: | ||
def tokenize(self, text): | ||
return text.split() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from tests.mock.mock_sentence_transformer import MockSentenceTransformer | ||
|
||
|
||
def test_mock_sentence_transformer(): | ||
embedding_model = MockSentenceTransformer() | ||
|
||
assert embedding_model.max_seq_length == 512 | ||
assert embedding_model.tokenizer.tokenize("Hello, world!") == ["Hello,", "world!"] | ||
assert len(embedding_model.encode("Hello, world!")) == 768 | ||
# It should be about 1, but with some tolerance for floating point imprecision | ||
assert sum(embedding_model.encode("Hello, world!")) - 1 < 0.01 | ||
|
||
# Test that we can compare similarity with dot product, | ||
# where sentences with the same average length word are considered more similar | ||
long_text = embedding_model.encode( | ||
"Incomprehensibility characterizes unintelligible, overwhelmingly convoluted dissertations." | ||
) | ||
medium_text = embedding_model.encode( | ||
"Curiosity inspires creative, innovative communities worldwide." | ||
) | ||
short_text = embedding_model.encode("The quick brown red fox jumps.") | ||
|
||
def dot_product(v1, v2): | ||
return sum(x * y for x, y in zip(v1, v2, strict=True)) | ||
|
||
assert dot_product(long_text, long_text) > dot_product(long_text, medium_text) | ||
assert dot_product(long_text, medium_text) > dot_product(long_text, short_text) | ||
assert dot_product(medium_text, medium_text) > dot_product(medium_text, short_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters