Skip to content

Commit

Permalink
Wct save interval (mosaicml#3264)
Browse files Browse the repository at this point in the history
* wct time

* pre-commit

* timeunit

* checker

* precommit

* test case

* test

* adding minute and hour

* syntax

* precommit

* precommit

* time conversion

* syntax

* precommit

* timedelta compatibility

* typo

* no wct for trainer and scheduler

* precommit

* robust assertion

* robust assertion

* precommit

* resolving doc comments

* precommit

* added testcases

* precommit

* precommit

* change asserts to errors

* update parsing

* precommit

* datetime

* datetime

* precommit

* debug

* parsable

* force int

* typo

* typo

* debug

* debug

* mihir refactor

* precommit

* bool to str

* mihir nit

* precommit

* rerun pr rev

* rerun pr rev

---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people authored May 24, 2024
1 parent 79e79eb commit cf56030
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 24 deletions.
55 changes: 54 additions & 1 deletion composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@
__all__ = ['TimeUnit', 'Time', 'Timestamp', 'ensure_time']


def verify_wct(timestamp: str) -> str:
"""Return a valid datetime formated wct timestamp if input is a valid wct.
Args:
timestamp (str): A string that represents a timestamp in wct.
Returns:
str: a properly formatted datetime if input is valid else None
"""
if 'h' not in timestamp:
timestamp = '0h' + timestamp
if 'm' not in timestamp:
timestamp = timestamp.replace('h', 'h0m')
if 's' not in timestamp:
timestamp = timestamp + '0s'

pattern = r'^(\d+h)?(\d+m)?(\d+s)?$'
match = re.match(pattern, timestamp)
if bool(match):
return timestamp
else:
raise ValueError(f'{timestamp} was passed in, which does not fit XXhYYmZZs formatting')


class TimeUnit(StringEnum):
"""Enum class to represent units of time for the training process.
Expand All @@ -44,6 +68,7 @@ class TimeUnit(StringEnum):
SAMPLE = 'sp'
TOKEN = 'tok'
DURATION = 'dur'
SECOND = 'sec'


# regex for parsing time string, matches timeunit and chars prior to unit as value
Expand Down Expand Up @@ -212,6 +237,25 @@ def from_duration(cls, duration: float) -> Time:
"""
return cls(duration, TimeUnit.DURATION)

@classmethod
def from_timedelta(cls, timestring: str) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.SECOND`.
Equivalent to ``Time(batch, TimeUnit.SECOND)``.
Args:
timestring (int): timedelta string in _h_m_s.
Returns:
Time: :class:`Time` instance, in seconds.
"""
# Convert timestring to be strptime parsable
verified_wct = verify_wct(timestring)
time_struct = datetime.datetime.strptime(verified_wct, '%Hh%Mm%Ss')
delta = datetime.timedelta(hours=time_struct.hour, minutes=time_struct.minute, seconds=time_struct.second)
total_seconds = delta.total_seconds()
return cls(int(total_seconds), TimeUnit.SECOND)

@property
def value(self) -> TValue:
"""The value of the time, as a number."""
Expand Down Expand Up @@ -392,6 +436,12 @@ def from_timestring(cls, timestring: str) -> Time:
Returns:
Time: An instance of :class:`Time`.
"""
# Handle TimeDelta matching first
try:
return Time.from_timedelta(timestring)
except ValueError:
pass

match = _TIME_STR_REGEX.findall(timestring)
if len(match) != 1:
raise ValueError(f'Invalid time string: {timestring}')
Expand Down Expand Up @@ -647,6 +697,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
return self.sample
if unit == TimeUnit.TOKEN:
return self.token
if unit == TimeUnit.SECOND:
return Time(int(self._total_wct.total_seconds()) if self._total_wct else 0, TimeUnit.SECOND)
raise ValueError(f'Invalid unit: {unit}')

def _parse(self, other: Union[int, float, Time, str]) -> Time:
Expand Down Expand Up @@ -944,4 +996,5 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str
Returns:
Time: An instance of :class:`.Time`.
"""
return Time.from_input(maybe_time, int_unit)
time_obj = Time.from_input(maybe_time, int_unit)
return time_obj
3 changes: 2 additions & 1 deletion composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def __call__(self, state: State, ssr: float = 1.0) -> float:
def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: float = 1.0) -> Time[int]:
if isinstance(time, str):
time = Time.from_timestring(time)

if time.unit == TimeUnit.SECOND:
raise ValueError('Wall clock time not an allowed time unit.')
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'

if time.unit == TimeUnit.DURATION:
Expand Down
4 changes: 4 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,8 @@ def __init__(
# Max Duration
if max_duration is not None:
self.state.max_duration = ensure_time(max_duration, TimeUnit.EPOCH)
if self.state.max_duration.unit == TimeUnit.SECOND:
raise ValueError('Wall clock time not an allowed time unit.')

self.logger.log_hyperparameters({'rank_zero_seed': rank_zero_seed})

Expand Down Expand Up @@ -2157,6 +2159,8 @@ def fit(
# Max Duration
if duration is not None:
duration = ensure_time(duration, TimeUnit.EPOCH)
if duration.unit == TimeUnit.SECOND:
raise ValueError('Wall clock time not an allowed time unit.')
# Effectively increment the max duration (if not resetting the Time)
# or set the max_duration (if resetting the time -- self.state.timestamp.get(duration.unit) will be 0)
# It is important to set the duration, rather than incrementing it, as ``duration`` could be in
Expand Down
23 changes: 19 additions & 4 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,18 @@ def create_interval_scheduler(
interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END
elif time_interval.unit == TimeUnit.ITERATION:
interval_event = Event.ITERATION_CHECKPOINT if checkpoint_events else Event.ITERATION_END
elif time_interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
elif time_interval.unit in {
TimeUnit.BATCH,
TimeUnit.TOKEN,
TimeUnit.SAMPLE,
TimeUnit.DURATION,
TimeUnit.SECOND,
}:
interval_event = Event.BATCH_CHECKPOINT if checkpoint_events else Event.BATCH_END
else:
raise NotImplementedError(
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.',
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, ' +\
'TimeUnit.SAMPLE, TimeUnit.SECOND',
)

last_batch_seen = -1
Expand All @@ -113,7 +120,14 @@ def check_interval(state: State, event: Event):
if include_end_of_training and event in final_events and elapsed_duration >= 1.0 and state.timestamp.batch != last_batch_seen:
return True

if time_interval.unit in {TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
if time_interval.unit in {
TimeUnit.ITERATION,
TimeUnit.EPOCH,
TimeUnit.BATCH,
TimeUnit.TOKEN,
TimeUnit.SAMPLE,
TimeUnit.SECOND,
}:
previous_count = state.previous_timestamp.get(time_interval.unit)
count = state.timestamp.get(time_interval.unit)
# If the eval_interval is a duration, we will track progress in terms of the unit of max_duration
Expand All @@ -123,7 +137,8 @@ def check_interval(state: State, event: Event):
count = state.timestamp.get(state.max_duration.unit)
else:
raise NotImplementedError(
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.',
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, ' +\
'TimeUnit.SAMPLE, TimeUnit.SECOND',
)

threshold_passed = math.floor(previous_count / time_interval.value) != math.floor(count / time_interval.value)
Expand Down
8 changes: 7 additions & 1 deletion docs/source/trainer/time.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ can be provided as a string:
"Samples", ``"sp"``, ``"2048sp"``, :attr:`.TimeUnit.SAMPLE`
"Tokens", ``"tok"``, ``"93874tok"``, :attr:`.TimeUnit.TOKEN`
"Duration", ``"dur"``, ``"0.7dur"``, :attr:`.TimeUnit.DURATION`
"Seconds", ``"sec"``, ``"30sec"``, :attr:`.TimeUnit.SECOND`

Duration is defined as a multiplier of the ``max_duration``.

These above string inputs are valid when an argument accepts the |Time|
type. There are some exceptions -- for example ``dur`` is not valid when
setting ``max_duration`` as that is circular.
setting ``max_duration`` as that is circular and seconds cannot be used
for schedulers and ``max_duration``.

Using timedelta strings are also supported and will be converted into
seconds in the |Time| class. For instance, something like 1h20m40s is
supported and will be converted to Time(4840, TimeUnit.SECOND).

Users can also specify milestones for objects such as learning rate schedulers
in units of ``duration``, e.g. ``0.1dur``. This makes it easy to build recipes
Expand Down
6 changes: 4 additions & 2 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
['4_000tok', 4000, TimeUnit.TOKEN],
['4_00_0tok', 4000, TimeUnit.TOKEN],
['0.5dur', 0.5, TimeUnit.DURATION],
['1h20m40s', 4840, TimeUnit.SECOND],
],
)
def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUnit):
Expand All @@ -37,6 +38,7 @@ def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUn
['3sp', Time(3, TimeUnit.SAMPLE)],
['4tok', Time(4, TimeUnit.TOKEN)],
['0.5dur', Time(0.5, TimeUnit.DURATION)],
['6sec', Time(6, TimeUnit.SECOND)],
],
)
def test_to_timestring(expected_timestring: str, time: Time):
Expand Down Expand Up @@ -254,12 +256,12 @@ def test_timestamp_repr():
assert timestamp == eval(repr(timestamp))


@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok'])
@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok', '0.1sec'])
def test_timestep_bad_strings(time_string: str):
with pytest.raises(TypeError):
Time.from_timestring(time_string)


@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp'])
@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp', '30sec'])
def test_timestep_valid_strings(time_string: str):
Time.from_timestring(time_string)
45 changes: 30 additions & 15 deletions tests/utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import datetime

import pytest

from composer.core import Event, Time, Timestamp
Expand All @@ -9,14 +11,22 @@

class DummyState:

def __init__(self, current_batches: int, max_duration: str, dataloader_len: str):
self.previous_timestamp = Timestamp(batch=current_batches - 1)
self.timestamp = Timestamp(batch=current_batches)
def __init__(self, current_batches: int, max_duration: str, dataloader_len: str, seconds_per_batch: int):
self.previous_timestamp = Timestamp(
batch=current_batches - 1,
total_wct=datetime.timedelta(seconds=(current_batches - 1) * seconds_per_batch),
)
self.timestamp = Timestamp(
batch=current_batches - 1,
total_wct=datetime.timedelta(seconds=current_batches * seconds_per_batch),
)
self.max_duration = Time.from_timestring(max_duration)
self.dataloader_len = Time.from_timestring(dataloader_len)
self.seconds_per_batch = seconds_per_batch
self.total_elapsed_time = datetime.timedelta(seconds=current_batches * seconds_per_batch)

def get_elapsed_duration(self):
return 0
return self.total_elapsed_time.total_seconds() / self.max_duration.value


def test_partial_format():
Expand All @@ -38,27 +48,32 @@ def test_partial_format():


@pytest.mark.parametrize(
'interval,current_batches,max_duration,dataloader_len,expected',
'interval,current_batches,max_duration,dataloader_len,seconds_per_batch,expected',
[
('0.25dur', 1, '1ep', '1ba', True),
('0.25dur', 1, '1ep', '4ba', True),
('0.25dur', 2, '1ep', '5ba', True),
('0.25dur', 1, '1ep', '5ba', False),
('0.25dur', 1, '1ba', '1ba', True),
('0.25dur', 1, '4ba', '4ba', True),
('0.25dur', 2, '5ba', '5ba', True),
('0.25dur', 1, '5ba', '5ba', False),
('0.25dur', 1, '1ep', '1ba', 10, True),
('0.25dur', 1, '1ep', '4ba', 10, True),
('0.25dur', 2, '1ep', '5ba', 10, True),
('0.25dur', 1, '1ep', '5ba', 10, True),
('0.25dur', 1, '1ba', '1ba', 10, True),
('0.25dur', 1, '4ba', '4ba', 10, True),
('0.25dur', 2, '5ba', '5ba', 10, True),
('0.25dur', 1, '5ba', '5ba', 10, True),
('10sec', 1, '6ba', '1ba', 10, True),
('10sec', 5, '6ba', '1ba', 10, True),
('10sec', 6, '6ba', '1ba', 10, True),
('20sec', 2, '6ba', '1ba', 1, False),
],
)
def test_interval_scheduler(
interval: str,
current_batches: int,
max_duration: str,
dataloader_len: str,
seconds_per_batch: int,
expected: bool,
):
interval_scheduler = create_interval_scheduler(interval)
dummy_state = DummyState(current_batches, max_duration, dataloader_len)
dummy_state = DummyState(current_batches, max_duration, dataloader_len, seconds_per_batch)

event = Event.BATCH_CHECKPOINT

Expand Down

0 comments on commit cf56030

Please sign in to comment.