Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add script to test model loading below n_parameters threshold #1698

Merged
merged 35 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
76d2e96
add model loading test for models below 2B params
isaac-chung Jan 3, 2025
a9d0c44
add failure message to include model namne
isaac-chung Jan 3, 2025
766aad2
use the real get_model_meta
isaac-chung Jan 3, 2025
230d4f2
use cache folder
isaac-chung Jan 3, 2025
99abdb5
teardown per function
isaac-chung Jan 3, 2025
0cbdaa0
fix directory removal
isaac-chung Jan 3, 2025
59cc65b
write to file
isaac-chung Jan 4, 2025
ea1d21f
wip loading from before
isaac-chung Jan 4, 2025
129e8cc
wip
isaac-chung Jan 4, 2025
8fbb48f
Rename model_loading_testing.py to model_loading.py
isaac-chung Jan 4, 2025
fb95ee7
Delete tests/test_models/test_model_loading.py
isaac-chung Jan 4, 2025
41c4b5c
checks for models below 2B
isaac-chung Jan 5, 2025
9af61d0
try not using cache folder
isaac-chung Jan 5, 2025
b8777d1
update script with scan_cache_dir and add args
isaac-chung Jan 6, 2025
bd56f86
add github CI: detect changed model files and run model loading test
isaac-chung Jan 6, 2025
dcdd80a
install all model dependencies
isaac-chung Jan 6, 2025
64d9c83
dependecy installations and move file location
isaac-chung Jan 7, 2025
0eef873
should trigger a model load test in CI
isaac-chung Jan 7, 2025
86ad348
find correct commit for diff
isaac-chung Jan 7, 2025
9cf1280
explicity fetch base branch
isaac-chung Jan 7, 2025
3982311
add make command
isaac-chung Jan 7, 2025
6fbaf0f
try to run in python instead and add pytest
isaac-chung Jan 7, 2025
8830034
fix attribute error and add read mode
isaac-chung Jan 7, 2025
b1c2021
separate script calling
isaac-chung Jan 7, 2025
fc89ce0
Merge branch 'main' of https://github.com/embeddings-benchmark/mteb i…
isaac-chung Jan 7, 2025
d843138
let pip install be cached and specify repo path
isaac-chung Jan 7, 2025
f994ab1
check ancestry
isaac-chung Jan 7, 2025
95d804d
add cache and rebase
isaac-chung Jan 7, 2025
a85a2cd
try to merge instead of rebase
isaac-chung Jan 7, 2025
609c883
try without merge base
isaac-chung Jan 8, 2025
44ccf08
check if file exists first
isaac-chung Jan 8, 2025
d479c5f
Apply suggestions from code review
isaac-chung Jan 8, 2025
fb26eab
Update .github/workflows/model_loading.yml
isaac-chung Jan 8, 2025
3dcaa96
Merge branch 'main' into add-model-load-test-below-n_param_threshold
isaac-chung Jan 9, 2025
a9ffc88
address review comments to run test once from CI and not pytest
isaac-chung Jan 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/model_loading.yml
isaac-chung marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Model Loading

on:
pull_request:
paths:
- 'mteb/models/**.py'
isaac-chung marked this conversation as resolved.
Show resolved Hide resolved

jobs:
extract-and-run:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'

- name: Install dependencies and run tests
run: |
make model-load-test
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,6 @@ tests/create_meta/model_card.md
# removed results from mteb repo they are now available at: https://github.com/embeddings-benchmark/results
results/
uv.lock

# model loading tests
model_names.txt
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,11 @@ pr:
build-docs:
@echo "--- 📚 Building documentation ---"
# since we do not have a documentation site, this just build tables for the .md files
python docs/create_tasks_table.py
python docs/create_tasks_table.py


model-load-test:
@echo "--- 🚀 Running model load test ---"
pip install ".[dev, speedtask, pylate,gritlm,xformers,model2vec]"
python scripts/extract_model_names.py
python tests/test_models/test_model_loading.py
2 changes: 1 addition & 1 deletion mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def instruct_wrapper(
from gritlm import GritLM
except ImportError:
raise ImportError(
f"Please install `pip install gritlm` to use {model_name_or_path}."
f"Please install `pip install mteb[gritlm]` to use {model_name_or_path}."
)

class InstructWrapper(GritLM, Wrapper):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ openai = ["openai>=1.41.0", "tiktoken>=0.8.0"]
model2vec = ["model2vec>=0.3.0"]
pylate = ["pylate>=1.1.4"]
bm25s = ["bm25s>=0.2.6", "PyStemmer>=2.2.0.3"]
gritlm = ["gritlm>=1.0.2"]
xformers = ["xformers>=0.0.29"]


[tool.coverage.report]
Expand Down
63 changes: 63 additions & 0 deletions scripts/extract_model_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import ast
import sys
from pathlib import Path

from git import Repo


def get_changed_files(base_branch="main"):
repo_path = Path(__file__).parent.parent
repo = Repo(repo_path)
repo.remotes.origin.fetch(base_branch)

base_commit = repo.commit(f"origin/{base_branch}")
head_commit = repo.commit("HEAD")

diff = repo.git.diff("--name-only", base_commit, head_commit)

changed_files = diff.splitlines()
return [
f for f in changed_files if f.startswith("mteb/models/") and f.endswith(".py")
]


def extract_model_names(files: list[str]) -> list[str]:
model_names = []
for file in files:
with open(file) as f:
tree = ast.parse(f.read())
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Name)
and node.value.func.id == "ModelMeta"
):
model_name = next(
(
kw.value.value
for kw in node.value.keywords
if kw.arg == "name"
),
None,
)
if model_name:
model_names.append(model_name)
return model_names


if __name__ == "__main__":
"""
Can pass in base branch as an argument. Defaults to 'main'.
e.g. python extract_model_names.py mieb
"""
base_branch = sys.argv[1] if len(sys.argv) > 1 else "main"
changed_files = get_changed_files(base_branch)
model_names = extract_model_names(changed_files)
output_file = Path(__file__).parent / "model_names.txt"
with output_file.open("w") as f:
f.write(" ".join(model_names))
116 changes: 116 additions & 0 deletions scripts/model_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

import argparse
import json
import logging
from pathlib import Path

from huggingface_hub import scan_cache_dir

from mteb import get_model, get_model_meta
from mteb.models.overview import MODEL_REGISTRY

logging.basicConfig(level=logging.INFO)


def teardown_function():
hf_cache_info = scan_cache_dir()
all_revisions = []
for repo in list(hf_cache_info.repos):
for revision in list(repo.revisions):
all_revisions.append(revision.commit_hash)

delete_strategy = scan_cache_dir().delete_revisions(*all_revisions)
print("Will free " + delete_strategy.expected_freed_size_str)
delete_strategy.execute()


def get_model_below_n_param_threshold(model_name: str) -> str:
"""Test that we can get all models with a number of parameters below a threshold."""
model_meta = get_model_meta(model_name=model_name)
assert model_meta is not None
if model_meta.n_parameters is not None:
if model_meta.n_parameters >= 2e9:
return "Over threshold. Not tested."
elif "API" in model_meta.framework:
try:
m = get_model(model_name)
if m is not None:
del m
return "None"
except Exception as e:
logging.warning(f"Failed to load model {model_name} with error {e}")
return e.__str__()
try:
m = get_model(model_name)
if m is not None:
del m
return "None"
except Exception as e:
logging.warning(f"Failed to load model {model_name} with error {e}")
return e.__str__()
finally:
teardown_function()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--omit_previous_success",
action="store_true",
default=False,
help="Omit models that have been successfully loaded in the past",
)
parser.add_argument(
"--run_missing",
action="store_true",
default=False,
help="Run the missing models in the registry that are missing from existing results.",
)
parser.add_argument(
"--model_name",
type=str,
nargs="+",
default=None,
help="Run the script for specific model names, e.g. model_1, model_2",
)

return parser.parse_args()


if __name__ == "__main__":
output_file = (
Path(__file__).parent.parent
/ "tests"
/ "test_models"
/ "model_load_failures.json"
)

args = parse_args()

# Load existing results if the file exists
results = {}
if output_file.exists():
with output_file.open("r") as f:
results = json.load(f)

if args.model_name:
all_model_names = args.model_name
else:
omit_keys = []
if args.run_missing:
omit_keys = list(results.keys())
elif args.omit_previous_success:
omit_keys = [k for k, v in results.items() if v == "None"]

all_model_names = list(set(MODEL_REGISTRY.keys()) - set(omit_keys))

for model_name in all_model_names:
error_msg = get_model_below_n_param_threshold(model_name)
results[model_name] = error_msg

results = dict(sorted(results.items()))

# Write the results to the file after each iteration
with output_file.open("w") as f:
json.dump(results, f, indent=4)
Loading
Loading