DeepSpeedPlugin with activation checkpoint fails #9144
-
Hi there, Minimal code to reproduce import os
import deepspeed
import pytorch_lightning as pl
import torch
from deepspeed.ops.adam import FusedAdam
from pytorch_lightning.plugins import DeepSpeedPlugin
from torch import nn
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.utils.data import DataLoader, RandomSampler
class PlModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Linear(1, 1)
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
res = deepspeed.checkpointing.checkpoint(self.model, batch)
return nn.MSELoss()(res, torch.zeros_like(res, device=res.device))
def configure_optimizers(self):
return FusedAdam(self.parameters(), lr=0.1)
if __name__ == '__main__':
trainer = pl.Trainer(gpus=-1, precision=16, plugins=DeepSpeedPlugin(stage=3, partition_activations=True))
model = PlModel()
dataset = torch.rand(100, 1)
dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=os.cpu_count(),
sampler=RandomSampler(dataset))
trainer.fit(model, dl) pytorch-lightning version: 1.3.3 |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 4 replies
-
Thanks @nachshonc! I've managed to reproduce the same case without Deepspeed using import deepspeed
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return torch.utils.checkpoint.checkpoint(self.layer, x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
max_epochs=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run() I think the issue arises from the fact that the entire model's activations have been removed, with the input tensors not requiring any gradients, thus the autograd engine not being able to infer any gradients. For activation checkpointing, it only makes sense to include it if you have intermediate layers which can create expensive activations. For example, swap the model out to look like this: class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_h = torch.nn.Linear(32, 32)
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = torch.utils.checkpoint.checkpoint(self.layer_h, x)
return self.layer(x) Activation checkpointing just means on the backwards, we'll need to re-compute the activations (unless you do CPU checkpointing with Deepspeed or something, where activations are just transferred to the CPU memory). In this case, there is no point checkpointing the final layer, as the final layer will instantly need to be re-computed. class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_h = torch.nn.Linear(32, 32)
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = self.layer_h(x)
return torch.utils.checkpoint.checkpoint(self.layer, x) # no point doing this! We should definitely make the docs clearer for this, I'll make this an issue :) |
Beta Was this translation helpful? Give feedback.
-
@SeanNaren Detailed answer!👍 But I wonder how to wrap an inner model like HuggingFace class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.bert_layer = AutoModel.from_pretrained("bert-base-uncased")
self.mlp_layer = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask, token_type_ids):
bert_output = deepspeed.checkpointing.checkpoint(self.bert_layer, input_ids, attention_mask, token_type_ids)
# ⬆️ this api will return a tuple of strings: ('last_hidden_state', 'pooler_output'), not embeddings
cls_output = bert_output.last_hidden_state[:, 0, :]
mlp_output = self.mlp_layer(cls_output)
return mlp_output |
Beta Was this translation helpful? Give feedback.
-
Deepspeed only provides Pipeline parallelism (PP), and using Deepspeed PP is incompatible with zero-2 and zero-3. ref : https://deepspeed.readthedocs.io/en/latest/pipeline.html Furthermore, zero-3 calculates activations independently for each GPU in a sub-batch, making activation partitioning itself meaningless. In conclusion, if you're using the DeepspeedStrategy in PyTorch Lightning, applying activation partitioning doesn't offer much significance. |
Beta Was this translation helpful? Give feedback.
Thanks @nachshonc!
I've managed to reproduce the same case without Deepspeed using
torch.utils.checkpoint
and our bug report model: