diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 4ee7caf0..17745a4a 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -1,5 +1,3 @@ - - from typing import Literal from pydantic import model_validator diff --git a/src/zeroband/data.py b/src/zeroband/data.py index c5cd133f..297e6a19 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -417,7 +417,7 @@ def load_all_datasets( split_world_size = world_size - logger.info(f"Loading Train dataset(s)") + logger.info("Loading Train dataset(s)") ds = _load_datasets( dataset_names=data_config.dataset_name_or_paths, diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 27157e6a..25cd424a 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -17,7 +17,7 @@ from zeroband.comms import ElasticDeviceMesh from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens -from zeroband.config import Config, MemoryProfilerConfig +from zeroband.config import Config #, MemoryProfilerConfig from zeroband.utils import ( FakeTokenizer,