Skip to content

Commit

Permalink
track served models with their session info so if they are deleted wh…
Browse files Browse the repository at this point in the history
…ile being served, ewe can safely remove their session data (#1474)
  • Loading branch information
DhanshreeA authored Dec 30, 2024
1 parent ad004a3 commit b3112b3
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
3 changes: 2 additions & 1 deletion ersilia/cli/commands/close.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import echo
from ... import ErsiliaModel
from ...core.session import Session

from ...utils.session import deregister_model_session

def close_cmd():
"""
Expand Down Expand Up @@ -35,6 +35,7 @@ def close():
return
mdl = ErsiliaModel(model_id, service_class=service_class)
mdl.close()
deregister_model_session(model_id)
echo(":no_entry: Model {0} closed".format(mdl.model_id), fg="green")

return close
3 changes: 3 additions & 0 deletions ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ... import ErsiliaModel
from ..messages import ModelNotFound
from ...store.utils import OutputSource, ModelNotInStore, store_has_model
from ...utils.session import register_model_session


def serve_cmd():
Expand Down Expand Up @@ -76,6 +77,8 @@ def serve(model, output_source, lake, port, track):
if mdl.url is None:
echo("No URL found. Service unsuccessful.", fg="red")
return

register_model_session(mdl.model_id, mdl.session._session_dir)
echo(
":rocket: Serving model {0}: {1}".format(mdl.model_id, mdl.slug), fg="green"
)
Expand Down
4 changes: 2 additions & 2 deletions ersilia/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class Session(ErsiliaBase):
"""
def __init__(self, config_json):
ErsiliaBase.__init__(self, config_json=config_json, credentials_json=None)
session_dir = get_session_dir()
self.session_file = os.path.join(session_dir, SESSION_JSON)
self._session_dir = get_session_dir()
self.session_file = os.path.join(self._session_dir, SESSION_JSON)

def current_model_id(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions ersilia/hub/delete/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..bundle.status import ModelStatus

from ...default import ISAURA_FILE_TAG, ISAURA_FILE_TAG_LOCAL
from ...utils.session import get_model_session, remove_session_dir, deregister_model_session


def rmtree(path):
Expand Down Expand Up @@ -573,6 +574,10 @@ def can_be_deleted(self, model_id: str) -> Tuple[bool, str]:
bool
True if the model can be deleted, False otherwise.
"""
mdl_session = get_model_session(model_id)
if mdl_session:
remove_session_dir(mdl_session)
deregister_model_session(model_id)
needs_delete = self._needs_delete(model_id)
mc = ModelCard(config_json=self.config_json).get(model_id)
model_source = ModelCatalog(config_json=self.config_json)._get_model_source(mc)
Expand Down
75 changes: 73 additions & 2 deletions ersilia/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import psutil
import json

from ..default import SESSIONS_DIR, LOGS_DIR, CONTAINER_LOGS_TMP_DIR, SESSION_JSON
from ..default import SESSIONS_DIR, LOGS_DIR, CONTAINER_LOGS_TMP_DIR, SESSION_JSON, EOS, MODELS_JSON


def get_current_pid():
Expand Down Expand Up @@ -93,7 +93,8 @@ def remove_session_dir(session_name):
The name of the session.
"""
session_dir = os.path.join(SESSIONS_DIR, session_name)
shutil.rmtree(session_dir)
if os.path.exists(session_dir):
shutil.rmtree(session_dir)


def determine_orphaned_session():
Expand Down Expand Up @@ -140,3 +141,73 @@ def get_session_id():
The session ID.
"""
return f"session_{get_parent_pid()}"


def register_model_session(model_id, session_dir):
"""
Register a model with a session.
Parameters
----------
model_id : str
The model ID.
session_dir : str
The session directory.
"""
file_path = os.path.join(EOS, MODELS_JSON)

if not os.path.exists(file_path):
with open(file_path, "w") as f:
json.dump({}, f, indent=4)

with open(file_path, "r") as f:
models = json.load(f)

if (
model_id not in models
): # TODO This would have implications when we try to run the same model across multiple sessions
models[model_id] = session_dir
with open(file_path, "w") as f:
json.dump(models, f, indent=4)


def get_model_session(model_id):
"""
Get the model session.
Parameters
----------
model_id : str
The model ID.
Returns
-------
str
The session ID.
"""
file_path = os.path.join(EOS, MODELS_JSON)
if not os.path.exists(file_path):
return None
with open(file_path, "r") as f:
models = json.load(f)
return models.get(model_id, None)


def deregister_model_session(model_id):
"""
Remove a model from a session.
Parameters
----------
model_id : str
The model ID.
"""
file_path = os.path.join(EOS, MODELS_JSON)
if not os.path.exists(file_path):
return
with open(file_path, "r") as f:
models = json.load(f)
if model_id in models:
del models[model_id]
with open(file_path, "w") as f:
json.dump(models, f, indent=4)

0 comments on commit b3112b3

Please sign in to comment.