Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Dataset Class that Reads From Zarr Archive #24

Closed
sadamov opened this issue May 3, 2024 · 1 comment · Fixed by #66
Closed

Pytorch Dataset Class that Reads From Zarr Archive #24

sadamov opened this issue May 3, 2024 · 1 comment · Fixed by #66
Assignees
Labels
enhancement New feature or request

Comments

@sadamov
Copy link
Collaborator

sadamov commented May 3, 2024

Summary
Since the weather community and especially ECMWF moved towards a single zarr archive that contains all the data in the state (domain), and one that contains all the data in the boundary, this project should follow the same approach. Zarr has many advantages like parallel computing with dask, lazy loading with xarray, efficient compression with different algorithms and chunking.

Specifics
There are three main data-processing steps happening in the current pipeline. This is a proposal how the work would be split between the three:

  • Data-Preprocessing
    • Usually some format like GRIB2 is converted into xarray->zarr. This step is out of scope
    • Pre-computation of forcings, static and grid features
    • Computation of normalization constants (stats) and inverse variances
    • Generating the boundary mask
  • Pytorch Dataset [on CPU]:
    • Reshaping of 3D variables into stacked 2D variables
    • Split data into train/val/test based on some indicator (e.g. time)
    • Generate the windowed indices for forcing and boundary
  • Pytorch Model [on GPU]
    • Normalization of the data

Interfaces

  • Data-Preprocessing
    • Input: out of scope
    • Output: one or multiple zarr files
  • Pytorch Dataset [on CPU]:
    • Input: one or multiple zarr files
    • Output: 5 pytorch tensors with the following dimensions:
      init_states: (2, N_grid, features_dim), 
      target_states: (n_lead_times, N_grid, features_dim), 
      forcing: (n_lead_times, N_grid, forcing_windowed_dim) # window_steps * n_forcing
      boundary: (n_lead_times, N_grid, boundary_windowed_dim) # windowed_steps * n_boundary
      batch_times: (2 + n_lead_times)[str]
  • Pytorch Model [on GPU]
    • Input: 5 pytorch tensors (batched with Pytorch DataLoader)
    • Output: out of scope

Implementation
One example for such a pytorch dataset and dataloader can be found here for inpisration: https://github.com/MeteoSwiss/neural-lam/blob/main/neural_lam/weather_dataset.py It needs however quite a bit of work:

Draw IO
dataset

@sadamov sadamov added the enhancement New feature or request label May 3, 2024
@sadamov sadamov self-assigned this May 3, 2024
@joeloskarsson
Copy link
Collaborator

Could very nicely use https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#on-after-batch-transfer to normalize once data is on GPU. Makes sure that you never forget about it (all batches on GPU are normalized).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
3 participants