Trying to encapsulate jax code into a class #14347
-
Description What jax/jaxlib version are you using? Which accelerator(s) are you using? Additional system info NVIDIA GPU info My Code
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Have you tried using Equinox library ( https://github.com/patrick-kidger/equinox ), it provides a convenient base module from which you can build classes which are correctly structured for JAX. |
Beta Was this translation helpful? Give feedback.
-
The reason your code is not working as expected is because you are marking See https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods for a discussion of this problem and some potential solutions. |
Beta Was this translation helpful? Give feedback.
The reason your code is not working as expected is because you are marking
self
as static when JIT-compiling theupdate
method. This is incompatible with in-place mutation of a class instance, and you are mutating your class instance in the loop when you set or modify theparams
method.See https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods for a discussion of this problem and some potential solutions.