Port of μP to JAX (Haiku specifically) #10528
Unanswered
davisyoshida
asked this question in
Show and tell
Replies: 1 comment 6 replies
-
Great Job! |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
μP is a method for reparameterizing NNs so that hyperparameter optima stay fixed as you scale them up. I wrote a port to JAX here.
As an example, I trained a bunch of transformers on PTB, with and without μP. As you can see, μP causes the optimal learning rate to stay fixed as you scale the models up:
(In the paper, the authors do this experiment but for LMs with 6.7B parameters).
This is specifically for Haiku, but if people are interesting in making versions for FLAX etc., it might be helpful to take a look at this. I had to change the design substantially from the authors' original repo, since a lot of things which are easy in Torch (via mutability) aren't in JAX, and vice versa. I'm also interest in any suggestions on improving the design or increasing usability.
Beta Was this translation helpful? Give feedback.
All reactions