-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable TritonFusedRMSNorm with local_map annotation #364
Conversation
[ghstack-poisoned]
ghstack-source-id: 7af2e42fd9611d83840bc14c9f023cfd65033f21 Pull Request resolved: #364
note: this test requires the land of pytorch/pytorch#126924 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good - thanks for implementing this!
two minor nits, main one is not sure if we want to leave tp on by default for debug_model.toml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should wait for the pytorch PR landed before landing this :)
[ghstack-poisoned]
ghstack-source-id: 6125011aba1a4bd9521fb4a3b761b62285ea6195 Pull Request resolved: #364
**Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank0]:2024-06-05 11:57:35,505 - root - INFO - step: 1 loss: 12.2703 memory: 24.66GiB(31.15%) wps: 143 mfu: 2.66% [rank0]:2024-06-05 11:57:35,505 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-06-05 11:58:11,490 - root - INFO - step: 10 loss: 11.0446 memory: 31.96GiB(40.37%) wps: 512 mfu: 9.51% [rank0]:2024-06-05 11:58:46,488 - root - INFO - step: 20 loss: 9.2321 memory: 31.96GiB(40.37%) wps: 586 mfu: 10.87% [rank0]:2024-06-05 11:59:22,462 - root - INFO - step: 30 loss: 8.2184 memory: 31.96GiB(40.37%) wps: 570 mfu: 10.58% [rank0]:2024-06-05 11:59:57,301 - root - INFO - step: 40 loss: 7.6220 memory: 31.96GiB(40.37%) wps: 589 mfu: 10.93% [rank0]:2024-06-05 12:00:32,254 - root - INFO - step: 50 loss: 7.5399 memory: 31.96GiB(40.37%) wps: 587 mfu: 10.89% [rank0]:2024-06-05 12:01:07,155 - root - INFO - step: 60 loss: 7.3179 memory: 31.96GiB(40.37%) wps: 588 mfu: 10.91% [rank0]:2024-06-05 12:01:41,999 - root - INFO - step: 70 loss: 7.3508 memory: 31.96GiB(40.37%) wps: 589 mfu: 10.92% [rank0]:2024-06-05 12:02:17,093 - root - INFO - step: 80 loss: 7.2696 memory: 31.96GiB(40.37%) wps: 584 mfu: 10.85% [rank0]:2024-06-05 12:02:52,009 - root - INFO - step: 90 loss: 7.0481 memory: 31.96GiB(40.37%) wps: 588 mfu: 10.91% [rank0]:2024-06-05 12:03:27,715 - root - INFO - step: 100 loss: 6.9623 memory: 31.96GiB(40.37%) wps: 575 mfu: 10.67% ``` 3. with `norm_type = "fused_rmsnorm"` ```[rank0]:2024-06-05 12:08:35,004 - root - INFO - step: 1 loss: 12.2422 memory: 24.62GiB(31.10%) wps: 95 mfu: 1.76% [rank0]:2024-06-05 12:08:35,004 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-06-05 12:09:12,401 - root - INFO - step: 10 loss: 11.0361 memory: 32.09GiB(40.54%) wps: 493 mfu: 9.15% [rank0]:2024-06-05 12:09:49,380 - root - INFO - step: 20 loss: 9.2725 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:10:26,645 - root - INFO - step: 30 loss: 8.2091 memory: 32.09GiB(40.54%) wps: 550 mfu: 10.21% [rank0]:2024-06-05 12:11:03,616 - root - INFO - step: 40 loss: 7.5601 memory: 32.09GiB(40.54%) wps: 555 mfu: 10.30% [rank0]:2024-06-05 12:11:40,625 - root - INFO - step: 50 loss: 7.5144 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:12:17,768 - root - INFO - step: 60 loss: 7.3869 memory: 32.09GiB(40.54%) wps: 552 mfu: 10.25% [rank0]:2024-06-05 12:12:54,820 - root - INFO - step: 70 loss: 7.3358 memory: 32.09GiB(40.54%) wps: 553 mfu: 10.27% [rank0]:2024-06-05 12:13:31,817 - root - INFO - step: 80 loss: 7.2085 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:14:09,156 - root - INFO - step: 90 loss: 7.0140 memory: 32.09GiB(40.54%) wps: 549 mfu: 10.19% [rank0]:2024-06-05 12:14:48,518 - root - INFO - step: 100 loss: 6.9507 memory: 32.09GiB(40.54%) wps: 521 mfu: 9.67%``` [ghstack-poisoned]
ghstack-source-id: 213ef4323f9888463076ea580c3b72e2359ec492 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
ghstack-source-id: bcd66ee725966f2a1670ef25b79240ebab5af249 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
ghstack-source-id: 8eb1f9ab7cb9e09840ec6982865b4dc032b3f7bc Pull Request resolved: #364
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see inlined comments
your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table |
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
ghstack-source-id: 070cee68512a890d0f2780732ade3b29b7142948 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: f01e04f5b60f17f83bd63846db83e2528c8881a4 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 169b1a9f70dce3d0acdd56889e94f3976fecb811 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 0b3c077b185c24438b24365a37e64aab959ac32e Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: ab2f2bef8172fe3620fd454ef6c83643529eb5e5 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 169b1a9f70dce3d0acdd56889e94f3976fecb811 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: ab2f2bef8172fe3620fd454ef6c83643529eb5e5 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 778bc18e027af989f522e5f7bd6177c3cfe34521 Pull Request resolved: #364
@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:
|
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 421081d28e170466302269c3a4703d2291e76e55 Pull Request resolved: #364
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! Thanks for working on this, please fix the CPU test failure before landing
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: e7a19e3525c6965969150a90b3344ada3e7d4c83 Pull Request resolved: #364
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
ghstack-source-id: 2d3741d4caaf02724c08f149c1c989992a42793f Pull Request resolved: #364
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. #364
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. pytorch#364
ghstack-source-id: 2d3741d4caaf02724c08f149c1c989992a42793f Pull Request resolved: #364
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. pytorch#364
Stack from ghstack (oldest at bottom):
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.
Test Plan
Here's the output of running
CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
using 4-way Tensor Parallel (tensor_parallel_degree = 4
). Detailed settings:norm_type = "rmsnorm"
norm_type = "fused_rmsnorm"