Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2] add similarity_fn in ModelMeta #1759

Open
wants to merge 27 commits into
base: v2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d71718b
add dotwrapper
Samoed Jan 10, 2025
d50fd88
lint
Samoed Jan 10, 2025
7d1e949
make cleaner
Samoed Jan 10, 2025
9e9a111
add poc similarity_fn in ModelMeta
sam-hey Jan 10, 2025
e4a692f
ref: rename EvaluationFunction to ScoringFunction
sam-hey Jan 11, 2025
1865345
make cos_sim default
sam-hey Jan 11, 2025
f34f110
Revert "make cleaner"
sam-hey Jan 11, 2025
49a954e
Revert "add dotwrapper"
sam-hey Jan 11, 2025
d9ebe97
lint
sam-hey Jan 11, 2025
4c89681
fix: _run_eval no co tracking
sam-hey Jan 12, 2025
fae6e31
Merge remote-tracking branch 'mteb/v2.0.0' into fix_contriever
sam-hey Jan 12, 2025
6298d75
fix: bm25s
sam-hey Jan 12, 2025
5a023d6
add enum to models
sam-hey Jan 12, 2025
8ad1e88
add mapping st sim fn name to mteb sim fn name
sam-hey Jan 12, 2025
700ad58
fix model meta use new fn for sim operators
sam-hey Jan 12, 2025
8cffb6a
add max_sim
sam-hey Jan 12, 2025
bf0cf07
fix: colbert & rm similarity_fn_name
sam-hey Jan 13, 2025
3391e1e
ci: skip AfriSentiLID for now (#1785)
isaac-chung Jan 13, 2025
7bb43ab
Merge branch 'v2.0.0' into fix_contriever
sam-hey Jan 13, 2025
4fabb09
test: add test for bm25s and ColBERT
sam-hey Jan 13, 2025
1442673
lint
sam-hey Jan 13, 2025
bb4beec
feat: add mapping for max_sim from pylate
sam-hey Jan 13, 2025
0f923c1
test: bm25s skip
sam-hey Jan 13, 2025
f4779c7
fix: MaxSim as max_sim match pylate & rm Enum in models
sam-hey Jan 13, 2025
89d1ae8
Merge remote-tracking branch 'mteb/v2.0.0' into fix_contriever
sam-hey Jan 14, 2025
07f4d6a
rm enum
sam-hey Jan 15, 2025
6c425f4
update tests skip
sam-hey Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@ def run(
task,
model,
split,
output_folder,
subsets_to_run=subsets_to_run,
encode_kwargs=encode_kwargs,
**kwargs,
Expand Down
1 change: 1 addition & 0 deletions mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __call__(
self.top_k,
task_name=self.task_name, # type: ignore
instructions=instructions,
score_function="bm25",
**kwargs,
)
else:
Expand Down
16 changes: 13 additions & 3 deletions mteb/evaluation/evaluators/model_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,19 @@ def _full_corpus_search(
query_embeddings = torch.as_tensor(query_embeddings).to(device)
sub_corpus_embeddings = torch.as_tensor(sub_corpus_embeddings).to(device)

score_function = (
self.model.similarity if hasattr(self.model, "similarity") else cos_sim
)
if hasattr(self.model.model, "mteb_model_meta") or hasattr(
self.model, "similarity"
):
score_function = (
self.model.similarity
if hasattr(self.model, "similarity")
else self.model.model.mteb_model_meta.get_similarity_function()
)
else:
logger.warning(
"The model does not provide `mteb_model_meta`; defaulting to the cosine similarity function."
)
score_function = cos_sim

with torch.inference_mode():
scores = score_function(query_embeddings, sub_corpus_embeddings)
Expand Down
27 changes: 27 additions & 0 deletions mteb/evaluation/evaluators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ def _cos_sim_core(a_tensor, b_tensor):
return _cos_sim_core(a, b)


def max_sim(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Computes the max-similarity max_sim(a[i], b[j]) for all i and j.
Works with a Tensor of the shape (batch_size, num_tokens, token_dim)

Return:
Matrix with res[i][j] = max_sim(a[i], b[j])
""" # noqa: D402
if not isinstance(a, torch.Tensor):
a = torch.tensor(a, dtype=torch.float32)

if not isinstance(b, torch.Tensor):
b = torch.tensor(b, dtype=torch.float32)

if len(a.shape) == 2:
a = a.unsqueeze(0)

if len(b.shape) == 2:
b = b.unsqueeze(0)

scores = torch.einsum(
"ash,bth->abst",
a,
b,
)

return scores.max(axis=-1).values.sum(axis=-1)

def dot_score(a: torch.Tensor, b: torch.Tensor):
"""Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
Expand Down
41 changes: 38 additions & 3 deletions mteb/model_meta.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

import logging
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal

from pydantic import BaseModel, ConfigDict
import numpy as np
from pydantic import BaseModel, ConfigDict, field_validator

from mteb.abstasks.AbsTask import AbsTask
from mteb.abstasks.TaskMetadata import STR_DATE, STR_URL
from mteb.encoder_interface import Encoder
from mteb.evaluation.evaluators.utils import cos_sim, dot_score, max_sim

from .languages import ISO_LANGUAGE_SCRIPT

Expand All @@ -30,7 +33,6 @@
"PyLate",
"ColBERT",
]
DISTANCE_METRICS = Literal["cosine", "max_sim", "dot"]


def sentence_transformers_loader(
Expand All @@ -51,6 +53,12 @@ def get_loader_name(
return loader.__name__


class ScoringFunction(str, Enum):
sam-hey marked this conversation as resolved.
Show resolved Hide resolved
DOT_PRODUCT: str = "dot"
COSINE: str = "cos_sim"
MAX_SIM: str = "max_sim"


class ModelMeta(BaseModel):
"""The model metadata object.

Expand Down Expand Up @@ -99,13 +107,40 @@ class ModelMeta(BaseModel):
public_training_code: bool | None = None
framework: list[FRAMEWORKS] = []
reference: STR_URL | None = None
similarity_fn_name: DISTANCE_METRICS | None = None
similarity_fn_name: ScoringFunction | None = None
use_instructions: bool | None = None
training_datasets: dict[str, list[str]] | None = None
adapted_from: str | None = None
superseded_by: str | None = None
citation: str | None = None

# @validator('similarity_fn_name', pre=True)
@field_validator("similarity_fn_name", mode="before")
@classmethod
def validate_similarity_fn_name(cls, value):
"""Converts the similarity function name to the corresponding enum value.
sentence_transformers uses Literal['cosine', 'dot', 'euclidean', 'manhattan'] for similarity_fn_name
"""
if type(value) is ScoringFunction or value is None:
return value
mapping = {
"cosine": ScoringFunction.COSINE,
"dot": ScoringFunction.DOT_PRODUCT,
}
if value in mapping:
return mapping[value]
raise ValueError(f"Invalid similarity function name: {value}")

def get_similarity_function(self) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
if self.similarity_fn_name == ScoringFunction.COSINE:
return cos_sim
elif self.similarity_fn_name == ScoringFunction.DOT_PRODUCT:
return dot_score
elif self.similarity_fn_name == ScoringFunction.MAX_SIM:
return max_sim
elif self.similarity_fn_name is None:
raise ValueError("Similarity function not specified.")

def to_dict(self):
dict_repr = self.model_dump()
loader = dict_repr.pop("loader", None)
Expand Down
20 changes: 10 additions & 10 deletions mteb/models/arctic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import partial

from mteb.model_meta import ModelMeta, sentence_transformers_loader
from mteb.model_meta import ModelMeta, ScoringFunction, sentence_transformers_loader

LANGUAGES_V2_0 = [
"afr_Latn",
Expand Down Expand Up @@ -102,7 +102,7 @@
embed_dim=768,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=False,
adapted_from=None,
superseded_by=None,
Expand Down Expand Up @@ -135,7 +135,7 @@
embed_dim=384,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-xs",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="sentence-transformers/all-MiniLM-L6-v2",
superseded_by=None,
Expand Down Expand Up @@ -185,7 +185,7 @@
embed_dim=384,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-s",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="intfloat/e5-small-unsupervised",
superseded_by=None,
Expand Down Expand Up @@ -235,7 +235,7 @@
embed_dim=768,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-m",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="intfloat/e5-base-unsupervised",
superseded_by="Snowflake/snowflake-arctic-embed-m-v1.5",
Expand Down Expand Up @@ -285,7 +285,7 @@
embed_dim=768,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-m-long",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="nomic-ai/nomic-embed-text-v1-unsupervised",
superseded_by="Snowflake/snowflake-arctic-embed-m-v2.0",
Expand Down Expand Up @@ -335,7 +335,7 @@
embed_dim=1024,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-l",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="intfloat/e5-base-unsupervised",
superseded_by="Snowflake/snowflake-arctic-embed-l-v2.0",
Expand Down Expand Up @@ -387,7 +387,7 @@
embed_dim=768,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from=None,
superseded_by="Snowflake/snowflake-arctic-embed-m-v2.0",
Expand All @@ -411,7 +411,7 @@
embed_dim=768,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="Alibaba-NLP/gte-multilingual-base",
superseded_by=None,
Expand Down Expand Up @@ -460,7 +460,7 @@
embed_dim=1024,
license="apache-2.0",
reference="https://huggingface.co/Snowflake/snowflake-arctic-embed-l-v2.0",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
adapted_from="BAAI/bge-m3-retromae",
superseded_by=None,
Expand Down
8 changes: 4 additions & 4 deletions mteb/models/bge_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import partial

from mteb.model_meta import ModelMeta, sentence_transformers_loader
from mteb.model_meta import ModelMeta, ScoringFunction, sentence_transformers_loader

model_prompts = {"query": "Represent this sentence for searching relevant passages: "}
BGE_15_CITATION = """@misc{bge_embedding,
Expand Down Expand Up @@ -31,7 +31,7 @@
license="mit",
max_tokens=512,
reference="https://huggingface.co/BAAI/bge-small-en-v1.5",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
citation=BGE_15_CITATION,
Expand Down Expand Up @@ -85,7 +85,7 @@
license="mit",
max_tokens=512,
reference="https://huggingface.co/BAAI/bge-base-en-v1.5",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
citation=BGE_15_CITATION,
Expand Down Expand Up @@ -139,7 +139,7 @@
license="mit",
max_tokens=512,
reference="https://huggingface.co/BAAI/bge-large-en-v1.5",
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
citation=BGE_15_CITATION,
Expand Down
10 changes: 5 additions & 5 deletions mteb/models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.model_meta import ModelMeta, ScoringFunction

from .wrapper import Wrapper

Expand Down Expand Up @@ -231,7 +231,7 @@ def encode(
embed_dim=512,
reference="https://cohere.com/blog/introducing-embed-v3",
license=None,
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["API"],
use_instructions=True,
public_training_data=False, # assumed
Expand All @@ -255,7 +255,7 @@ def encode(
max_tokens=512,
embed_dim=1024,
license=None,
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["API"],
use_instructions=True,
public_training_data=False, # assumed
Expand All @@ -279,7 +279,7 @@ def encode(
max_tokens=512,
embed_dim=384,
license=None,
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["API"],
use_instructions=True,
public_training_data=False, # assumed
Expand All @@ -303,7 +303,7 @@ def encode(
max_tokens=512,
embed_dim=384,
license=None,
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
framework=["API"],
use_instructions=True,
public_training_data=False, # assumed
Expand Down
6 changes: 3 additions & 3 deletions mteb/models/colbert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.model_meta import ModelMeta, ScoringFunction

from .wrapper import Wrapper

Expand Down Expand Up @@ -158,7 +158,7 @@ def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
max_tokens=180, # Reduced for Benchmarking - see ColBERT paper
embed_dim=None, # Bag of Embeddings (128) for each token
license="mit",
similarity_fn_name="max_sim",
similarity_fn_name=ScoringFunction.MAX_SIM,
framework=["PyLate", "ColBERT"],
reference="https://huggingface.co/colbert-ir/colbertv2.0",
use_instructions=False,
Expand Down Expand Up @@ -209,7 +209,7 @@ def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
max_tokens=8192,
embed_dim=None, # Bag of Embeddings (128) for each token
license="cc-by-nc-4.0",
similarity_fn_name="max_sim",
similarity_fn_name=ScoringFunction.MAX_SIM,
framework=["PyLate", "ColBERT"],
reference="https://huggingface.co/jinaai/jina-colbert-v2",
use_instructions=False,
Expand Down
6 changes: 3 additions & 3 deletions mteb/models/e5_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from mteb.model_meta import ModelMeta
from mteb.model_meta import ModelMeta, ScoringFunction

from .e5_models import E5_PAPER_RELEASE_DATE, XLMR_LANGUAGES
from .instruct_wrapper import instruct_wrapper
Expand Down Expand Up @@ -32,7 +32,7 @@
revision="baa7be480a7de1539afce709c8f13f833a510e0a",
release_date=E5_PAPER_RELEASE_DATE,
framework=["GritLM", "PyTorch"],
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
reference="https://huggingface.co/intfloat/multilingual-e5-large-instruct",
n_parameters=560_000_000,
Expand Down Expand Up @@ -66,7 +66,7 @@
revision="07163b72af1488142a360786df853f237b1a3ca1",
release_date=E5_PAPER_RELEASE_DATE,
framework=["GritLM", "PyTorch"],
similarity_fn_name="cosine",
similarity_fn_name=ScoringFunction.COSINE,
use_instructions=True,
reference="https://huggingface.co/intfloat/e5-mistral-7b-instruct",
n_parameters=7_111_000_000,
Expand Down
Loading
Loading