diff --git a/ersilia/core/model.py b/ersilia/core/model.py index a137be80e..7831be542 100644 --- a/ersilia/core/model.py +++ b/ersilia/core/model.py @@ -4,15 +4,15 @@ import time import types import asyncio -import importlib import collections -import __main__ as main +import sys from click import secho as echo # Style-aware echo from .. import logger from ..serve.api import Api from .session import Session +from ..hub.fetch.fetch import ModelFetcher from .base import ErsiliaBase from ..lake.base import LakeBase from ..utils import tmp_pid_file @@ -128,7 +128,7 @@ def __init__( else: self.logger.set_verbosity(0) else: - if not hasattr(main, "__file__"): + if hasattr(sys, 'ps1'): self.logger.set_verbosity(0) self.save_to_lake = save_to_lake if self.save_to_lake: @@ -152,6 +152,7 @@ def __init__( self.service_class = service_class mdl = ModelBase(model) self._is_valid = mdl.is_valid() + assert self._is_valid, "The identifier {0} is not valid. Please visit the Ersilia Model Hub for valid identifiers".format( model ) @@ -174,13 +175,11 @@ def __init__( self.logger.debug("Unable to capture user input. Fetching anyway.") do_fetch = True if do_fetch: - fetch = importlib.import_module("ersilia.hub.fetch.fetch") - mf = fetch.ModelFetcher( + mf = ModelFetcher( config_json=self.config_json, credentials_json=self.credentials_json ) asyncio.run(mf.fetch(self.model_id)) - else: - return + self.api_schema = ApiSchema( model_id=self.model_id, config_json=self.config_json ) @@ -213,30 +212,13 @@ def __init__( def fetch(self): """ - Fetch the model if not available locally. - - This method fetches the model from the Ersilia Model Hub if it is not available locally. + This method fetches the model from the Ersilia Model Hub. """ - if not self._is_available_locally and self.fetch_if_not_available: - self.logger.info("Model is not available locally") - try: - do_fetch = yes_no_input( - "Requested model {0} is not available locally. Do you want to fetch it? [Y/n]".format( - self.model_id - ), - default_answer="Y", - ) - except: - self.logger.debug("Unable to capture user input. Fetching anyway.") - do_fetch = True - if do_fetch: - fetch = importlib.import_module("ersilia.hub.fetch.fetch") - mf = fetch.ModelFetcher( - config_json=self.config_json, credentials_json=self.credentials_json - ) - asyncio.run(mf.fetch(self.model_id)) - else: - return + mf = ModelFetcher( + config_json=self.config_json, credentials_json=self.credentials_json + ) + asyncio.run(mf.fetch(self.model_id)) + def __enter__(self): """ diff --git a/test/test_models.py b/test/test_models.py index 630e562bf..91e1dd490 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -87,7 +87,7 @@ def mock_run(): @patch("ersilia.core.model.ErsiliaModel") -def test_models( +def test_model_with_prior_fetching( mock_ersilia_model, mock_fetcher, mock_session, @@ -124,3 +124,42 @@ def test_models( assert mock_serve.called assert mock_run.called assert mock_close.called + + +@patch("ersilia.core.model.ErsiliaModel") +def test_model_with_no_prior_fetching( + mock_ersilia_model, + mock_fetcher, + mock_session, + mock_set_apis, + mock_convn_api_get_apis, + mock_api_task, + mock_serve, + mock_run, + mock_close, +): + MODEL_ID = MODELS[1] + INPUT = "CCCC" + + em = ErsiliaModel( + model=MODEL_ID, service_class="docker", output_source="LOCAL_ONLY" + ) + + em.fetch() + + result = em.run( + input=INPUT, + output="result.csv", + batch_size=100, + track_run=False, + try_standard=False, + ) + + em.serve() + em.close() + + assert result == RESULTS[1] + assert mock_fetcher.called + assert mock_serve.called + assert mock_run.called + assert mock_close.called