Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZBVZeroBubble error #774

Open
hhaAndroid opened this issue Jan 3, 2025 · 5 comments
Open

ZBVZeroBubble error #774

hhaAndroid opened this issue Jan 3, 2025 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@hhaAndroid
Copy link

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 4
pipeline_parallel_microbatches = 8
pipeline_parallel_schedule='ZBVZeroBubble' 

world-size is 8

[rank3]:[rank3]: Traceback (most recent call last):
[rank3]:[rank3]:   File "code/torchtitan/train.py", line 429, in <module>
[rank3]:[rank3]:     main(config)
[rank3]:[rank3]:   File "miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
[rank3]:[rank3]:     return f(*args, **kwargs)
[rank3]:[rank3]:   File "/code/torchtitan/train.py", line 290, in main
[rank3]:[rank3]:     pp_schedule.step()
[rank3]:[rank3]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/pipelining/schedules.py", line 1194, in step
[rank3]:[rank3]:     self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank3]:[rank3]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/pipelining/schedules.py", line 1372, in _step_microbatches
[rank3]:[rank3]:     raise e
[rank3]:[rank3]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/pipelining/schedules.py", line 1251, in _step_microbatches
[rank3]:[rank3]:     stage = stage_index_to_stage[stage_index]
[rank3]:[rank3]: KeyError: 6
[rank2]:[rank2]:E0103 05:54:18.993137 491746 site-packages/torch/distributed/pipelining/schedules.py:1358] [Rank 1] pipeline schedule ScheduleZBVZeroBubble caught the following exception                      at time_step 6 when running action 6F0
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366]          Rank 0  Rank 1  Rank 2  Rank 3 
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 00: 0F0                            
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 01: 0F1     1F0                    
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 02: 0F2     1F1     2F0            
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 03: 0F3     1F2     2F1     3F0    
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 04: 0F4     1F3     2F2     4F0    
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 05: 0F5     1F4     5F0     3F1    
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 06: 0F6     6F0     2F3     4F1     <-- ERROR HERE
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 07: 7F0     1F5     5F1     3F2    
[rank2]:[rank2]:E0103 05:54:18.996414 491746 site-packages/torch/distributed/pipelining/schedules.py:1366] Step 08: 7I0     6F1     2F4     4F2    
@hhaAndroid
Copy link
Author

if i replace 'loop' to 'v' style:

[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 1467, in _prepare_forward_infra
[rank2]:[rank2]:     outputs = self._shape_inference(args, kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 1405, in _shape_inference
[rank2]:[rank2]:     outputs = self.submod(*args, **kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank2]:[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank2]:[rank2]:     return inner()
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
[rank2]:[rank2]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 71, in fsdp_hook_wrapper
[rank2]:[rank2]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 744, in _fn
[rank2]:[rank2]:     return fn(*args, **kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 231, in _pre_forward
[rank2]:[rank2]:     args, kwargs = self._root_pre_forward(module, args, kwargs)
[rank2]:[rank2]:   File "/miniconda3/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 144, in _root_pre_forward
[rank2]:[rank2]:     args, kwargs = args_tuple[0], kwargs_tuple[0]
[rank2]:[rank2]: IndexError: tuple index out of range

@wconstab
Copy link
Contributor

wconstab commented Jan 3, 2025

Can you include your full repro instructions? (at least: whole toml file, command-line, branch/commit of torchtitan and pytorch being used) Just sharing the pipeline-specific snippet from the config isn't enough for me to run and reproduce your issue.

@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 4, 2025

@wconstab I think this feature is missing a proper interface in torchtitan. Here's what @H-Huang told me:

ZBVSchedule requires that the stages on each rank to be ordered differently (v-formation). In titan we do interleaved by default so something like:

Rank 0: 0, 4, 8
Rank 1: 1, 5, 9
Rank 2: 2, 6, 10
Rank 3: 3, 7, 11

For allocating the stages in a V shape, it requires something like:

Rank 0: 0 7 8
Rank 1: 1 6 9
Rank 2: 2 5 10
Rank 3: 3 4 11

Did you mean to do InterleavedZeroBubSchedule instead? Or were you trying out the ZBVschedule? ZBV is new and I still haven't added a test for ZBV in titan and need to figure out a UX friendly way to do it.

@hhaAndroid
Copy link
Author

Compared to its importance, I hope there is a way to address #773, as ZeroBubble TGS is quite slow, which renders it ineffective.

@hhaAndroid
Copy link
Author

InterleavedZeroBubSchedule

I found that the stage_ids_this_rank function in titan supports the v style, so I modified the code, but I encountered some issues as well.

@tianyu-l tianyu-l added the bug Something isn't working label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants