Model global weights are suddenly zero within local training code #3132
Replies: 4 comments 2 replies
-
Update: |
Beta Was this translation helpful? Give feedback.
-
@virginiafdez Thanks for you interest and question! Could you switch to NVFlare 2.5 or the main branch and try again? |
Beta Was this translation helpful? Give feedback.
-
Hello! I've managed to continue debugging the application; the weights are actually not zero, but they are still causing my training to NaN, something that does not happen when SVT is not used. I've narrowed it down to the impact that the privacy filters cause on the normalisation layers (because the NaNs only happen when eval() is used). About switching to NVFlare 2.5, I am unable to, as the production system we are using has 2.2.5 and cannot currently be updated. |
Beta Was this translation helpful? Give feedback.
-
More updates into this issue, and a summary of what happens in each scenario. I've been shuffling the following:
Regardless of SVT being applied or not, with all the weights (the ones coming out of SVT) loaded, the training results are fine. If I evaluate on train() mode, the results are fine as well. When I use SVT, and I validate on eval() mode, as I should:
Since these values running mean and variances are not learnable, does it make sense for them to go through the filters anyway? I assume I could write a customised filter excluding specific keys and try it out. I was curious on what are the minimal-impact parameters you can set with SVT to avoid modifying the weights too much (although I know that the point of SVT is precisely to do something to the weights, it's useful for testing purposes when the filter has to be there but the parameters can be tweaked). Besides this, the issue is clarified as this seems to be the root cause of the problem unless there is anything else coming from you that might point to a different direction. |
Beta Was this translation helpful? Give feedback.
-
Python version (
python3 -V
)3.8
NVFlare version (
python3 -m pip list | grep "nvflare"
)2.2.5
NVFlare branch (if running examples, please use the branch that corresponds to the NVFlare version,
git branch
)No response
Operating system
Ubuntu 22.04
Have you successfully run any of the following examples?
Please describe your question
I am relatively new to nvflare, and Iam running a training code I need to trim up, and while running it on the nvflare simulator, I am seeing that the weights that are passed on from server to each site initially are set to zero.
The model starts with pre-loading weights when it is created. In the scatter and gather workflow, I am printing said weights by accessing the ._global_weights attribute of the fl_ctx object and they are fine (not zero).
But then, right after, the broadcast_and_wait function is called to launch the training task and if I go to the execute method of that task, which I imagine, is the first thing that gets executed, and print the fl_ctx weights, they are zero.
Is there anything I should be aware of that might be causing this? Any function that is potentially zeroing or losing the weights that are correctly loaded in the server and passed on to the task?
Thanks a lot for any input!!
Beta Was this translation helpful? Give feedback.
All reactions