Skip to content

Commit

Permalink
Merge branch 'dev' into checkpoint_saver
Browse files Browse the repository at this point in the history
  • Loading branch information
bigning authored Jun 17, 2024
2 parents 4f3108c + cca51e2 commit 11307f0
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 16 deletions.
93 changes: 79 additions & 14 deletions composer/callbacks/system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os

import psutil
import torch

from composer.core import Callback, Event, State
from composer.loggers import Logger
Expand All @@ -19,13 +20,52 @@

__all__ = ['SystemMetricsMonitor']

_GPU_METRICS = [
'gpu_percentage',
'memory_percentage',
'gpu_temperature_C',
'gpu_power_usage_W',
]


class SystemMetricsMonitor(Callback):
"""Track system metrics."""
"""Logs GPU/CPU metrics.
GPU Metrics:
gpu_percentage: Occupancy rate, percent of time over sampling period during which one or more kernels was executing on the GPU.
memory_percentage: Percent of time over sampling period during which global memory was being read or written.
gpu_temperature_C: Temperature of device, in Celcius.
gpu_power_usage_W: Power usage of device, in Watts.
By default, only the maximum and minimum values for these metrics, alongside their respective ranks in the key names,
are logged on the :attr:`.Event.BATCH_START`, :attr:`.Event.EVAL_BATCH_START`, :attr:`.Event.PREDICT_BATCH_START`
events for every batch. If log_all_data is set to True, all values for these metrics across all ranks are logged on the
above events for every batch.
Example:
.. doctest::
def __init__(self, gpu_available: bool = False) -> None:
>>> from composer import Trainer
>>> from composer.callbacks import SystemMetricsMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[SystemMetricsMonitor()],
... )
Args:
log_all_data (bool, optional): True if user wants to log data for all ranks, not just the min/max.
Defaults to False.
"""

def __init__(self, log_all_data: bool = False) -> None:
super().__init__()
self.gpu_available = gpu_available
self.gpu_available = torch.cuda.is_available()
self.log_all_data = log_all_data
if self.gpu_available:
try:
import pynvml
Expand All @@ -46,9 +86,23 @@ def run_event(self, event: Event, state: State, logger: Logger):
]:
local_node_system_metrics = self.compute_system_metrics()
all_system_metrics = dist.all_gather_object(local_node_system_metrics)
system_metrics = {
key: value for local_metrics in all_system_metrics for key, value in local_metrics.items()
}
system_metrics = {}

if self.log_all_data:
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key in _GPU_METRICS:
system_metrics[f'{key}_rank_{rank}'] = value
else:
system_metrics[key] = value

else:
system_metrics = self.compute_gpu_min_max_metrics(all_system_metrics, state)
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key not in _GPU_METRICS:
system_metrics[key] = value

logger.log_metrics(system_metrics)

def compute_system_metrics(self):
Expand All @@ -58,17 +112,14 @@ def compute_system_metrics(self):
if self.gpu_available:
import pynvml
local_rank = dist.get_local_rank()
global_rank = dist.get_global_rank()
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
system_metrics[f'device{global_rank}_memory_total'] = memory.total
system_metrics[f'device{global_rank}_memory_free'] = memory.free
system_metrics[f'device{global_rank}_memory_used'] = memory.used
device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
system_metrics[f'device{global_rank}_gpu_percentage'] = device_utilization.gpu
system_metrics[f'device{global_rank}_memory_percentage'] = device_utilization.memory
system_metrics['gpu_percentage'] = device_utilization.gpu
system_metrics['memory_percentage'] = device_utilization.memory
temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
system_metrics[f'device{global_rank}_gpu_temperature'] = temperature
system_metrics['gpu_temperature_C'] = temperature
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # convert from mW to W
system_metrics['gpu_power_usage_W'] = power

# Get metrics for the system
cpu_percent = psutil.cpu_percent()
Expand All @@ -83,3 +134,17 @@ def compute_system_metrics(self):
for k, v in network_usage.items():
system_metrics[f'network_{k}'] = v
return system_metrics

def compute_gpu_min_max_metrics(self, all_metrics, state):
min_max_metrics = {}

if self.gpu_available:
for key in _GPU_METRICS:
values = torch.tensor([metrics_for_cur_rank[key] for metrics_for_cur_rank in all_metrics])
values = state.device.tensor_to_device(values)
min_rank = int(torch.argmin(values).item())
max_rank = int(torch.argmax(values).item())
min_max_metrics[f'min_{key}_rank_{min_rank}'] = values[min_rank].item()
min_max_metrics[f'max_{key}_rank_{max_rank}'] = values[max_rank].item()

return min_max_metrics
7 changes: 7 additions & 0 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class MLFlowLogger(LoggerDestination):
resume (bool, optional): If ``True``, Composer will search for an existing run tagged with
the `run_name` and resume it. If no existing run is found, a new run will be created.
If ``False``, Composer will create a new run. (default: ``False``)
logging_buffer_seconds (int, optional): The amount of time, in seconds, that MLflow
waits before sending logs to the MLflow tracking server. Metrics/params/tags logged
within this buffer time will be grouped in batches before being sent to the backend.
"""

def __init__(
Expand All @@ -85,6 +88,7 @@ def __init__(
ignore_hyperparameters: Optional[list[str]] = None,
run_group: Optional[str] = None,
resume: bool = False,
logging_buffer_seconds: Optional[int] = 10,
) -> None:
try:
import mlflow
Expand Down Expand Up @@ -116,6 +120,9 @@ def __init__(
)
self.resume = resume

if logging_buffer_seconds:
os.environ['MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS'] = str(logging_buffer_seconds)

self._rank_zero_only = rank_zero_only
self._last_flush_time = time.time()
self._flush_interval = flush_interval
Expand Down
2 changes: 2 additions & 0 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'Dev
return submesh

else:
from torch.distributed.device_mesh import _mesh_resources

def create_child_mesh(
self, parent_mesh: 'DeviceMesh', submesh_dim_names: Tuple[str, ...],
) -> 'DeviceMesh':
Expand Down
73 changes: 73 additions & 0 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import logging
import os
import pickle
import random
import string
import sys
import time
from contextlib import contextmanager
Expand Down Expand Up @@ -627,6 +629,77 @@ def get_sampler(
)


def get_node_signal_file_name(rng: Optional[random.Random] = None) -> str:
"""Returns a file name to use for a file based wait within a node.
The file name will contain a randomly generated string to avoid conflicts.
Note: This file name will be the same on each node, so that it can be used for a file based wait.
Returns:
str: The name of the file that will be created to signal the end of a node's training.
"""
if rng is None:
rng = random.Random()

random_string = ''.join(rng.choices(string.ascii_letters + string.digits, k=6))
node_rank = get_node_rank()
file_name_list = [f'._signal_file_node{node_rank}_{random_string}']
dist.broadcast_object_list(file_name_list, src=0)
return file_name_list[0]


def write_signal_file(signal_file_name: str, dir_path: Optional[str] = None) -> str:
"""Writes a signal file to the specified directory.
This function creates a signal file in the specified directory. If the directory does
Note: Only local rank zero writes the signal file. All other ranks are expected to wait for the signal file.
Args:
signal_file_name (str): The name of the signal file.
dir_path (str, optional): The full path to the directory in which to create the signal file. If ``None``,
the current working directory will be used.
"""
if dir_path is not None:
os.makedirs(dir_path, exist_ok=True)

signal_file_path = os.path.join(dir_path or os.getcwd(), signal_file_name)
if get_local_rank() == 0:
with open(signal_file_path, 'w') as _f:
_f.write('local rank zero done')

return signal_file_path


@contextmanager
def busy_wait_for_local_rank_zero(dir_path: Optional[str] = None):
"""Busy waits for the signal file to be created by local rank zero.
This function will wait for the signal file to be created by local rank zero. It will
check every 0.1 seconds for the existence of the file.
Args:
dir_path (str, optional): The directory in which to look for the signal file. If ``None``,
the current working directory will be used.
"""
# Get unique file name
signal_file_name = get_node_signal_file_name()

# All ranks yield execution to allow local rank zero to run the code it needs to
yield

# Local rank zero writes the signal file, all other rank just get the expected path
signal_file_path = write_signal_file(signal_file_name=signal_file_name, dir_path=dir_path)

# Wait for the signal file to be created by local rank zero
with local_rank_zero_download_and_wait(signal_file_path):
# Sync all ranks across nodes as busy wait only is within node
dist.barrier()

# Remove the signal file
if get_local_rank() == 0:
os.remove(signal_file_path)


@contextmanager
def local_rank_zero_download_and_wait(expected_file_path: str):
"""Context manager to wait for a file to exist on all ranks except local rank zero.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def package_files(prefix: str, directory: str, extension: str):
extra_deps['mlflow'] = [
'mlflow>=2.11.1,<3.0',
'databricks-sdk==0.28.0',
'pynvml>=11.5.0,<12',
]

