Adding JAX implementation to the Dive into Deep Learning book #12246
Unanswered
astonzhang
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
We are currently adding JAX implementation to our Dive into Deep Learning (D2L.ai) book and would like to seek early feedback from the JAX community.
For context, D2L aims to make deep learning approachable, teaching readers the concepts and the code. The entire book is drafted in Jupyter notebooks, seamlessly integrating exposition figures, math, and interactive examples with self-contained code. Previously we implemented the book using other frameworks such as PyTorch, MXNet, and TensorFlow. To be instruction-friendly, we adopted an object-oriented design so we only need to re-implement (e.g., subclass) the
Module
/DataModule
class if a new section only deals with a new model/dataset. In this way readers can just focus on change of models (when new models are described)/datasets (when new tasks are described) in each section without going through implementation of the entire training pipeline again and again.As of now, we have added JAX implementation for Chapter 3 on the jax branch of our repo, which can be previewed online:
If you spot anything related to JAX that can be improved, please let me and @AnirudhDagar know. Besides, we plan to use PyTorch data loaders, which unfortunately results in extra library dependency: if JAX plans to have its own data loaders please share your roadmaps.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions