Skip to content

Commit

Permalink
Merge pull request #313 from bothub-it/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
lucasagra authored Nov 6, 2021
2 parents f7829e2 + b6e023e commit df7a7ea
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
14 changes: 4 additions & 10 deletions bothub/shared/utils/rasa_components/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,8 @@
"bert_multilang": bert_embeddings_post_processor,
}

model_config_url = {
"bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert-portuguese/config.json",
"bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert-english/config.json",
"bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert_multilang/config.json"
}

model_download_url = {
"bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert-portuguese/pytorch_model.bin",
"bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert-english/tf_model.h5",
"bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert_multilang/tf_model.h5"
model_url = {
"bert_portuguese": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_portuguese.zip",
"bert_english": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_english.zip",
"bert_multilang": "https://bothub-nlp-models.s3.amazonaws.com/bert/bert_multilang.zip"
}
26 changes: 11 additions & 15 deletions bothub/shared/utils/scripts/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import plac
import requests
import posixpath
import zipfile

from decouple import config
from spacy.cli import download
Expand All @@ -21,11 +21,8 @@
sys.path.insert(
1, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
)
from bothub.shared.utils.rasa_components.registry import (
from_pt_dict,
model_download_url,
model_config_url,
)

from bothub.shared.utils.rasa_components.registry import model_url

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,19 +52,18 @@ def download_file(url, file_name):


def download_bert(model_name):
# model_dir = posixpath.join("bothub", "nlu_worker", model_name)
model_dir = model_name
os.makedirs(model_dir, exist_ok=True)

from_pt = from_pt_dict.get(model_name, False)
model_url = model_download_url.get(model_name)
config_url = model_config_url.get(model_name)
zipped_file_name = "temp.zip"
url = model_url.get(model_name)
logger.info(f"downloading {model_name} . . .")
download_file(url, zipped_file_name)

logger.info("downloading bert")
model_file_name = "pytorch_model.bin" if from_pt else "tf_model.h5"
download_file(model_url, posixpath.join(model_dir, model_file_name))
download_file(config_url, posixpath.join(model_dir, "config.json"))
logger.info("finished downloading bert")
logger.info(f"extracting {model_name} . . .")
with zipfile.ZipFile(zipped_file_name, 'r') as zip_ref:
zip_ref.extractall(model_dir)
os.remove(zipped_file_name)


def cast_supported_languages(languages):
Expand Down

0 comments on commit df7a7ea

Please sign in to comment.