Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example of training a tabular model on multiple GPUs #474

Merged
merged 15 commits into from
Dec 30, 2024

Conversation

akihironitta
Copy link
Member

@akihironitta akihironitta commented Dec 27, 2024

Adds an example for training a model on multiple GPUs.

Usage

$ python examples/trompt_multi_gpu.py

Changes

To enable it, this PR modifies Trompt accordingly:

- out = Trompt(...).forward_stacked(tf)  # forward_stacked is removed
+ out = Trompt(...)(tf)
  assert out.size() == (batch_size, num_layers, num_channels)

Some highlights in the script

  • Unlike examples/trompt.py, it computes training accuracy batch-wise instead of computing it at the end of each epoch to save time (although model parameters change over steps within each epoch).
  • It reduces losses and metrics with all_reduce and torchmetrics's API across all ranks at the end of each epoch.
  • It avoids unnecessary device synchronisations within each epoch, e.g., by not calling float(cuda_tensor), and by setting repeat_interleave(..., output_size=...).
  • ... (I'm happy to elaborate if anything in the script is unclear.)

Benchmark results from --dataset jannis on g6.12xlarge with four L4 GPUs

four GPUs single GPU
time per training step 0.725 seconds 0.741 seconds
time per training epoch 29 seconds 117 seconds
test accuracy 80.23 % 80.29 %

@@ -122,7 +122,7 @@ def reset_parameters(self) -> None:
trompt_conv.reset_parameters()
self.trompt_decoder.reset_parameters()

def forward_stacked(self, tf: TensorFrame) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure this change does not break the example code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed the change doesn't break these scripts across all supported task types:

examples/trompt.py
benchmark/data_frame_benchmark.py
benchmark/data_frame_text_benchmark.py

@akihironitta akihironitta merged commit 655730c into master Dec 30, 2024
14 checks passed
@akihironitta akihironitta deleted the aki/ddp branch December 30, 2024 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants