-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove batch-static tensor from dataset class and models #13
Conversation
@@ -115,7 +112,6 @@ def predict_step( | |||
( | |||
prev_state, | |||
prev_prev_state, | |||
batch_static_features, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the batch-static features are now put as the first feature dimension in forcing. Earlier they were stacked right on top of forcing in this tensor. This results in no change to how grid_features
looks like for a sample. Importantly, this means that models trained before this PR can be loaded and works without any problems.
@sadamov Hope it's ok that I put you to review PRs like this :) I think it's valuable to get a second pair of eyes to look at the changes, and also good for you to get an update on small things I am changing. The changes to the MEPS Dataset class here are not very important, this is really motivated by moving away from things being too specific for that data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with both the general direction of and the explicit changes to the codebase.
- In general, making the dataloader more flexible and reducing the complexity of input feature types, allows for easier onboarding of new collaborators.
- I tested the explicit changes with the meps_example dataset and training is successful without batch_static_features.
Thanks for taking a look! I just realized I forgot to change |
Squashed commit of the following: commit b0050b9 Author: Joel Oskarsson <[email protected]> Date: Mon Mar 18 10:56:45 2024 +0100 Remove batch-static tensor from dataset class and models (mllam#13) * Bake the batch-static features into the normal forcing in the MEPS Dataset class. * Change the Dataset class to only return 3 tensors per sample (init, target, forcing). * Remove the batch-static tensor from being extracted from the batch and passed around in the graph-based models. This while making sure that input dimensions line up so older checkpoints can still be loaded correctly. commit 0669ff4 Author: Joel Oskarsson <[email protected]> Date: Thu Feb 29 11:50:27 2024 +0100 Re-define RMSE metric to take sqrt after sample averaging (mllam#10)
The batch-static tensor contained forcing that differed between initialization times, but stayed static for all lead times of a forecast. For the MEPS data we used this for the land-water-mask, as this could be different throughout the year, but we could not produce separate values per lead time (as all other forcing).
This PR removes the batch-static features as an explicit extra input. The motivation is:
None
.This PR changes: