Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relaxing divisibility constraints on num_canonical_nodes and num_physical_nodes #476

Merged
merged 10 commits into from
Oct 26, 2023
2 changes: 1 addition & 1 deletion streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
stream_partition = get_partitions(dataset.partition_algo, samples_in_stream,
dataset.num_canonical_nodes, world.num_nodes,
world.ranks_per_node, world.workers_per_rank, batch_size,
0)
0, dataset.initial_physical_nodes)
if dataset.shuffle:
# Ratio of stream's shuffle block size to overall shuffle block size should be the
# same as the ratio of the stream's samples to overall samples.
Expand Down
3 changes: 2 additions & 1 deletion streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch
# batch) such that we have an elastically deterministic sample order.
big_ids = get_partitions(dataset.partition_algo, dataset.epoch_size,
dataset.num_canonical_nodes, world.num_nodes, world.ranks_per_node,
world.workers_per_rank, dataset.batch_size, sample_in_epoch)
world.workers_per_rank, dataset.batch_size, sample_in_epoch,
dataset.initial_physical_nodes)

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
Expand Down
3 changes: 2 additions & 1 deletion streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# We also handle used samples (drop_first) at the end.
stream_partition = get_partitions(dataset.partition_algo, samples_in_stream,
dataset.num_canonical_nodes, 1, world.ranks_per_node,
world.workers_per_rank, 1, 0)
world.workers_per_rank, 1, 0,
dataset.initial_physical_nodes)
if dataset.shuffle:
# Ratio of stream's shuffle block size to overall shuffle block size should be the
# same as the ratio of the stream's samples to overall samples.
Expand Down
19 changes: 18 additions & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ def __init__(self,
self.shuffle_block_size = shuffle_block_size
self.batching_method = batching_method

# Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the
# number of physical nodes of the initial run in the _resume function.
self.initial_physical_nodes = None
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

# Check streams vs remote/local.
if bool(streams) == (bool(remote) or bool(local)):
raise ValueError(
Expand Down Expand Up @@ -678,6 +682,9 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
sample_in_epoch = obj['sample_in_epoch']
self.num_canonical_nodes = obj['num_canonical_nodes']
self.shuffle_seed = obj['shuffle_seed']
# Ensure that we are backwards compatible with old checkpoint dataset state, since the
# 'initial_physical_nodes' key may not be present.
self.initial_physical_nodes = obj.get('initial_physical_nodes', None)
self._set_predownload()

return epoch, sample_in_epoch
Expand Down Expand Up @@ -732,11 +739,21 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
sample_in_epoch = num_samples
else:
sample_in_epoch = offset + num_samples

if self.initial_physical_nodes is None:
# In this case, we are running for the first time, so we set initial_physical_nodes
# to the current number of physical nodes.
initial_physical_nodes = world.num_nodes
else:
# In this case, initial_physical_nodes has already been set from an initial run. We
# keep this value persisted in the state across the total run duration.
initial_physical_nodes = self.initial_physical_nodes
return {
'epoch': epoch,
'sample_in_epoch': sample_in_epoch,
'num_canonical_nodes': self.num_canonical_nodes,
'shuffle_seed': self.shuffle_seed
'shuffle_seed': self.shuffle_seed,
'initial_physical_nodes': initial_physical_nodes,
}

def load_state_dict(self, obj: Dict[str, Any]) -> None:
Expand Down
9 changes: 7 additions & 2 deletions streaming/base/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from numpy.typing import NDArray

from streaming.base.partition.orig import get_partitions_orig
from streaming.base.partition.relaxed import get_partitions_relaxed

algos = {
'orig': get_partitions_orig,
'relaxed': get_partitions_relaxed,
}


Expand All @@ -22,7 +24,8 @@ def get_partitions(algo: str,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0) -> NDArray[np.int64]:
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other.
Expand All @@ -41,11 +44,13 @@ def get_partitions(algo: str,
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
get = algos[algo]
return get(num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node,
workers_per_rank, batch_size, drop_first)
workers_per_rank, batch_size, drop_first, initial_physical_nodes)
5 changes: 4 additions & 1 deletion streaming/base/partition/orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_partitions_orig(num_samples: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0) -> NDArray[np.int64]:
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other.
Expand All @@ -37,6 +38,8 @@ def get_partitions_orig(num_samples: int,
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
Expand Down
98 changes: 98 additions & 0 deletions streaming/base/partition/relaxed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order."""

import logging
from typing import Optional

import numpy as np
from numpy.typing import NDArray

from streaming.base.partition.orig import get_partitions_orig

logger = logging.getLogger(__name__)


def get_partitions_relaxed(num_samples: int,
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
num_canonical_nodes: int,
num_physical_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other when partitioning over
the initial number of physical nodes. For partitions during resumption, the only constraint
is that the global batch size, which remains constant during training, must be evenly divisible
by the total number of devices, which is num_physical_nodes * ranks_per_node.

It is suggested to set num_canonical_nodes higher than your expected number of physical nodes,
because scaling your number of nodes below that level may result in more shards being used
across node boundaries due to preserving the same global sample order.

Args:
num_samples (int): Dataset size.
num_canonical_nodes (int): Number of canonical nodes.
num_physical_nodes (int): Number of physical nodes.
ranks_per_node (int): Number of ranks per node.
workers_per_rank (int): Number of worker partitions per rank.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

if initial_physical_nodes is None or (num_physical_nodes <= num_canonical_nodes and
num_canonical_nodes % num_physical_nodes == 0) or \
(num_physical_nodes > num_canonical_nodes and
num_physical_nodes % num_canonical_nodes == 0):
# Case 1: We are partitioning for the first time. Use the original partitions algorithm,
# which also requires that NCN be divisible by PN or vice versa.
# Case 2: PN <= NCN and PN evenly divides NCN. The original partition algo can be used,
# and will give better downloads per node as well.
# Case 3: PN > NCN and NCN evenly divides PN. The original partition algo can be used.
return get_partitions_orig(num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)
else:
batch_size = batch_size or 1
# First, make a partition over the initial number of physical nodes and device batch size.
# We assume that ranks_per_node and workers_per_rank stay constant during resumptions.
global_batch_size = num_physical_nodes * ranks_per_node * batch_size
initial_total_devices = initial_physical_nodes * ranks_per_node
# Check for divisibility of the current global batch size and the initial total devices.
# This should be true since the global batch size should not change in the middle of
# training.
if global_batch_size % initial_total_devices != 0:
raise ValueError(f'A global batch size of {global_batch_size} is not evenly ' +
f'divisible by the initial total number of devices of ' +
f'{initial_total_devices}. Make sure that when using ' +
f'the `relaxed` partitioning algorithm, the global batch size does ' +
f'not change during resumption of training.')
initial_batch_size = global_batch_size // initial_total_devices
partition = get_partitions_orig(num_samples, num_canonical_nodes, initial_physical_nodes,
ranks_per_node, workers_per_rank, initial_batch_size,
drop_first)

# Flatten the initial partition in order of traversal.
# partition was originally (nodes, ranks, workers, batches per worker, batch size)
# in-order, the dimensions are (batches per worker, workers, nodes, ranks, batch size)
partition = partition.transpose(3, 2, 0, 1, 4).flatten()

# Reshape the in-order traversal of the partition to the new physical nodes and batch size.
partition = partition.reshape(-1, workers_per_rank, num_physical_nodes, ranks_per_node,
batch_size)

# Re-transpose this partition matrix back to the original format below and return it:
# (physical nodes, ranks per node, workers per rank, batches per worker, batch size)
return partition.transpose(2, 3, 1, 0, 4)
Loading