From cf56030d7186fc398734f89c636e71b71274b2d9 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Fri, 24 May 2024 14:33:57 -0700 Subject: [PATCH] Wct save interval (#3264) * wct time * pre-commit * timeunit * checker * precommit * test case * test * adding minute and hour * syntax * precommit * precommit * time conversion * syntax * precommit * timedelta compatibility * typo * no wct for trainer and scheduler * precommit * robust assertion * robust assertion * precommit * resolving doc comments * precommit * added testcases * precommit * precommit * change asserts to errors * update parsing * precommit * datetime * datetime * precommit * debug * parsable * force int * typo * typo * debug * debug * mihir refactor * precommit * bool to str * mihir nit * precommit * rerun pr rev * rerun pr rev --------- Co-authored-by: Mihir Patel Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- composer/core/time.py | 55 +++++++++++++++++++++++++++++++++++- composer/optim/scheduler.py | 3 +- composer/trainer/trainer.py | 4 +++ composer/utils/misc.py | 23 ++++++++++++--- docs/source/trainer/time.rst | 8 +++++- tests/test_time.py | 6 ++-- tests/utils/test_misc.py | 45 +++++++++++++++++++---------- 7 files changed, 120 insertions(+), 24 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index f05b521614..1fe34cad63 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -27,6 +27,30 @@ __all__ = ['TimeUnit', 'Time', 'Timestamp', 'ensure_time'] +def verify_wct(timestamp: str) -> str: + """Return a valid datetime formated wct timestamp if input is a valid wct. + + Args: + timestamp (str): A string that represents a timestamp in wct. + + Returns: + str: a properly formatted datetime if input is valid else None + """ + if 'h' not in timestamp: + timestamp = '0h' + timestamp + if 'm' not in timestamp: + timestamp = timestamp.replace('h', 'h0m') + if 's' not in timestamp: + timestamp = timestamp + '0s' + + pattern = r'^(\d+h)?(\d+m)?(\d+s)?$' + match = re.match(pattern, timestamp) + if bool(match): + return timestamp + else: + raise ValueError(f'{timestamp} was passed in, which does not fit XXhYYmZZs formatting') + + class TimeUnit(StringEnum): """Enum class to represent units of time for the training process. @@ -44,6 +68,7 @@ class TimeUnit(StringEnum): SAMPLE = 'sp' TOKEN = 'tok' DURATION = 'dur' + SECOND = 'sec' # regex for parsing time string, matches timeunit and chars prior to unit as value @@ -212,6 +237,25 @@ def from_duration(cls, duration: float) -> Time: """ return cls(duration, TimeUnit.DURATION) + @classmethod + def from_timedelta(cls, timestring: str) -> Time: + """Create a :class:`Time` with units of :attr:`TimeUnit.SECOND`. + + Equivalent to ``Time(batch, TimeUnit.SECOND)``. + + Args: + timestring (int): timedelta string in _h_m_s. + + Returns: + Time: :class:`Time` instance, in seconds. + """ + # Convert timestring to be strptime parsable + verified_wct = verify_wct(timestring) + time_struct = datetime.datetime.strptime(verified_wct, '%Hh%Mm%Ss') + delta = datetime.timedelta(hours=time_struct.hour, minutes=time_struct.minute, seconds=time_struct.second) + total_seconds = delta.total_seconds() + return cls(int(total_seconds), TimeUnit.SECOND) + @property def value(self) -> TValue: """The value of the time, as a number.""" @@ -392,6 +436,12 @@ def from_timestring(cls, timestring: str) -> Time: Returns: Time: An instance of :class:`Time`. """ + # Handle TimeDelta matching first + try: + return Time.from_timedelta(timestring) + except ValueError: + pass + match = _TIME_STR_REGEX.findall(timestring) if len(match) != 1: raise ValueError(f'Invalid time string: {timestring}') @@ -647,6 +697,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]: return self.sample if unit == TimeUnit.TOKEN: return self.token + if unit == TimeUnit.SECOND: + return Time(int(self._total_wct.total_seconds()) if self._total_wct else 0, TimeUnit.SECOND) raise ValueError(f'Invalid unit: {unit}') def _parse(self, other: Union[int, float, Time, str]) -> Time: @@ -944,4 +996,5 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str Returns: Time: An instance of :class:`.Time`. """ - return Time.from_input(maybe_time, int_unit) + time_obj = Time.from_input(maybe_time, int_unit) + return time_obj diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 8064f196e3..3d98cfa12e 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -138,7 +138,8 @@ def __call__(self, state: State, ssr: float = 1.0) -> float: def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: float = 1.0) -> Time[int]: if isinstance(time, str): time = Time.from_timestring(time) - + if time.unit == TimeUnit.SECOND: + raise ValueError('Wall clock time not an allowed time unit.') assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' if time.unit == TimeUnit.DURATION: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 2b363d9db0..a09a33018b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1574,6 +1574,8 @@ def __init__( # Max Duration if max_duration is not None: self.state.max_duration = ensure_time(max_duration, TimeUnit.EPOCH) + if self.state.max_duration.unit == TimeUnit.SECOND: + raise ValueError('Wall clock time not an allowed time unit.') self.logger.log_hyperparameters({'rank_zero_seed': rank_zero_seed}) @@ -2157,6 +2159,8 @@ def fit( # Max Duration if duration is not None: duration = ensure_time(duration, TimeUnit.EPOCH) + if duration.unit == TimeUnit.SECOND: + raise ValueError('Wall clock time not an allowed time unit.') # Effectively increment the max duration (if not resetting the Time) # or set the max_duration (if resetting the time -- self.state.timestamp.get(duration.unit) will be 0) # It is important to set the duration, rather than incrementing it, as ``duration`` could be in diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 31480a11b6..df5a14214e 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -87,11 +87,18 @@ def create_interval_scheduler( interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END elif time_interval.unit == TimeUnit.ITERATION: interval_event = Event.ITERATION_CHECKPOINT if checkpoint_events else Event.ITERATION_END - elif time_interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}: + elif time_interval.unit in { + TimeUnit.BATCH, + TimeUnit.TOKEN, + TimeUnit.SAMPLE, + TimeUnit.DURATION, + TimeUnit.SECOND, + }: interval_event = Event.BATCH_CHECKPOINT if checkpoint_events else Event.BATCH_END else: raise NotImplementedError( - f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.', + f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, ' +\ + 'TimeUnit.SAMPLE, TimeUnit.SECOND', ) last_batch_seen = -1 @@ -113,7 +120,14 @@ def check_interval(state: State, event: Event): if include_end_of_training and event in final_events and elapsed_duration >= 1.0 and state.timestamp.batch != last_batch_seen: return True - if time_interval.unit in {TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}: + if time_interval.unit in { + TimeUnit.ITERATION, + TimeUnit.EPOCH, + TimeUnit.BATCH, + TimeUnit.TOKEN, + TimeUnit.SAMPLE, + TimeUnit.SECOND, + }: previous_count = state.previous_timestamp.get(time_interval.unit) count = state.timestamp.get(time_interval.unit) # If the eval_interval is a duration, we will track progress in terms of the unit of max_duration @@ -123,7 +137,8 @@ def check_interval(state: State, event: Event): count = state.timestamp.get(state.max_duration.unit) else: raise NotImplementedError( - f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.', + f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, ' +\ + 'TimeUnit.SAMPLE, TimeUnit.SECOND', ) threshold_passed = math.floor(previous_count / time_interval.value) != math.floor(count / time_interval.value) diff --git a/docs/source/trainer/time.rst b/docs/source/trainer/time.rst index 84013bcaa8..1ae75b494e 100644 --- a/docs/source/trainer/time.rst +++ b/docs/source/trainer/time.rst @@ -18,12 +18,18 @@ can be provided as a string: "Samples", ``"sp"``, ``"2048sp"``, :attr:`.TimeUnit.SAMPLE` "Tokens", ``"tok"``, ``"93874tok"``, :attr:`.TimeUnit.TOKEN` "Duration", ``"dur"``, ``"0.7dur"``, :attr:`.TimeUnit.DURATION` + "Seconds", ``"sec"``, ``"30sec"``, :attr:`.TimeUnit.SECOND` Duration is defined as a multiplier of the ``max_duration``. These above string inputs are valid when an argument accepts the |Time| type. There are some exceptions -- for example ``dur`` is not valid when -setting ``max_duration`` as that is circular. +setting ``max_duration`` as that is circular and seconds cannot be used +for schedulers and ``max_duration``. + +Using timedelta strings are also supported and will be converted into +seconds in the |Time| class. For instance, something like 1h20m40s is +supported and will be converted to Time(4840, TimeUnit.SECOND). Users can also specify milestones for objects such as learning rate schedulers in units of ``duration``, e.g. ``0.1dur``. This makes it easy to build recipes diff --git a/tests/test_time.py b/tests/test_time.py index 75b8d69e5f..b5fad369d9 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -20,6 +20,7 @@ ['4_000tok', 4000, TimeUnit.TOKEN], ['4_00_0tok', 4000, TimeUnit.TOKEN], ['0.5dur', 0.5, TimeUnit.DURATION], + ['1h20m40s', 4840, TimeUnit.SECOND], ], ) def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUnit): @@ -37,6 +38,7 @@ def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUn ['3sp', Time(3, TimeUnit.SAMPLE)], ['4tok', Time(4, TimeUnit.TOKEN)], ['0.5dur', Time(0.5, TimeUnit.DURATION)], + ['6sec', Time(6, TimeUnit.SECOND)], ], ) def test_to_timestring(expected_timestring: str, time: Time): @@ -254,12 +256,12 @@ def test_timestamp_repr(): assert timestamp == eval(repr(timestamp)) -@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok']) +@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok', '0.1sec']) def test_timestep_bad_strings(time_string: str): with pytest.raises(TypeError): Time.from_timestring(time_string) -@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp']) +@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp', '30sec']) def test_timestep_valid_strings(time_string: str): Time.from_timestring(time_string) diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index 8359f8ce62..c4992a21a4 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -1,6 +1,8 @@ -# Copyright 2022 MosaicML Composer authors +# Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import datetime + import pytest from composer.core import Event, Time, Timestamp @@ -9,14 +11,22 @@ class DummyState: - def __init__(self, current_batches: int, max_duration: str, dataloader_len: str): - self.previous_timestamp = Timestamp(batch=current_batches - 1) - self.timestamp = Timestamp(batch=current_batches) + def __init__(self, current_batches: int, max_duration: str, dataloader_len: str, seconds_per_batch: int): + self.previous_timestamp = Timestamp( + batch=current_batches - 1, + total_wct=datetime.timedelta(seconds=(current_batches - 1) * seconds_per_batch), + ) + self.timestamp = Timestamp( + batch=current_batches - 1, + total_wct=datetime.timedelta(seconds=current_batches * seconds_per_batch), + ) self.max_duration = Time.from_timestring(max_duration) self.dataloader_len = Time.from_timestring(dataloader_len) + self.seconds_per_batch = seconds_per_batch + self.total_elapsed_time = datetime.timedelta(seconds=current_batches * seconds_per_batch) def get_elapsed_duration(self): - return 0 + return self.total_elapsed_time.total_seconds() / self.max_duration.value def test_partial_format(): @@ -38,16 +48,20 @@ def test_partial_format(): @pytest.mark.parametrize( - 'interval,current_batches,max_duration,dataloader_len,expected', + 'interval,current_batches,max_duration,dataloader_len,seconds_per_batch,expected', [ - ('0.25dur', 1, '1ep', '1ba', True), - ('0.25dur', 1, '1ep', '4ba', True), - ('0.25dur', 2, '1ep', '5ba', True), - ('0.25dur', 1, '1ep', '5ba', False), - ('0.25dur', 1, '1ba', '1ba', True), - ('0.25dur', 1, '4ba', '4ba', True), - ('0.25dur', 2, '5ba', '5ba', True), - ('0.25dur', 1, '5ba', '5ba', False), + ('0.25dur', 1, '1ep', '1ba', 10, True), + ('0.25dur', 1, '1ep', '4ba', 10, True), + ('0.25dur', 2, '1ep', '5ba', 10, True), + ('0.25dur', 1, '1ep', '5ba', 10, True), + ('0.25dur', 1, '1ba', '1ba', 10, True), + ('0.25dur', 1, '4ba', '4ba', 10, True), + ('0.25dur', 2, '5ba', '5ba', 10, True), + ('0.25dur', 1, '5ba', '5ba', 10, True), + ('10sec', 1, '6ba', '1ba', 10, True), + ('10sec', 5, '6ba', '1ba', 10, True), + ('10sec', 6, '6ba', '1ba', 10, True), + ('20sec', 2, '6ba', '1ba', 1, False), ], ) def test_interval_scheduler( @@ -55,10 +69,11 @@ def test_interval_scheduler( current_batches: int, max_duration: str, dataloader_len: str, + seconds_per_batch: int, expected: bool, ): interval_scheduler = create_interval_scheduler(interval) - dummy_state = DummyState(current_batches, max_duration, dataloader_len) + dummy_state = DummyState(current_batches, max_duration, dataloader_len, seconds_per_batch) event = Event.BATCH_CHECKPOINT