extra_deps['pandas'] = ['pandas>=2.0.0,<3.0']
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.gpu
def test_system_metrics_monitor_gpu():
# Construct the trainer
system_metrics_monitor = SystemMetricsMonitor(gpu_available=True)
system_metrics_monitor = SystemMetricsMonitor()
in_memory_logger = InMemoryLogger()
trainer = Trainer(
model=SimpleModel(),
Expand All @@ -24,7 +24,7 @@ def test_system_metrics_monitor_gpu():
)
trainer.fit()

assert 'device0_gpu_percentage' in in_memory_logger.data
assert 'min_gpu_percentage_rank_0' in in_memory_logger.data
assert 'cpu_percentage' in in_memory_logger.data


Expand Down
37 changes: 37 additions & 0 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,43 @@ def test_rename_metrics(self, device, num_batches, tmp_path):
assert not os.path.exists(metric_file)


def test_mlflow_logging_time_buffer(tmp_path):
mlflow = pytest.importorskip('mlflow')
if not hasattr(mlflow.environment_variables, 'MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS'):
pytest.skip('MLFlow {mlflow.__version__} does not support async logging buffer seconds.')

with patch('mlflow.store.tracking.file_store.FileStore.log_batch') as mock_log_batch:

mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
experiment_name = 'mlflow_logging_test'
mock_state = MagicMock()
mock_logger = MagicMock()

test_mlflow_logger = MLFlowLogger(
tracking_uri=mlflow_uri,
experiment_name=experiment_name,
log_system_metrics=True,
run_name='test_run',
logging_buffer_seconds=2,
)
test_mlflow_logger.init(state=mock_state, logger=mock_logger)
test_mlflow_logger.log_hyperparameters({'name': 'test'})
steps = 10
for i in range(steps):
metrics = {
'foo': i,
'bar': i,
}
test_mlflow_logger.log_metrics(metrics, step=i)
test_mlflow_logger.post_close()

# There will be 2 calls to `log_batch`, one from `start_run` with tags, and one from the metrics
# and hyperparameters logging.
assert mock_log_batch.call_count == 2
assert len(mock_log_batch.call_args_list[0][1]['metrics']) == 0
assert len(mock_log_batch.call_args_list[1][1]['metrics']) == 2 * steps


def test_mlflow_resume_run(tmp_path):
mlflow = pytest.importorskip('mlflow')

Expand Down
47 changes: 47 additions & 0 deletions tests/utils/test_dist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import time
from unittest.mock import patch

import pytest
Expand All @@ -27,3 +29,48 @@ def test_run_local_rank_first_context_runs_properly():
# so dist is initialized here and this code should run without error
with dist.run_local_rank_zero_first():
pass


@pytest.mark.world_size(2)
def test_get_node_signal_file_name():
file_name = dist.get_node_signal_file_name()
gathered_file_names = dist.all_gather_object(file_name)

assert len(gathered_file_names) == 2
assert gathered_file_names[0] == gathered_file_names[1]
assert gathered_file_names[0] == file_name
assert file_name.startswith('._signal_file_node0_')
assert len(file_name) == len('._signal_file_node0_') + 6


@pytest.mark.world_size(2)
def test_write_signal_file(tmp_path):
file_name = dist.get_node_signal_file_name()
file_path = os.path.join(tmp_path, file_name)
dist.write_signal_file(file_name, tmp_path)

# tmp_path will be different on each rank, and only rank zero
# should have written a file
if dist.get_local_rank() == 0:
assert os.path.exists(file_path)
else:
assert not os.path.exists(file_path)


@pytest.mark.world_size(2)
def test_busy_wait_for_local_rank_zero(tmp_path):
gathered_tmp_path = dist.all_gather_object(tmp_path)[0]

dist.barrier()
start_time = time.time()
assert os.listdir(gathered_tmp_path) == []
with dist.busy_wait_for_local_rank_zero(gathered_tmp_path):
if dist.get_local_rank() == 0:
time.sleep(0.5)

end_time = time.time()
total_time = end_time - start_time
gathered_times = dist.all_gather_object(total_time)
assert os.listdir(gathered_tmp_path) == []
assert len(gathered_times) == 2
assert abs(gathered_times[0] - gathered_times[1]) < 0.1

0 comments on commit 11307f0

Please sign in to comment.