Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient clipping doesn't work with FSDP CPU offloading #1977

Closed
acisseJZhong opened this issue Nov 9, 2024 · 16 comments
Closed

Gradient clipping doesn't work with FSDP CPU offloading #1977

acisseJZhong opened this issue Nov 9, 2024 · 16 comments
Labels
bug Something isn't working

Comments

@acisseJZhong
Copy link
Contributor

acisseJZhong commented Nov 9, 2024

I am running the full finetune distributed recipe, when setting clip_grad_norm: 1.0 and fsdp_cpu_offload: True, it raises error
RuntimeError: No backend type associated with device type cpu

Full error stack trace:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 847, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 842, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 740, in train
[rank2]:     grad_norm = torch.nn.utils.clip_grad_norm_(
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
[rank2]:     clip_coef = max_norm / (total_norm + 1e-6)
[rank2]:                 ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 39, in wrapped
[rank2]:     return f(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 1075, in __rdiv__
[rank2]:     return self.reciprocal() * other
[rank2]:            ^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
[rank2]:     self.redistribute_local_args(
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
[rank2]:     resharded_local_tensor = redistribute_local_tensor(
[rank2]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
[rank2]:     new_local_tensor = partial_spec._reduce_value(
[rank2]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
[rank2]:     reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
[rank2]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
[rank2]:     return funcol.all_reduce(
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 176, in all_reduce
[rank2]:     tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: No backend type associated with device type cpu

Wondering how should we fix this error?

@acisseJZhong acisseJZhong changed the title Gradient clipping doesn't work with FSDP cpu offloading Gradient clipping doesn't work with FSDP CPU offloading Nov 9, 2024
@acisseJZhong acisseJZhong added the bug Something isn't working label Nov 9, 2024
@felipemello1
Copy link
Contributor

@ebsmothers, do you think it would make sense to ping someone from FSDP?

@RdoubleA
Copy link
Contributor

RdoubleA commented Nov 9, 2024

Could you try modifying the init_process_group call to use the gloo backend for cpu? Perhaps it should initialize both nccl for gpu and gloo for cpu?

https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L903

@ebsmothers
Copy link
Contributor

I don’t think we want to modify init_process_group here. To me that error indicates that we are trying to call some comms primitive on a tensor that’s already on CPU, which we shouldn’t be doing. Initializing process group on CPU would only be helpful if we actually want distributed training on CPU, which we don’t. Let’s debug a bit more and then we can loop in distributed folks if needed.

@gau-nernst
Copy link
Contributor

I believe when CPU offload is used in FSDP, gradients will be transferred to CPU during the backward pass (to free up gradients memory, similar to optim in backward) to perform optimizer step on CPU. That's probably why you see cpu device there, because the gradients are on CPU now. They are DTensor, hence when you run gradient clipping, which calls .sum() or some sort, it will try to do all-reduce, hence the error.

It's probably faster to check with the distributed folks if FSDP w/ CPU offload support gradient clipping in general. Even if it is technically possible (e.g. do clipping on CPU), I think it would be too slow + possibly require changes in internal FSDP code.

@vancoykendall
Copy link
Contributor

Looks like torchtitan repo ran into the same issue and someone created a quick workaround in a special branch:
https://github.com/pytorch/torchtitan/pull/622/files

@gordicaleksa
Copy link

gordicaleksa commented Nov 29, 2024

I'm hitting this same issue when doing a full fine-tune of a 70B llama on a single node.

Any "proper" way of solving this? I'll check out torchtitan solution

edit: torchtitan solution is just "don't use grad clipping" basically?

@ebsmothers
Copy link
Contributor

@gordicaleksa yeah you're right.. it seems to me like the error in that PR is the same as what's being described in this issue. cc @weifengpy @mori360 what is the status of pytorch/torchtitan#622?

@mori360
Copy link
Contributor

mori360 commented Dec 2, 2024

https://github.com/pytorch/torchtitan/pull/622/files

@ebsmothers
The branch #622 is not landed, here we apply an optional input buffer_device to solve the cpu offloading case.
Only under cpu offloading case, buffer_device="cuda", otherwise buffer_device=None and apply to self.freqs_cis.device
pytorch/torchtitan#624

To deal with the backend issue on the different device type during offloading, we used
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo", ...)
instead of
torch.distributed.init_process_group(backend="nccl", ...)

@weifengpy
Copy link
Contributor

https://github.com/pytorch/torchtitan/pull/622/files

The branch here is not landed, here we apply a optional input buffer_device to solve the cpu offloading case. pytorch/torchtitan#624

I thought we landed a PR to support cpu offloading in torchtitan?

@weifengpy
Copy link
Contributor

RuntimeError: No backend type associated with device type cpu

@mori360 the error is RuntimeError: No backend type associated with device type cpu. it's different from buffer_device. I remember you mentioned initing cpu backend will resolve the issue?

@mori360
Copy link
Contributor

mori360 commented Dec 2, 2024

RuntimeError: No backend type associated with device type cpu

@mori360 the error is RuntimeError: No backend type associated with device type cpu. it's different from buffer_device. I remember you mentioned initing cpu backend will resolve the issue?

We used "cuda:nccl,cpu:gloo" to solve the backend issue on cpu offloading.
https://github.com/pytorch/torchtitan/blob/3e3909a0c1d4c451a44cda6c32139dbda69961c0/torchtitan/utils.py#L196-L198

@ebsmothers
Copy link
Contributor

@mori360 can you share more info on why adding gloo backend for CPU solves the issue? Is there some 1:1 mapping between FSDP's process group on CUDA and a CPU process group when CPU offloading is enabled? My assumption was that any CPU offloading would offload to a single CPU process, but maybe that was incorrect?

@mori360
Copy link
Contributor

mori360 commented Dec 2, 2024

@mori360 can you share more info on why adding gloo backend for CPU solves the issue? Is there some 1:1 mapping between FSDP's process group on CUDA and a CPU process group when CPU offloading is enabled? My assumption was that any CPU offloading would offload to a single CPU process, but maybe that was incorrect?

Yeah, CPU offloading would offload to a single CPU process, however gradient clipping needs communication in _NormPartial, which needs to add gloo backend for CPU.
@wz337

@ebsmothers
Copy link
Contributor

Thanks @mori360 and @weifengpy for the explanation here. I guess @RdoubleA was right from the outset (sorry for derailing things). I just opened #2108 for this

@joecummings
Copy link
Contributor

@ebsmothers Can this issue be closed now?

@ebsmothers
Copy link
Contributor

Yeah we can close this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

10 participants