Skip to content

Commit

Permalink
Use async logging in MLflowLogger (mosaicml#2693)
Browse files Browse the repository at this point in the history
* async mlflow logging

Signed-off-by: chenmoneygithub <[email protected]>

* small fix

Signed-off-by: chenmoneygithub <[email protected]>

* clean up

* fix test

* fix tests

* deflake

* pin mlflow

---------

Signed-off-by: chenmoneygithub <[email protected]>
  • Loading branch information
chenmoneygithub authored Dec 11, 2023
1 parent 236b738 commit 39d6df4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 30 deletions.
33 changes: 10 additions & 23 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def __init__(
flush_interval: int = 10,
model_registry_prefix: str = '',
model_registry_uri: Optional[str] = None,
synchronous: bool = False,
) -> None:
try:
import mlflow
from mlflow import MlflowClient
from mlflow.utils.autologging_utils import MlflowAutologgingQueueingClient
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='mlflow',
conda_package='mlflow',
Expand All @@ -78,6 +78,7 @@ def __init__(
self.tags = tags
self.model_registry_prefix = model_registry_prefix
self.model_registry_uri = model_registry_uri
self.synchronous = synchronous
if self.model_registry_uri == 'databricks-uc':
if len(self.model_registry_prefix.split('.')) != 2:
raise ValueError(f'When registering to Unity Catalog, model_registry_prefix must be in the format ' +
Expand All @@ -98,12 +99,7 @@ def __init__(
self.experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name,
DEFAULT_MLFLOW_EXPERIMENT_NAME)
self._mlflow_client = MlflowClient(self.tracking_uri)
# Create an instance of MlflowAutologgingQueueingClient - an optimized version
# of MlflowClient - that automatically batches metrics together and supports
# asynchronous logging for improved performance
self._optimized_mlflow_client = MlflowAutologgingQueueingClient(self.tracking_uri)
# Set experiment. We use MlflowClient for experiment retrieval and creation
# because MlflowAutologgingQueueingClient doesn't support it
# Set experiment.
env_exp_id = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_ID.name, None)
if env_exp_id is not None:
self._experiment_id = env_exp_id
Expand Down Expand Up @@ -154,26 +150,24 @@ def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Tabl
)

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
from mlflow import log_metrics
if self._enabled:
# Convert all metrics to floats to placate mlflow.
metrics = {k: float(v) for k, v in metrics.items()}
self._optimized_mlflow_client.log_metrics(
run_id=self._run_id,
log_metrics(
metrics=metrics,
step=step,
synchronous=self.synchronous,
)
time_since_flush = time.time() - self._last_flush_time
if time_since_flush >= self._flush_interval:
self._optimized_mlflow_client.flush(synchronous=False)
self._last_flush_time = time.time()

def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
from mlflow import log_params

if self._enabled:
self._optimized_mlflow_client.log_params(
run_id=self._run_id,
log_params(
params=hyperparameters,
synchronous=self.synchronous,
)
self._optimized_mlflow_client.flush(synchronous=False)

def register_model(
self,
Expand Down Expand Up @@ -269,16 +263,9 @@ def post_close(self):
if self._enabled:
import mlflow

# We use MlflowClient for run termination because MlflowAutologgingQueueingClient's
# run termination relies on scheduling Python futures, which is not supported within
# the Python atexit handler in which post_close() is called
self._mlflow_client.set_terminated(self._run_id)
mlflow.end_run()

def _flush(self):
"""Test-only method to synchronously flush all queued metrics."""
return self._optimized_mlflow_client.flush(synchronous=True)


def _convert_to_mlflow_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray:
if isinstance(image, torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def package_files(prefix: str, directory: str, extension: str):
]

extra_deps['mlflow'] = [
'mlflow>=2.5.0,<3.0',
'mlflow>=2.8.1,<3.0',
]

extra_deps['pandas'] = ['pandas>=2.0.0,<3.0']
Expand Down
9 changes: 3 additions & 6 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import csv
import json
import os
import time
from pathlib import Path
from unittest.mock import MagicMock

Expand Down Expand Up @@ -284,7 +285,6 @@ def test_mlflow_log_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer):
metadata={'task': 'llm/v1/completions'},
task='text-generation',
)
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(mlflow_exp_name, tracking_uri=mlflow_uri)
Expand Down Expand Up @@ -328,7 +328,6 @@ def test_mlflow_save_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer):
metadata={'task': 'llm/v1/completions'},
task='text-generation',
)
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

loaded_model = mlflow.transformers.load_model(local_mlflow_save_path, return_type='components')
Expand Down Expand Up @@ -372,7 +371,6 @@ def test_mlflow_register_model(tmp_path, monkeypatch):
registry_uri='databricks-uc')
assert mlflow.get_registry_uri() == 'databricks-uc'

test_mlflow_logger._flush()
test_mlflow_logger.post_close()


Expand Down Expand Up @@ -411,7 +409,6 @@ def test_mlflow_register_model_non_databricks(tmp_path, monkeypatch):
tags=None,
registry_uri='my_registry_uri')

test_mlflow_logger._flush()
test_mlflow_logger.post_close()


Expand Down Expand Up @@ -456,7 +453,8 @@ def test_mlflow_logging_works(tmp_path, device):
eval_interval=eval_interval,
device=device)
trainer.fit()
test_mlflow_logger._flush()
# Allow async logging to finish.
time.sleep(3)
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(
Expand Down Expand Up @@ -527,7 +525,6 @@ def before_forward(self, state: State, logger: Logger):
device=device)

trainer.fit()
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(
Expand Down

0 comments on commit 39d6df4

Please sign in to comment.