Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2] add similarity_fn in ModelMeta #1759

Open
wants to merge 27 commits into
base: v2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d71718b
add dotwrapper
Samoed Jan 10, 2025
d50fd88
lint
Samoed Jan 10, 2025
7d1e949
make cleaner
Samoed Jan 10, 2025
9e9a111
add poc similarity_fn in ModelMeta
sam-hey Jan 10, 2025
e4a692f
ref: rename EvaluationFunction to ScoringFunction
sam-hey Jan 11, 2025
1865345
make cos_sim default
sam-hey Jan 11, 2025
f34f110
Revert "make cleaner"
sam-hey Jan 11, 2025
49a954e
Revert "add dotwrapper"
sam-hey Jan 11, 2025
d9ebe97
lint
sam-hey Jan 11, 2025
4c89681
fix: _run_eval no co tracking
sam-hey Jan 12, 2025
fae6e31
Merge remote-tracking branch 'mteb/v2.0.0' into fix_contriever
sam-hey Jan 12, 2025
6298d75
fix: bm25s
sam-hey Jan 12, 2025
5a023d6
add enum to models
sam-hey Jan 12, 2025
8ad1e88
add mapping st sim fn name to mteb sim fn name
sam-hey Jan 12, 2025
700ad58
fix model meta use new fn for sim operators
sam-hey Jan 12, 2025
8cffb6a
add max_sim
sam-hey Jan 12, 2025
bf0cf07
fix: colbert & rm similarity_fn_name
sam-hey Jan 13, 2025
3391e1e
ci: skip AfriSentiLID for now (#1785)
isaac-chung Jan 13, 2025
7bb43ab
Merge branch 'v2.0.0' into fix_contriever
sam-hey Jan 13, 2025
4fabb09
test: add test for bm25s and ColBERT
sam-hey Jan 13, 2025
1442673
lint
sam-hey Jan 13, 2025
bb4beec
feat: add mapping for max_sim from pylate
sam-hey Jan 13, 2025
0f923c1
test: bm25s skip
sam-hey Jan 13, 2025
f4779c7
fix: MaxSim as max_sim match pylate & rm Enum in models
sam-hey Jan 13, 2025
89d1ae8
Merge remote-tracking branch 'mteb/v2.0.0' into fix_contriever
sam-hey Jan 14, 2025
07f4d6a
rm enum
sam-hey Jan 15, 2025
6c425f4
update tests skip
sam-hey Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mteb/evaluation/evaluators/model_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,11 @@ def _full_corpus_search(
query_embeddings = torch.as_tensor(query_embeddings).to(device)
sub_corpus_embeddings = torch.as_tensor(sub_corpus_embeddings).to(device)

# TODO: If the model is not DRESModel
Samoed marked this conversation as resolved.
Show resolved Hide resolved
score_function = (
self.model.similarity if hasattr(self.model, "similarity") else cos_sim
self.model.similarity
if hasattr(self.model, "similarity")
else self.model.model.mteb_model_meta.get_evaluation_function()
sam-hey marked this conversation as resolved.
Show resolved Hide resolved
)

with torch.inference_mode():
Expand Down
23 changes: 22 additions & 1 deletion mteb/model_meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand All @@ -9,6 +10,7 @@
from mteb.abstasks.AbsTask import AbsTask
from mteb.abstasks.TaskMetadata import STR_DATE, STR_URL
from mteb.encoder_interface import Encoder
from mteb.evaluation.evaluators.utils import cos_sim, dot_score

from .languages import ISO_LANGUAGE_SCRIPT

Expand Down Expand Up @@ -51,6 +53,11 @@ def get_loader_name(
return loader.__name__


class EvaluationFunction(str, Enum):
sam-hey marked this conversation as resolved.
Show resolved Hide resolved
DOT_PRODUCT: str = "dot_score"
COSINE: str = "cos_sim"


class ModelMeta(BaseModel):
"""The model metadata object.

Expand Down Expand Up @@ -101,13 +108,27 @@ class ModelMeta(BaseModel):
public_training_code: bool | None = None
framework: list[FRAMEWORKS] = []
reference: STR_URL | None = None
similarity_fn_name: DISTANCE_METRICS | None = None
similarity_fn_name: DISTANCE_METRICS | EvaluationFunction | None = None
sam-hey marked this conversation as resolved.
Show resolved Hide resolved
use_instructions: bool | None = None
training_datasets: dict[str, list[str]] | None = None
adapted_from: str | None = None
superseded_by: str | None = None
citation: str | None = None

def get_evaluation_function(self) -> callable:
if (
self.similarity_fn_name == "cosine"
or self.similarity_fn_name == EvaluationFunction.COSINE
):
return cos_sim
elif (
self.similarity_fn_name == "dot"
or self.similarity_fn_name == EvaluationFunction.DOT_PRODUCT
):
return dot_score
else:
raise ValueError(f"Unknown similarity function {self.similarity_fn_name}")

def to_dict(self):
dict_repr = self.model_dump()
loader = dict_repr.pop("loader", None)
Expand Down
4 changes: 2 additions & 2 deletions mteb/models/e5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta, sentence_transformers_loader
from mteb.model_meta import EvaluationFunction, ModelMeta, sentence_transformers_loader

E5_PAPER_RELEASE_DATE = "2024-02-08"
XLMR_LANGUAGES = [
Expand Down Expand Up @@ -149,7 +149,7 @@
license="mit",
max_tokens=512,
reference="https://huggingface.co/intfloat/multilingual-e5-small",
similarity_fn_name="cosine",
similarity_fn_name=EvaluationFunction.COSINE,
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
citation=MULTILINGUAL_E5_CITATION,
Expand Down
8 changes: 7 additions & 1 deletion mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from mteb.encoder_interface import PromptType

from ..evaluation import dot_distance
from .wrapper import Wrapper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
if isinstance(self.model, CrossEncoder):
self.predict = self._predict

if hasattr(self.model, "similarity"):
if hasattr(self.model, "similarity") and not hasattr(self, "similarity"):
self.similarity = self.model.similarity

def encode(
Expand Down Expand Up @@ -125,3 +126,8 @@ def _predict(
convert_to_numpy=True,
**kwargs,
)


class SentenceTransformerWrapperDotSimilarity(SentenceTransformerWrapper):
def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
return dot_distance(embedding1, embedding2)
10 changes: 10 additions & 0 deletions mteb/models/sentence_transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from __future__ import annotations

from functools import partial

from mteb.model_meta import ModelMeta
from mteb.models.sentence_transformer_wrapper import (
SentenceTransformerWrapperDotSimilarity,
)

paraphrase_langs = [
"ara_Arab",
Expand Down Expand Up @@ -375,6 +380,11 @@
)

contriever = ModelMeta(
loader=partial(
SentenceTransformerWrapperDotSimilarity,
model="facebook/contriever-msmarco",
revision="abe8c1493371369031bcb1e02acb754cf4e162fa",
),
name="facebook/contriever-msmarco",
languages=["eng-Latn"],
open_weights=True,
Expand Down
Loading