-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model init with HuggingFace model #743
Comments
cc: @weifengpy @mori360 |
👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀 |
It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)
Could you give more details on the safe_tensors as I could repro the huge memory cost. |
I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the
I downloaded the model.safetensors for the
I am trying to mimic TorchTitan's implementation but with a HuggingFace model
This is a simple repro of my implementation which can be run using:
The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call |
However, @fegin Please correct me if I'm wrong. Also, shall we update model.init_weight() in torchtitan in the process from model.init_weight() to checkpoint.load() to to init weight param by param? |
Yes, @mori360, as you have implemented this feature, OOM should be able to avoid with |
I am writing a simple script to run FSDP2 (
fully_shard
) on thepythia-1b
model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try withpythia-2.8b
model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:This is not very straightforward since the shards expect
DTensors
when the weights are being loaded viaload_checkpoint_and_dispatch
. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.The text was updated successfully, but these errors were encountered: