You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm doing some benchmarking with torchtitan on H100s to compare the experimental feature in #778 with other training configurations.
One important comparison is comparing:
float8nocompile (handwritten triton kernels for fp8 conversion of nn.Linear layers), and
production float8 training + using torch.compile on only the nn.Linear layers
I've ran this comparison using:
no AC
full AC
selective per op AC
However, specifically when using selective per op AC, there is a massive drop in performance when using the production float8 training + using torch.compile on only the nn.Linear layers, as compared to using torch.compile on the full model.
This does not occur when using no AC or full AC.
I would expect some performance degradation compiling only nn.Linear instead of the full model, but the drop off is massive (see screenshot of benchmarks below). TFLOPS drops from 386.58 down to 125.88!
I looked at the traces for these 2 runs, and found some surprising issues:
In the forward pass of the slow configuration (compiling only nn.Linear layers) it seems like aten::cross_entropy_loss is running on CPU and not dispatching any CUDA kernels, causing it to take 267ms vs 71us in the fully compiled version (~3800x slowdown).
Only nn.Linears compiled (267ms):
Full model compiled (71us):
In the backward pass of the slow configuration (compiling only nn.Linear layers) there is an extremely long/slow FSDP::post_backward_reduce call that does not appear in the fully compiled version (or rather, it is orders of magnitude faster).
Summary
I'm doing some benchmarking with torchtitan on H100s to compare the experimental feature in #778 with other training configurations.
One important comparison is comparing:
I've ran this comparison using:
However, specifically when using selective per op AC, there is a massive drop in performance when using the production float8 training + using torch.compile on only the nn.Linear layers, as compared to using torch.compile on the full model.
This does not occur when using no AC or full AC.
I would expect some performance degradation compiling only nn.Linear instead of the full model, but the drop off is massive (see screenshot of benchmarks below). TFLOPS drops from 386.58 down to 125.88!
I looked at the traces for these 2 runs, and found some surprising issues:
Only nn.Linears compiled (267ms):
Full model compiled (71us):
FSDP::post_backward_reduce
call that does not appear in the fully compiled version (or rather, it is orders of magnitude faster).Only nn.Linears compiled:
Steps to reproduce
training_configs/llama3_8b.toml
to run prod float8 + fully compiled model + selective per op AC on H100s:NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
training_configs/llama3_8b.toml
to run prod float8 + only linear layers compiled + selective per op AC on H100s (don't think # of GPUs matters):TORCHTITAN_COMPILE_LINEAR_ONLY=1 NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
cc @vkuzo @soulitzer
The text was updated successfully, but these errors were encountered: