From 8d033f39415cd00840a1e0f6305d453ca6032abf Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 9 Jan 2025 17:42:42 +0200 Subject: [PATCH] test: Add script to test model loading below n_parameters threshold (#1698) * add model loading test for models below 2B params * add failure message to include model namne * use the real get_model_meta * use cache folder * teardown per function * fix directory removal * write to file * wip loading from before * wip * Rename model_loading_testing.py to model_loading.py * Delete tests/test_models/test_model_loading.py * checks for models below 2B * try not using cache folder * update script with scan_cache_dir and add args * add github CI: detect changed model files and run model loading test * install all model dependencies * dependecy installations and move file location * should trigger a model load test in CI * find correct commit for diff * explicity fetch base branch * add make command * try to run in python instead and add pytest * fix attribute error and add read mode * separate script calling * let pip install be cached and specify repo path * check ancestry * add cache and rebase * try to merge instead of rebase * try without merge base * check if file exists first * Apply suggestions from code review Co-authored-by: Kenneth Enevoldsen * Update .github/workflows/model_loading.yml Co-authored-by: Kenneth Enevoldsen * address review comments to run test once from CI and not pytest --------- Co-authored-by: Kenneth Enevoldsen --- .github/workflows/model_loading.yml | 24 +++ .gitignore | 3 + Makefile | 9 +- mteb/models/instruct_wrapper.py | 2 +- pyproject.toml | 2 + scripts/extract_model_names.py | 63 +++++++ tests/test_models/model_load_failures.json | 197 +++++++++++++++++++++ tests/test_models/model_loading.py | 127 +++++++++++++ 8 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/model_loading.yml create mode 100644 scripts/extract_model_names.py create mode 100644 tests/test_models/model_load_failures.json create mode 100644 tests/test_models/model_loading.py diff --git a/.github/workflows/model_loading.yml b/.github/workflows/model_loading.yml new file mode 100644 index 0000000000..8707a9c1d6 --- /dev/null +++ b/.github/workflows/model_loading.yml @@ -0,0 +1,24 @@ +name: Model Loading + +on: + pull_request: + paths: + - 'mteb/models/**.py' + +jobs: + extract-and-run: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install dependencies and run tests + run: | + make model-load-test diff --git a/.gitignore b/.gitignore index 868f0f1745..977fe8dc1a 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,6 @@ tests/create_meta/model_card.md # removed results from mteb repo they are now available at: https://github.com/embeddings-benchmark/results results/ uv.lock + +# model loading tests +model_names.txt \ No newline at end of file diff --git a/Makefile b/Makefile index c1404270d9..6e8647a2ce 100644 --- a/Makefile +++ b/Makefile @@ -35,4 +35,11 @@ pr: build-docs: @echo "--- 📚 Building documentation ---" # since we do not have a documentation site, this just build tables for the .md files - python docs/create_tasks_table.py \ No newline at end of file + python docs/create_tasks_table.py + + +model-load-test: + @echo "--- 🚀 Running model load test ---" + pip install ".[dev, speedtask, pylate,gritlm,xformers,model2vec]" + python scripts/extract_model_names.py + python tests/test_models/model_loading.py --model_name_file scripts/model_names.txt \ No newline at end of file diff --git a/mteb/models/instruct_wrapper.py b/mteb/models/instruct_wrapper.py index 303a386836..2ee3a09b56 100644 --- a/mteb/models/instruct_wrapper.py +++ b/mteb/models/instruct_wrapper.py @@ -24,7 +24,7 @@ def instruct_wrapper( from gritlm import GritLM except ImportError: raise ImportError( - f"Please install `pip install gritlm` to use {model_name_or_path}." + f"Please install `pip install mteb[gritlm]` to use {model_name_or_path}." ) class InstructWrapper(GritLM, Wrapper): diff --git a/pyproject.toml b/pyproject.toml index ed02ec8845..c8e6f51885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,8 @@ openai = ["openai>=1.41.0", "tiktoken>=0.8.0"] model2vec = ["model2vec>=0.3.0"] pylate = ["pylate>=1.1.4"] bm25s = ["bm25s>=0.2.6", "PyStemmer>=2.2.0.3"] +gritlm = ["gritlm>=1.0.2"] +xformers = ["xformers>=0.0.29"] [tool.coverage.report] diff --git a/scripts/extract_model_names.py b/scripts/extract_model_names.py new file mode 100644 index 0000000000..dbe99a990e --- /dev/null +++ b/scripts/extract_model_names.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import ast +import sys +from pathlib import Path + +from git import Repo + + +def get_changed_files(base_branch="main"): + repo_path = Path(__file__).parent.parent + repo = Repo(repo_path) + repo.remotes.origin.fetch(base_branch) + + base_commit = repo.commit(f"origin/{base_branch}") + head_commit = repo.commit("HEAD") + + diff = repo.git.diff("--name-only", base_commit, head_commit) + + changed_files = diff.splitlines() + return [ + f for f in changed_files if f.startswith("mteb/models/") and f.endswith(".py") + ] + + +def extract_model_names(files: list[str]) -> list[str]: + model_names = [] + for file in files: + with open(file) as f: + tree = ast.parse(f.read()) + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "ModelMeta" + ): + model_name = next( + ( + kw.value.value + for kw in node.value.keywords + if kw.arg == "name" + ), + None, + ) + if model_name: + model_names.append(model_name) + return model_names + + +if __name__ == "__main__": + """ + Can pass in base branch as an argument. Defaults to 'main'. + e.g. python extract_model_names.py mieb + """ + base_branch = sys.argv[1] if len(sys.argv) > 1 else "main" + changed_files = get_changed_files(base_branch) + model_names = extract_model_names(changed_files) + output_file = Path(__file__).parent / "model_names.txt" + with output_file.open("w") as f: + f.write(" ".join(model_names)) diff --git a/tests/test_models/model_load_failures.json b/tests/test_models/model_load_failures.json new file mode 100644 index 0000000000..f1be1c940b --- /dev/null +++ b/tests/test_models/model_load_failures.json @@ -0,0 +1,197 @@ +{ + "Alibaba-NLP/gte-Qwen1.5-7B-instruct": "Over threshold. Not tested.", + "Alibaba-NLP/gte-Qwen2-1.5B-instruct": "None", + "Alibaba-NLP/gte-Qwen2-7B-instruct": "Over threshold. Not tested.", + "BAAI/bge-base-en-v1.5": "None", + "BAAI/bge-large-en-v1.5": "Over threshold. Not tested.", + "BAAI/bge-reranker-v2-m3": "None", + "BAAI/bge-small-en-v1.5": "None", + "BAAI/bge-small-en-v1.5 BAAI/bge-base-en-v1.5 BAAI/bge-large-en-v1.5": null, + "BeastyZ/e5-R-mistral-7b": "Over threshold. Not tested.", + "Cohere/Cohere-embed-english-light-v3.0": "None", + "Cohere/Cohere-embed-english-v3.0": "None", + "Cohere/Cohere-embed-multilingual-light-v3.0": "None", + "Cohere/Cohere-embed-multilingual-v3.0": "None", + "DeepPavlov/distilrubert-small-cased-conversational": "None", + "DeepPavlov/rubert-base-cased": "None", + "DeepPavlov/rubert-base-cased-sentence": "None", + "Gameselo/STS-multilingual-mpnet-base-v2": "None", + "GritLM/GritLM-7B": "Over threshold. Not tested.", + "GritLM/GritLM-8x7B": "Over threshold. Not tested.", + "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1": "None", + "HIT-TMG/KaLM-embedding-multilingual-mini-v1": "None", + "Haon-Chen/speed-embedding-7b-instruct": "Over threshold. Not tested.", + "Hum-Works/lodestone-base-4096-v1": "None", + "Jaume/gemma-2b-embeddings": "Over threshold. Not tested.", + "Lajavaness/bilingual-embedding-base": "None", + "Lajavaness/bilingual-embedding-large": "None", + "Lajavaness/bilingual-embedding-small": "None", + "Linq-AI-Research/Linq-Embed-Mistral": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-supervised": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-supervised": "Over threshold. Not tested.", + "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse": "Over threshold. Not tested.", + "Mihaiii/Bulbasaur": "None", + "Mihaiii/Ivysaur": "None", + "Mihaiii/Squirtle": "None", + "Mihaiii/Venusaur": "None", + "Mihaiii/Wartortle": "None", + "Mihaiii/gte-micro": "None", + "Mihaiii/gte-micro-v4": "None", + "Omartificial-Intelligence-Space/Arabert-all-nli-triplet-Matryoshka": "None", + "Omartificial-Intelligence-Space/Arabic-MiniLM-L12-v2-all-nli-triplet": "None", + "Omartificial-Intelligence-Space/Arabic-all-nli-triplet-Matryoshka": "None", + "Omartificial-Intelligence-Space/Arabic-labse-Matryoshka": "None", + "Omartificial-Intelligence-Space/Arabic-mpnet-base-all-nli-triplet": "None", + "Omartificial-Intelligence-Space/Marbert-all-nli-triplet-Matryoshka": "None", + "OrdalieTech/Solon-embeddings-large-0.1": "None", + "OrlikB/KartonBERT-USE-base-v1": "None", + "OrlikB/st-polish-kartonberta-base-alpha-v1": "None", + "Salesforce/SFR-Embedding-2_R": "Over threshold. Not tested.", + "Salesforce/SFR-Embedding-Mistral": "Over threshold. Not tested.", + "Snowflake/snowflake-arctic-embed-l": "None", + "Snowflake/snowflake-arctic-embed-l-v2.0": "None", + "Snowflake/snowflake-arctic-embed-m": "None", + "Snowflake/snowflake-arctic-embed-m-long": "None", + "Snowflake/snowflake-arctic-embed-m-v1.5": "None", + "Snowflake/snowflake-arctic-embed-m-v2.0": "None", + "Snowflake/snowflake-arctic-embed-s": "None", + "Snowflake/snowflake-arctic-embed-xs": "None", + "WhereIsAI/UAE-Large-V1": "None", + "aari1995/German_Semantic_STS_V2": "None", + "abhinand/MedEmbed-small-v0.1": "None", + "ai-forever/ru-en-RoSBERTa": "None", + "ai-forever/sbert_large_mt_nlu_ru": "None", + "ai-forever/sbert_large_nlu_ru": "None", + "avsolatorio/GIST-Embedding-v0": "None", + "avsolatorio/GIST-all-MiniLM-L6-v2": "None", + "avsolatorio/GIST-large-Embedding-v0": "None", + "avsolatorio/GIST-small-Embedding-v0": "None", + "avsolatorio/NoInstruct-small-Embedding-v0": "None", + "bigscience/sgpt-bloom-7b1-msmarco": "None", + "bm25s": "None", + "brahmairesearch/slx-v0.1": "None", + "castorini/monobert-large-msmarco": "None", + "castorini/monot5-3b-msmarco-10k": "None", + "castorini/monot5-base-msmarco-10k": "None", + "castorini/monot5-large-msmarco-10k": "None", + "castorini/monot5-small-msmarco-10k": "None", + "castorini/repllama-v1-7b-lora-passage": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Llama-2-7b-hf.\n401 Client Error. (Request ID: Root=1-67794457-7e56cbf325381c760c430207;a79cc472-a4fc-49dc-80f0-9d4b8cb5ef42)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/config.json.\nAccess to model meta-llama/Llama-2-7b-hf is restricted. You must have access to it and be authenticated to access it. Please log in.", + "cointegrated/LaBSE-en-ru": "None", + "cointegrated/rubert-tiny": "None", + "cointegrated/rubert-tiny2": "None", + "colbert-ir/colbertv2.0": "None", + "consciousAI/cai-lunaris-text-embeddings": "None", + "consciousAI/cai-stellaris-text-embeddings": "None", + "deepfile/embedder-100p": "None", + "deepvk/USER-base": "None", + "deepvk/USER-bge-m3": "None", + "deepvk/deberta-v1-base": "None", + "dunzhang/stella_en_1.5B_v5": "None", + "dunzhang/stella_en_400M_v5": "None", + "dwzhu/e5-base-4k": "None", + "google/flan-t5-base": "None", + "google/flan-t5-large": "None", + "google/flan-t5-xl": "None", + "google/flan-t5-xxl": "None", + "google/text-embedding-004": "None", + "google/text-embedding-005": "None", + "google/text-multilingual-embedding-002": "None", + "ibm-granite/granite-embedding-107m-multilingual": "None", + "ibm-granite/granite-embedding-125m-english": "None", + "ibm-granite/granite-embedding-278m-multilingual": "None", + "ibm-granite/granite-embedding-30m-english": "None", + "infgrad/jasper_en_vision_language_v1": "Over threshold. Not tested.", + "infgrad/stella-base-en-v2": "None", + "intfloat/e5-base": "None", + "intfloat/e5-base-v2": "None", + "intfloat/e5-large": "None", + "intfloat/e5-large-v2": "None", + "intfloat/e5-mistral-7b-instruct": "Over threshold. Not tested.", + "intfloat/e5-small": "None", + "intfloat/e5-small-v2": "None", + "intfloat/multilingual-e5-base": "None", + "intfloat/multilingual-e5-large": "None", + "intfloat/multilingual-e5-large-instruct": "None", + "intfloat/multilingual-e5-small": "None", + "izhx/udever-bloom-1b1": "None", + "izhx/udever-bloom-3b": "None", + "izhx/udever-bloom-560m": "None", + "izhx/udever-bloom-7b1": "None", + "jhu-clsp/FollowIR-7B": "None", + "jinaai/jina-colbert-v2": "None", + "jinaai/jina-embedding-b-en-v1": "None", + "jinaai/jina-embedding-s-en-v1": "None", + "jinaai/jina-embeddings-v2-base-en": "None", + "jinaai/jina-embeddings-v2-small-en": "None", + "jinaai/jina-embeddings-v3": "None", + "jinaai/jina-reranker-v2-base-multilingual": "None", + "keeeeenw/MicroLlama-text-embedding": "None", + "malenia1/ternary-weight-embedding": "None", + "manu/bge-m3-custom-fr": "None", + "manu/sentence_croissant_alpha_v0.2": "None", + "manu/sentence_croissant_alpha_v0.3": "Over threshold. Not tested.", + "manu/sentence_croissant_alpha_v0.4": "Over threshold. Not tested.", + "meta-llama/Llama-2-7b-chat-hf": "None", + "meta-llama/Llama-2-7b-hf": "None", + "minishlab/M2V_base_glove": "None", + "minishlab/M2V_base_glove_subword": "None", + "minishlab/M2V_base_output": "None", + "minishlab/M2V_multilingual_output": "None", + "minishlab/potion-base-2M": "None", + "minishlab/potion-base-4M": "None", + "minishlab/potion-base-8M": "None", + "mistralai/Mistral-7B-Instruct-v0.2": "None", + "mixedbread-ai/mxbai-embed-large-v1": "None", + "nomic-ai/nomic-embed-text-v1": "None", + "nomic-ai/nomic-embed-text-v1-ablated": "None", + "nomic-ai/nomic-embed-text-v1-unsupervised": "None", + "nomic-ai/nomic-embed-text-v1.5": "None", + "nvidia/NV-Embed-v1": "Over threshold. Not tested.", + "nvidia/NV-Embed-v2": "Over threshold. Not tested.", + "omarelshehy/arabic-english-sts-matryoshka": "None", + "openai/text-embedding-3-large": "None", + "openai/text-embedding-3-small": "None", + "openai/text-embedding-ada-002": "None", + "openbmb/MiniCPM-Embedding": "Over threshold. Not tested.", + "samaya-ai/RepLLaMA-reproduced": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Llama-2-7b-hf.\n401 Client Error. (Request ID: Root=1-6779403c-1bd84d333e938afa4e7cf86b;b873eea6-3c10-4659-b6da-2288d83e721b)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/config.json.\nAccess to model meta-llama/Llama-2-7b-hf is restricted. You must have access to it and be authenticated to access it. Please log in.", + "samaya-ai/promptriever-llama2-7b-v1": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Llama-2-7b-hf.\n401 Client Error. (Request ID: Root=1-677940f7-6c2bfcaa7985abb1165185ff;efdd2ef8-60a0-45c3-a92b-b24784b30b43)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/config.json.\nAccess to model meta-llama/Llama-2-7b-hf is restricted. You must have access to it and be authenticated to access it. Please log in.", + "samaya-ai/promptriever-llama3.1-8b-instruct-v1": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct.\n401 Client Error. (Request ID: Root=1-6779430b-3277d7961f3c88ab56ecf91f;a476a013-b28f-47c6-bd95-e3d6fe823468)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/resolve/main/config.json.\nAccess to model meta-llama/Llama-3.1-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.", + "samaya-ai/promptriever-llama3.1-8b-v1": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3.1-8B.\n401 Client Error. (Request ID: Root=1-677bba8f-608cf825273d8d2b0670b5ad;066bb2fa-3bef-4fb9-b3cb-4c5ffee41047)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json.\nAccess to model meta-llama/Llama-3.1-8B is restricted. You must have access to it and be authenticated to access it. Please log in.", + "samaya-ai/promptriever-mistral-v0.1-7b-v1": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/mistralai/Mistral-7B-v0.1.\n401 Client Error. (Request ID: Root=1-67794457-688a6d9c24a9e8f15cf70d28;da3a233f-7c7c-4919-9cee-72a1d66acdb6)\n\nCannot access gated repo for url https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json.\nAccess to model mistralai/Mistral-7B-v0.1 is restricted. You must have access to it and be authenticated to access it. Please log in.", + "sdadas/mmlw-e5-base": "None", + "sdadas/mmlw-e5-large": "None", + "sdadas/mmlw-e5-small": "None", + "sdadas/mmlw-roberta-base": "None", + "sdadas/mmlw-roberta-large": "None", + "sentence-transformer/multi-qa-MiniLM-L6-cos-v1": "sentence-transformer/multi-qa-MiniLM-L6-cos-v1 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=`", + "sentence-transformers/LaBSE": "None", + "sentence-transformers/all-MiniLM-L12-v2": "None", + "sentence-transformers/all-MiniLM-L6-v2": "None", + "sentence-transformers/all-mpnet-base-v2": "None", + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "None", + "sentence-transformers/paraphrase-multilingual-mpnet-base-v2": "None", + "sergeyzh/LaBSE-ru-turbo": "None", + "sergeyzh/rubert-tiny-turbo": "None", + "shibing624/text2vec-base-multilingual": "None", + "silma-ai/silma-embeddding-matryoshka-v0.1": "None", + "thenlper/gte-base": "None", + "thenlper/gte-large": "None", + "thenlper/gte-small": "None", + "unicamp-dl/mt5-13b-mmarco-100k": "None", + "unicamp-dl/mt5-base-mmarco-v2": "None", + "voyage-large-2": "None", + "voyageai/voyage-2": "None", + "voyageai/voyage-3": "None", + "voyageai/voyage-3-lite": "None", + "voyageai/voyage-code-2": "None", + "voyageai/voyage-finance-2": "None", + "voyageai/voyage-large-2-instruct": "None", + "voyageai/voyage-law-2": "None", + "voyageai/voyage-multilingual-2": "None", + "zeta-alpha-ai/Zeta-Alpha-E5-Mistral": "Over threshold. Not tested." +} \ No newline at end of file diff --git a/tests/test_models/model_loading.py b/tests/test_models/model_loading.py new file mode 100644 index 0000000000..3f22db733f --- /dev/null +++ b/tests/test_models/model_loading.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path + +from huggingface_hub import scan_cache_dir + +from mteb import get_model, get_model_meta +from mteb.models.overview import MODEL_REGISTRY + +logging.basicConfig(level=logging.INFO) + + +def teardown_function(): + hf_cache_info = scan_cache_dir() + all_revisions = [] + for repo in list(hf_cache_info.repos): + for revision in list(repo.revisions): + all_revisions.append(revision.commit_hash) + + delete_strategy = scan_cache_dir().delete_revisions(*all_revisions) + print("Will free " + delete_strategy.expected_freed_size_str) + delete_strategy.execute() + + +def get_model_below_n_param_threshold(model_name: str) -> str: + """Test that we can get all models with a number of parameters below a threshold.""" + model_meta = get_model_meta(model_name=model_name) + assert model_meta is not None + if model_meta.n_parameters is not None: + if model_meta.n_parameters >= 2e9: + return "Over threshold. Not tested." + elif "API" in model_meta.framework: + try: + m = get_model(model_name) + if m is not None: + del m + return "None" + except Exception as e: + logging.warning(f"Failed to load model {model_name} with error {e}") + return e.__str__() + try: + m = get_model(model_name) + if m is not None: + del m + return "None" + except Exception as e: + logging.warning(f"Failed to load model {model_name} with error {e}") + return e.__str__() + finally: + teardown_function() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--omit_previous_success", + action="store_true", + default=False, + help="Omit models that have been successfully loaded in the past", + ) + parser.add_argument( + "--run_missing", + action="store_true", + default=False, + help="Run the missing models in the registry that are missing from existing results.", + ) + parser.add_argument( + "--model_name", + type=str, + nargs="+", + default=None, + help="Run the script for specific model names, e.g. model_1, model_2", + ) + parser.add_argument( + "--model_name_file", + type=str, + default=None, + help="Filename containing space-separated model names to test.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + output_file = Path(__file__).parent / "model_load_failures.json" + + args = parse_args() + + # Load existing results if the file exists + results = {} + if output_file.exists(): + with output_file.open("r") as f: + results = json.load(f) + + if args.model_name: + all_model_names = args.model_name + elif args.model_name_file: + all_model_names = [] + if Path(args.model_name_file).exists(): + with open(args.model_name_file) as f: + all_model_names = f.read().strip().split() + else: + logging.warning( + f"Model name file {args.model_name_file} does not exist. Exiting." + ) + exit(1) + else: + omit_keys = [] + if args.run_missing: + omit_keys = list(results.keys()) + elif args.omit_previous_success: + omit_keys = [k for k, v in results.items() if v == "None"] + + all_model_names = list(set(MODEL_REGISTRY.keys()) - set(omit_keys)) + + for model_name in all_model_names: + error_msg = get_model_below_n_param_threshold(model_name) + results[model_name] = error_msg + + results = dict(sorted(results.items())) + + # Write the results to the file after each iteration + with output_file.open("w") as f: + json.dump(results, f, indent=4)