Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexibility in NTK recording #110

Open
KonstiNik opened this issue Feb 23, 2024 · 2 comments
Open

More flexibility in NTK recording #110

KonstiNik opened this issue Feb 23, 2024 · 2 comments
Labels

Comments

@KonstiNik
Copy link
Member

The current implementation of the JaxRecorder by default takes the ntk_apply function defined through the model.
For some NTK computations (like the loss ntk or some fisher ntk approximations) the ntk apply function does not correspond to the model apply function.
It would therefore be reasonable for the user to be able to set this function manually for each recorder, as one might like to record multiple versions of the ntk of one training.

The suggestion is to move the ntk_apply function from the model to a separate class.
It handles all the ntk computation and is constructed by taking an apply function (of the model e.g.).
We would need one NTK computation class for each model.

@KonstiNik KonstiNik changed the title NTK Recorder more flexible More flexibility in NTK recording Feb 23, 2024
@SamTov
Copy link
Member

SamTov commented Mar 5, 2024

Do you think one single NTK function should handle all your different versions of it? As the "loss ntk" or fisher are different things, it would make more sense to have a loss ntk or Fisher calculator somewhere and not overload the single NTK computation with a bunch of additional arguments and options. They can share a backend though.

@KonstiNik
Copy link
Member Author

I agree. The point I was trying to make is more that the ntk computation currently is part of the model. In case you want to record an ntk that is not the exact model function, we cannot do this a.t.m.
An example would be you want to record the ntk of the model with softmax output but the loss is softmax-cross-entropy which has the softmax already included. Another example is to record the ntk of the model + loss function.
Moving the ntk computation from the model into a separate class does affect the recorders, as the ntk computation would be passed as a callable to the recorders.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants