Skip to content

Commit

Permalink
Tensor Parallelism Tests (#3620)
Browse files Browse the repository at this point in the history
Co-authored-by: Eitan Turok <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 3eda9cf commit f76c2ff
Show file tree
Hide file tree
Showing 4 changed files with 398 additions and 21 deletions.
6 changes: 6 additions & 0 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InfiniteClassificationDataset,
ParityDataset,
RandomClassificationDataset,
RandomClassificationDatasetReplicated,
RandomImageDataset,
RandomSegmentationDataset,
RandomTextClassificationDataset,
Expand All @@ -21,13 +22,15 @@
EmbeddedWeightTiedModel,
EmptyModel,
EvenSimplerMLP,
SimpleComposerMLP,
SimpleConvModel,
SimpleMLP,
SimpleModel,
SimpleModelWithDropout,
SimpleTransformerClassifier,
SimpleTransformerMaskedLM,
SimpleWeightTiedModel,
TPSimpleComposerMLP,
ZeroModel,
composer_resnet,
)
Expand All @@ -42,6 +45,7 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]:
__all__ = [
'assert_state_equivalent',
'RandomClassificationDataset',
'RandomClassificationDatasetReplicated',
'RandomTextClassificationDataset',
'RandomTextLMDataset',
'RandomImageDataset',
Expand All @@ -67,4 +71,6 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]:
'composer_resnet',
'SimpleMLP',
'EvenSimplerMLP',
'SimpleComposerMLP',
'TPSimpleComposerMLP',
]
69 changes: 60 additions & 9 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision.datasets import VisionDataset

from composer.utils import dist
from composer.utils import dist, reproducibility
from tests.common.models import configure_tiny_bert_tokenizer, configure_tiny_gpt2_tokenizer


Expand Down Expand Up @@ -55,12 +55,19 @@ class RandomClassificationDataset(Dataset):
num_classes (int): number of classes (default: 2)
"""

def __init__(self, shape: Sequence[int] = (1, 1, 1), size: int = 100, num_classes: int = 2):
self.size = size
self.shape = shape
self.num_classes = num_classes
self.x = None
self.y = None
def __init__(
self,
shape: Sequence[int] = (1, 1, 1),
size: int = 100,
num_classes: int = 2,
device: Optional[torch.device] = None,
):
self.size: int = size
self.shape: Sequence[int] = shape
self.num_classes: int = num_classes
self.device: Optional[torch.device] = device
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None

def __len__(self):
return self.size
Expand All @@ -69,12 +76,56 @@ def __getitem__(self, index: int):
# Note: lazily generate data so it runs after Composer seeds everything, giving the same
# dataset across multiple calls when using the same seed.
if self.x is None:
self.x = torch.randn(self.size, *self.shape)
self.x = torch.randn(
self.size,
*self.shape,
device=self.device,
)
if self.y is None:
self.y = torch.randint(0, self.num_classes, size=(self.size,))
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)
return self.x[index], self.y[index]


class RandomClassificationDatasetReplicated(RandomClassificationDataset):
"""Like RandomClassificationDataset but samples are replicated across tensor parallelism groups."""

def __init__(
self,
shape: Sequence[int] = (1, 1, 1),
size: int = 100,
num_classes: int = 2,
device: Optional[torch.device] = None,
seed: int = 44,
replication: Optional[int] = 2,
):
super().__init__(shape, size, num_classes, device)
self.rank = dist.get_local_rank()
self.world_size = dist.get_world_size()
assert replication is not None
self.n_tp_groups = replication # the number of tp groups that we are replicating across
self.seed = seed

def _generate_data(self):
tp_group_id = self.rank // self.n_tp_groups
seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed
reproducibility.seed_all(seed)
self.x = torch.randn(self.size, *self.shape, device=self.device)
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)

def __len__(self):
return self.size

def __getitem__(self, idx):
if self.x is None and self.y is None:
self._generate_data()

assert self.x is not None
assert self.y is not None

rank_idx = idx // self.world_size
return self.x[rank_idx], self.y[rank_idx]


class RandomImageDataset(VisionDataset):
""" Image Classification dataset with values drawn from a normal distribution
Args:
Expand Down
14 changes: 13 additions & 1 deletion tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,25 @@ def forward(self, x):
# test ComposerModels instead of nn.Module.
class SimpleComposerMLP(ComposerClassifier):

def __init__(self, num_features: int, device: str, num_classes: int = 3):
def __init__(self, num_features: int, device: Union[str, torch.device], num_classes: int = 3):
fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False)
net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2)
super().__init__(num_classes=num_classes, module=net)


# Like SimpleComposerMLP but saves each layer which is necessary to TP to it.
class TPSimpleComposerMLP(ComposerClassifier):

def __init__(self, num_features: int, device: Union[str, torch.device], num_classes: int = 3):
fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False)
net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2)
super().__init__(num_classes=num_classes, module=net)

self.fc1 = fc1
self.fc2 = fc2


class SimpleWeightTiedModel(ComposerClassifier):
"""Small classification model with tied weights.
Expand Down
Loading

0 comments on commit f76c2ff

Please sign in to comment.