Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable TritonFusedRMSNorm with local_map annotation #364

Merged
merged 16 commits into from
Jun 14, 2024

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented May 25, 2024

Stack from ghstack (oldest at bottom):

Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

Test Plan
Here's the output of running CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh using 4-way Tensor Parallel (tensor_parallel_degree = 4). Detailed settings:

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
  1. with norm_type = "rmsnorm"
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
  1. with norm_type = "fused_rmsnorm"
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%

XilunWu added a commit that referenced this pull request May 25, 2024
ghstack-source-id: 7af2e42fd9611d83840bc14c9f023cfd65033f21
Pull Request resolved: #364
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 25, 2024
@XilunWu
Copy link
Contributor Author

XilunWu commented May 25, 2024

note: this test requires the land of pytorch/pytorch#126924

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good - thanks for implementing this!
two minor nits, main one is not sure if we want to leave tp on by default for debug_model.toml

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should wait for the pytorch PR landed before landing this :)

XilunWu added a commit that referenced this pull request Jun 4, 2024
ghstack-source-id: 6125011aba1a4bd9521fb4a3b761b62285ea6195
Pull Request resolved: #364
**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank0]:2024-06-05 11:57:35,505 - root - INFO - step:  1  loss: 12.2703  memory: 24.66GiB(31.15%)  wps: 143  mfu: 2.66%
[rank0]:2024-06-05 11:57:35,505 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-05 11:58:11,490 - root - INFO - step: 10  loss: 11.0446  memory: 31.96GiB(40.37%)  wps: 512  mfu: 9.51%
[rank0]:2024-06-05 11:58:46,488 - root - INFO - step: 20  loss:  9.2321  memory: 31.96GiB(40.37%)  wps: 586  mfu: 10.87%
[rank0]:2024-06-05 11:59:22,462 - root - INFO - step: 30  loss:  8.2184  memory: 31.96GiB(40.37%)  wps: 570  mfu: 10.58%
[rank0]:2024-06-05 11:59:57,301 - root - INFO - step: 40  loss:  7.6220  memory: 31.96GiB(40.37%)  wps: 589  mfu: 10.93%
[rank0]:2024-06-05 12:00:32,254 - root - INFO - step: 50  loss:  7.5399  memory: 31.96GiB(40.37%)  wps: 587  mfu: 10.89%
[rank0]:2024-06-05 12:01:07,155 - root - INFO - step: 60  loss:  7.3179  memory: 31.96GiB(40.37%)  wps: 588  mfu: 10.91%
[rank0]:2024-06-05 12:01:41,999 - root - INFO - step: 70  loss:  7.3508  memory: 31.96GiB(40.37%)  wps: 589  mfu: 10.92%
[rank0]:2024-06-05 12:02:17,093 - root - INFO - step: 80  loss:  7.2696  memory: 31.96GiB(40.37%)  wps: 584  mfu: 10.85%
[rank0]:2024-06-05 12:02:52,009 - root - INFO - step: 90  loss:  7.0481  memory: 31.96GiB(40.37%)  wps: 588  mfu: 10.91%
[rank0]:2024-06-05 12:03:27,715 - root - INFO - step: 100  loss:  6.9623  memory: 31.96GiB(40.37%)  wps: 575  mfu: 10.67%
```

3. with `norm_type = "fused_rmsnorm"`
```[rank0]:2024-06-05 12:08:35,004 - root - INFO - step:  1  loss: 12.2422  memory: 24.62GiB(31.10%)  wps: 95  mfu: 1.76%
[rank0]:2024-06-05 12:08:35,004 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-05 12:09:12,401 - root - INFO - step: 10  loss: 11.0361  memory: 32.09GiB(40.54%)  wps: 493  mfu: 9.15%
[rank0]:2024-06-05 12:09:49,380 - root - INFO - step: 20  loss:  9.2725  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:10:26,645 - root - INFO - step: 30  loss:  8.2091  memory: 32.09GiB(40.54%)  wps: 550  mfu: 10.21%
[rank0]:2024-06-05 12:11:03,616 - root - INFO - step: 40  loss:  7.5601  memory: 32.09GiB(40.54%)  wps: 555  mfu: 10.30%
[rank0]:2024-06-05 12:11:40,625 - root - INFO - step: 50  loss:  7.5144  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:12:17,768 - root - INFO - step: 60  loss:  7.3869  memory: 32.09GiB(40.54%)  wps: 552  mfu: 10.25%
[rank0]:2024-06-05 12:12:54,820 - root - INFO - step: 70  loss:  7.3358  memory: 32.09GiB(40.54%)  wps: 553  mfu: 10.27%
[rank0]:2024-06-05 12:13:31,817 - root - INFO - step: 80  loss:  7.2085  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:14:09,156 - root - INFO - step: 90  loss:  7.0140  memory: 32.09GiB(40.54%)  wps: 549  mfu: 10.19%
[rank0]:2024-06-05 12:14:48,518 - root - INFO - step: 100  loss:  6.9507  memory: 32.09GiB(40.54%)  wps: 521  mfu: 9.67%```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 9, 2024
ghstack-source-id: 213ef4323f9888463076ea580c3b72e2359ec492
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: bcd66ee725966f2a1670ef25b79240ebab5af249
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: 8eb1f9ab7cb9e09840ec6982865b4dc032b3f7bc
Pull Request resolved: #364
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see inlined comments

torchtitan/models/norms.py Outdated Show resolved Hide resolved
test/test_fused_rms_norm.py Outdated Show resolved Hide resolved
@wanchaol
Copy link
Contributor

your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table

**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: 070cee68512a890d0f2780732ade3b29b7142948
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: f01e04f5b60f17f83bd63846db83e2528c8881a4
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 169b1a9f70dce3d0acdd56889e94f3976fecb811
Pull Request resolved: #364
@XilunWu XilunWu requested a review from wanchaol June 13, 2024 17:11
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 0b3c077b185c24438b24365a37e64aab959ac32e
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: ab2f2bef8172fe3620fd454ef6c83643529eb5e5
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 169b1a9f70dce3d0acdd56889e94f3976fecb811
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: ab2f2bef8172fe3620fd454ef6c83643529eb5e5
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 778bc18e027af989f522e5f7bd6177c3cfe34521
Pull Request resolved: #364
@wanchaol
Copy link
Contributor

@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:

[rank0]:2024-06-13 12:53:16,156 - root - INFO - step:  1  loss: 12.2550  memory: 33.35GiB(35.09%)  wps: 542  mfu: 3.17%
[rank0]:2024-06-13 12:53:16,156 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-13 12:53:42,562 - root - INFO - step: 10  loss: 10.7798  memory: 41.18GiB(43.33%)  wps: 2,792  mfu: 16.35%
[rank0]:2024-06-13 12:54:04,065 - root - INFO - step: 20  loss:  9.1087  memory: 41.18GiB(43.33%)  wps: 3,812  mfu: 22.32%
[rank0]:2024-06-13 12:54:25,626 - root - INFO - step: 30  loss:  7.9951  memory: 41.18GiB(43.33%)  wps: 3,802  mfu: 22.27%

**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 421081d28e170466302269c3a4703d2291e76e55
Pull Request resolved: #364
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Thanks for working on this, please fix the CPU test failure before landing

**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: e7a19e3525c6965969150a90b3344ada3e7d4c83
Pull Request resolved: #364
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 2d3741d4caaf02724c08f149c1c989992a42793f
Pull Request resolved: #364
@XilunWu XilunWu merged commit 264340d into gh/XilunWu/2/base Jun 14, 2024
6 checks passed
XilunWu added a commit that referenced this pull request Jun 14, 2024
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. #364
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. pytorch#364
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
ghstack-source-id: 2d3741d4caaf02724c08f149c1c989992a42793f
Pull Request resolved: #364
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP.
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. pytorch#364
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants