-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Registering generic pytree nodes via getattr() and __setattr__ #25760
Comments
Thanks for bringing this up! This is pretty similar to the Still, there appears to be some appetite for unflattening via |
What about simply updating The main downside of this that comes to mind is that explaining the required invariant is a little bit more complex than "[class attributes] can be passed as keywords to the class constructor to create a copy of the object." |
Is it worth upstreaming some version of |
What could make sense for JAX is a basic, explicit unflatten method that avoids calling |
Fair enough if that's not desired. FWIW ~half of the complexity is stuff I'd be happy to cut if it bought a standardised solution. Just wanted to make the offer if JAX was going to eventually climb the same hill anyway. (Addressing Jake's concern |
In the past we have discussed inheritance-based pytree registration approaches, and chose not to go that route (though it's fine for downstream libraries to choose that style of API!) |
I came across a similar issue, when working on my little from dataclasses import dataclass
from jax.tree_util import register_dataclass
from functools import partial
from jax import numpy as np
@partial(register_dataclass, data_fields=["x"], meta_fields=["y"])
@dataclass
class Bar:
x: jax.Array
@classmethod
def from_any_array(cls, x):
"""Initialize from any array supporting the buffer protocol"""
return cls(x=jnp.asarray(x)) I realize that some people might dislike the additional code, but it is not really boilerplate, as the function name provides additional meaningful context for its arguments. Explicit is better than implicit. Compared to any other solution this design pattern is by far the simplest. Maybe the jax documention would benefit from documenting this pattern as well, especially in the context of using |
A persistent challenge when writing libraries that use JAX is that all containers need to be pytree types.
The easiest way to do this is by defining custom dataclasses, and using utilities like
jax.tree_util.register_dataclass
. However, this limits you to writing simple dataclases, and precludes common designg patterns such as casting to arrays in__init__
: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initializationAlternatively, you can write your own flattening/unflattening functions and register them, e.g., with
register_pytree_node
. This is more powerful, but requires considerably care.It occurs to me that is is possible to write generic flattening/unflattening logic similar to
register_dataclass
but that handles almost reasonable Python class, including classles with custom initialization. If desired,data_fields
andmeta_fields
may even be checked against__dict__
/__slots__
to verify that they are comprehensive:By using
object.__setatrr__
to construct the state, we can handle immutable dataclasses and classes that do validation in__init__
:The implementation here also works on dataclasses, too, so in principle
register_dataclass
(which already works on "dataclass like" non-dataclasses) could simply be modified to useobject.__setattr__
instead of the current unflattener which calls the constructor. This would require updating both the Python and C++ dataclass unflatteners, and might have performance implications.The text was updated successfully, but these errors were encountered: