diff --git a/bothub/shared/utils/rasa_components/registry.py b/bothub/shared/utils/rasa_components/registry.py index 96d9198..b25cf56 100644 --- a/bothub/shared/utils/rasa_components/registry.py +++ b/bothub/shared/utils/rasa_components/registry.py @@ -80,14 +80,8 @@ "bert_multilang": bert_embeddings_post_processor, } -model_config_url = { - "bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert-portuguese/config.json", - "bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert-english/config.json", - "bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert_multilang/config.json" -} - -model_download_url = { - "bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert-portuguese/pytorch_model.bin", - "bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert-english/tf_model.h5", - "bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert_multilang/tf_model.h5" +model_url = { + "bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_portuguese.zip", + "bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_english.zip", + "bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_multilang.zip" } diff --git a/bothub/shared/utils/scripts/download_models.py b/bothub/shared/utils/scripts/download_models.py index f2e7ec8..671333a 100644 --- a/bothub/shared/utils/scripts/download_models.py +++ b/bothub/shared/utils/scripts/download_models.py @@ -11,7 +11,7 @@ import logging import plac import requests -import posixpath +import zipfile from decouple import config from spacy.cli import download @@ -21,11 +21,8 @@ sys.path.insert( 1, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) ) -from bothub.shared.utils.rasa_components.registry import ( - from_pt_dict, - model_download_url, - model_config_url, -) + +from bothub.shared.utils.rasa_components.registry import model_url logger = logging.getLogger(__name__) @@ -55,19 +52,18 @@ def download_file(url, file_name): def download_bert(model_name): - # model_dir = posixpath.join("bothub", "nlu_worker", model_name) model_dir = model_name os.makedirs(model_dir, exist_ok=True) - from_pt = from_pt_dict.get(model_name, False) - model_url = model_download_url.get(model_name) - config_url = model_config_url.get(model_name) + zipped_file_name = "temp.zip" + url = model_url.get(model_name) + logger.info(f"downloading {model_name} . . .") + download_file(url, zipped_file_name) - logger.info("downloading bert") - model_file_name = "pytorch_model.bin" if from_pt else "tf_model.h5" - download_file(model_url, posixpath.join(model_dir, model_file_name)) - download_file(config_url, posixpath.join(model_dir, "config.json")) - logger.info("finished downloading bert") + logger.info(f"extracting {model_name} . . .") + with zipfile.ZipFile(zipped_file_name, 'r') as zip_ref: + zip_ref.extractall(model_dir) + os.remove(zipped_file_name) def cast_supported_languages(languages):