From 39d6df44a49edbf8aee5aeaff713f43655191cb4 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 11 Dec 2023 11:34:53 -0800 Subject: [PATCH] Use async logging in MLflowLogger (#2693) * async mlflow logging Signed-off-by: chenmoneygithub * small fix Signed-off-by: chenmoneygithub * clean up * fix test * fix tests * deflake * pin mlflow --------- Signed-off-by: chenmoneygithub --- composer/loggers/mlflow_logger.py | 33 +++++++++-------------------- setup.py | 2 +- tests/loggers/test_mlflow_logger.py | 9 +++----- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index c3a2051334..a28238937e 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -62,11 +62,11 @@ def __init__( flush_interval: int = 10, model_registry_prefix: str = '', model_registry_uri: Optional[str] = None, + synchronous: bool = False, ) -> None: try: import mlflow from mlflow import MlflowClient - from mlflow.utils.autologging_utils import MlflowAutologgingQueueingClient except ImportError as e: raise MissingConditionalImportError(extra_deps_group='mlflow', conda_package='mlflow', @@ -78,6 +78,7 @@ def __init__( self.tags = tags self.model_registry_prefix = model_registry_prefix self.model_registry_uri = model_registry_uri + self.synchronous = synchronous if self.model_registry_uri == 'databricks-uc': if len(self.model_registry_prefix.split('.')) != 2: raise ValueError(f'When registering to Unity Catalog, model_registry_prefix must be in the format ' + @@ -98,12 +99,7 @@ def __init__( self.experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, DEFAULT_MLFLOW_EXPERIMENT_NAME) self._mlflow_client = MlflowClient(self.tracking_uri) - # Create an instance of MlflowAutologgingQueueingClient - an optimized version - # of MlflowClient - that automatically batches metrics together and supports - # asynchronous logging for improved performance - self._optimized_mlflow_client = MlflowAutologgingQueueingClient(self.tracking_uri) - # Set experiment. We use MlflowClient for experiment retrieval and creation - # because MlflowAutologgingQueueingClient doesn't support it + # Set experiment. env_exp_id = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_ID.name, None) if env_exp_id is not None: self._experiment_id = env_exp_id @@ -154,26 +150,24 @@ def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Tabl ) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + from mlflow import log_metrics if self._enabled: # Convert all metrics to floats to placate mlflow. metrics = {k: float(v) for k, v in metrics.items()} - self._optimized_mlflow_client.log_metrics( - run_id=self._run_id, + log_metrics( metrics=metrics, step=step, + synchronous=self.synchronous, ) - time_since_flush = time.time() - self._last_flush_time - if time_since_flush >= self._flush_interval: - self._optimized_mlflow_client.flush(synchronous=False) - self._last_flush_time = time.time() def log_hyperparameters(self, hyperparameters: Dict[str, Any]): + from mlflow import log_params + if self._enabled: - self._optimized_mlflow_client.log_params( - run_id=self._run_id, + log_params( params=hyperparameters, + synchronous=self.synchronous, ) - self._optimized_mlflow_client.flush(synchronous=False) def register_model( self, @@ -269,16 +263,9 @@ def post_close(self): if self._enabled: import mlflow - # We use MlflowClient for run termination because MlflowAutologgingQueueingClient's - # run termination relies on scheduling Python futures, which is not supported within - # the Python atexit handler in which post_close() is called self._mlflow_client.set_terminated(self._run_id) mlflow.end_run() - def _flush(self): - """Test-only method to synchronously flush all queued metrics.""" - return self._optimized_mlflow_client.flush(synchronous=True) - def _convert_to_mlflow_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: if isinstance(image, torch.Tensor): diff --git a/setup.py b/setup.py index 9c6317b3d5..6bf8fa3154 100644 --- a/setup.py +++ b/setup.py @@ -223,7 +223,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['mlflow'] = [ - 'mlflow>=2.5.0,<3.0', + 'mlflow>=2.8.1,<3.0', ] extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 055454670e..180bc666b0 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -4,6 +4,7 @@ import csv import json import os +import time from pathlib import Path from unittest.mock import MagicMock @@ -284,7 +285,6 @@ def test_mlflow_log_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer): metadata={'task': 'llm/v1/completions'}, task='text-generation', ) - test_mlflow_logger._flush() test_mlflow_logger.post_close() run = _get_latest_mlflow_run(mlflow_exp_name, tracking_uri=mlflow_uri) @@ -328,7 +328,6 @@ def test_mlflow_save_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer): metadata={'task': 'llm/v1/completions'}, task='text-generation', ) - test_mlflow_logger._flush() test_mlflow_logger.post_close() loaded_model = mlflow.transformers.load_model(local_mlflow_save_path, return_type='components') @@ -372,7 +371,6 @@ def test_mlflow_register_model(tmp_path, monkeypatch): registry_uri='databricks-uc') assert mlflow.get_registry_uri() == 'databricks-uc' - test_mlflow_logger._flush() test_mlflow_logger.post_close() @@ -411,7 +409,6 @@ def test_mlflow_register_model_non_databricks(tmp_path, monkeypatch): tags=None, registry_uri='my_registry_uri') - test_mlflow_logger._flush() test_mlflow_logger.post_close() @@ -456,7 +453,8 @@ def test_mlflow_logging_works(tmp_path, device): eval_interval=eval_interval, device=device) trainer.fit() - test_mlflow_logger._flush() + # Allow async logging to finish. + time.sleep(3) test_mlflow_logger.post_close() run = _get_latest_mlflow_run( @@ -527,7 +525,6 @@ def before_forward(self, state: State, logger: Logger): device=device) trainer.fit() - test_mlflow_logger._flush() test_mlflow_logger.post_close() run = _get_latest_mlflow_run(