From b2d5df801a8ddaef5f418c3377b47c170f31c248 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 7 Mar 2024 12:07:05 +0100 Subject: [PATCH] Allow reading remote configs (#205) --- src/eva/setup.py | 10 ++++++++++ src/eva/trainers/_logging.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/eva/setup.py b/src/eva/setup.py index 7a2e5810..185d0a79 100644 --- a/src/eva/setup.py +++ b/src/eva/setup.py @@ -4,9 +4,18 @@ import sys import warnings +import jsonargparse from loguru import logger +def _configure_jsonargparse() -> None: + """Configures the `jsonargparse` library.""" + jsonargparse.set_config_read_mode( + urls_enabled=True, + fsspec_enabled=True, + ) + + def _initialize_logger() -> None: """Initializes, manipulates and customizes the logger. @@ -44,6 +53,7 @@ def _enable_mps_fallback() -> None: def setup() -> None: """Sets up the environment before the module is imported.""" + _configure_jsonargparse() _initialize_logger() _suppress_warnings() _enable_mps_fallback() diff --git a/src/eva/trainers/_logging.py b/src/eva/trainers/_logging.py index 892891f7..75334ed7 100644 --- a/src/eva/trainers/_logging.py +++ b/src/eva/trainers/_logging.py @@ -4,6 +4,7 @@ import sys from datetime import datetime +from lightning_fabric.utilities import cloud_io from loguru import logger @@ -33,7 +34,7 @@ def _generate_config_hash(max_hash_len: int = 8) -> str: config_path = _fetch_config_path() if config_path is None: logger.warning( - "No or multiple configuration file found from command line arguments." + "No or multiple configuration file found from command line arguments. " "No configuration hash code will created for this experiment." ) return "" @@ -50,7 +51,8 @@ def _fetch_config_path() -> str | None: Returns: The path to the configuration file. """ - config_paths = [f for f in sys.argv if f.endswith(".yaml")] + inputs = sys.argv + config_paths = [inputs[i + 1] for i, arg in enumerate(inputs) if arg == "--config"] if len(config_paths) == 0 or len(config_paths) > 1: # TODO combine the multiple configuration files # and produced hash for the merged one. @@ -69,8 +71,11 @@ def _generate_hash_from_config(path: str, max_hash_len: int = 8) -> str: Returns: Hash of the configuration file content. """ - with open(path, "r") as stream: - config = stream.read().encode("utf-8") + fs = cloud_io.get_filesystem(path) + with fs.open(path, "r") as stream: + config = stream.read() + if isinstance(config, str): + config = config.encode("utf-8") config_sha256 = hashlib.sha256(config) hash_id = config_sha256.hexdigest() return hash_id[:max_hash_len]