From 46616537f7d555b8bfcae945527e512c44e6f608 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 26 Nov 2024 14:13:41 +0100 Subject: [PATCH] Expose `DistributedSampler` RNG seed argument Relies on https://github.com/mosaicml/composer/pull/3724. --- llmfoundry/data/finetuning/dataloader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 661729ff8a..bd130bf77c 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -63,6 +63,7 @@ def build_finetuning_dataloader( prefetch_factor: int = 2, persistent_workers: bool = True, timeout: int = 0, + shuffle_seed: int = 0, ) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. @@ -168,6 +169,9 @@ def build_finetuning_dataloader( timeout (int, optional): If positive, the timeout value for collecting a batch from workers. Should always be non-negative. The default is 0. This argument is passed directly to the pytorch :class:`DataLoader`. + shuffle_seed (int, optional): Initialization value for the random number generator of the + distributed sampler. Only relevant if `dataset.shuffle=True`. The default is 0. This + argument is passed directly to the PyTorch :class:`DistributedSampler`. See :class:`DataLoader` for standard argument options to the pytorch dataloader, such as `drop_last`, `num_workers`, etc. @@ -336,6 +340,7 @@ def build_finetuning_dataloader( replication_factor if replication_factor > 1 else None, rank=dist.get_global_rank() // replication_factor if replication_factor > 1 else None, + seed=shuffle_seed, ) assert streaming_dataset is not None # for pyright