Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registering generic pytree nodes via getattr() and __setattr__ #25760

Open
shoyer opened this issue Jan 7, 2025 · 7 comments
Open

Registering generic pytree nodes via getattr() and __setattr__ #25760

shoyer opened this issue Jan 7, 2025 · 7 comments
Labels
enhancement New feature or request

Comments

@shoyer
Copy link
Collaborator

shoyer commented Jan 7, 2025

A persistent challenge when writing libraries that use JAX is that all containers need to be pytree types.

The easiest way to do this is by defining custom dataclasses, and using utilities like jax.tree_util.register_dataclass. However, this limits you to writing simple dataclases, and precludes common designg patterns such as casting to arrays in __init__: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

Alternatively, you can write your own flattening/unflattening functions and register them, e.g., with register_pytree_node. This is more powerful, but requires considerably care.

It occurs to me that is is possible to write generic flattening/unflattening logic similar to register_dataclass but that handles almost reasonable Python class, including classles with custom initialization. If desired, data_fields and meta_fields may even be checked against __dict__/__slots__ to verify that they are comprehensive:

from jax import tree_util


def register_pytree_via_state(
    cls: type,
    data_fields: list[str],
    meta_fields: list[str],
) -> type:
  expected_fields = set(data_fields + meta_fields)

  def flatten_with_keys(obj):
    try:
      actual_fields = obj.__dict__.keys()
    except AttributeError:
      # All Python objects without __dict__ have __slots__.
      # __slots__ may be a str or iterable of strings:
      # https://docs.python.org/3/reference/datamodel.html#slots
      slots = obj.__slots__
      actual_fields = {slots} if isinstance(slots, str) else set(slots)

    if actual_fields != expected_fields:
      raise TypeError(
          'unexpected attributes on object: '
          f'got {sorted(actual_fields)}, expected {sorted(expected_fields)}'
      )

    children_with_keys = [
        (tree_util.GetAttrKey(k), getattr(obj, k)) for k in data_fields
    ]
    aux_data = tuple((k, getattr(obj, k)) for k in meta_fields)
    return children_with_keys, aux_data

  def unflatten_func(aux_data, children):
    result = object.__new__(cls)
    for k, v in zip(data_fields, children):
      object.__setattr__(result, k, v)
    for k, v in aux_data:
      object.__setattr__(result, k, v)
    return result

  tree_util.register_pytree_with_keys(cls, flatten_with_keys, unflatten_func)
  return cls

By using object.__setatrr__ to construct the state, we can handle immutable dataclasses and classes that do validation in __init__:

import functools
import dataclasses
import jax
import jax.numpy as jnp


@functools.partial(
    register_pytree_via_state, data_fields=['x'], meta_fields=['y']
)
@dataclasses.dataclass(frozen=True)
class Foo:
  x: int
  y: str

print(jax.jit(lambda x: x)(Foo(1, 'two')))
# Foo(x=Array(1, dtype=int32, weak_type=True), y='two')


@functools.partial(register_pytree_via_state, data_fields=['x'], meta_fields=[])
class Bar:
  def __init__(self, x):
    self.x = jnp.asarray(x)

print(jax.vmap(lambda x: x)(Bar(jnp.arange(3))).x)
# Array([0, 1, 2], dtype=int32)

The implementation here also works on dataclasses, too, so in principle register_dataclass (which already works on "dataclass like" non-dataclasses) could simply be modified to use object.__setattr__ instead of the current unflattener which calls the constructor. This would require updating both the Python and C++ dataclass unflatteners, and might have performance implications.

@shoyer shoyer added the enhancement New feature or request label Jan 7, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 8, 2025

Thanks for bringing this up! This is pretty similar to the register_simple idea I had a while ago: #21245 At the time we decided that register_dataclass was sufficient for the use-cases that it covered, and with the enhancements in #24700 register_dataclass has become even more streamlined.

Still, there appears to be some appetite for unflattening via setitem as in that PR (e.g. #25486 is a similar request), so maybe it's worth revisiting. Though I worry that we're growing too many ways to register a pytree...

@shoyer
Copy link
Collaborator Author

shoyer commented Jan 8, 2025

What about simply updating register_dataclass to use object.__setattr__ to construct new objects rather than calling the constructor directly?

The main downside of this that comes to mind is that explaining the required invariant is a little bit more complex than "[class attributes] can be passed as keywords to the class constructor to create a copy of the object."

@patrick-kidger
Copy link
Collaborator

Is it worth upstreaming some version of equinox.Module? It already handles things like custom __init__ methods and converters, as this issue is concerned with. (And already works via dataclasses, as register_dataclass is already concerned with.)

@shoyer
Copy link
Collaborator Author

shoyer commented Jan 13, 2025

equinox.Module is great, but is also way more sophisticated than anything built-in to JAX, with over 1000 lines of code including comments & documentation. I think such complexity is better saved for external libraries.

What could make sense for JAX is a basic, explicit unflatten method that avoids calling __init__. This is the ~7 lines of code suggested in my first post, and which not coincidentally can also be found in Equinox:
https://github.com/patrick-kidger/equinox/blob/36f81165d53328a9e83fc1dc67afc0dc25fa48ec/equinox/_module.py#L953-L963

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jan 13, 2025

Fair enough if that's not desired. FWIW ~half of the complexity is stuff I'd be happy to cut if it bought a standardised solution. Just wanted to make the offer if JAX was going to eventually climb the same hill anyway. (Addressing Jake's concern Though I worry that we're growing too many ways to register a pytree....)

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 13, 2025

In the past we have discussed inheritance-based pytree registration approaches, and chose not to go that route (though it's fine for downstream libraries to choose that style of API!)

@adonath
Copy link
Contributor

adonath commented Jan 14, 2025

I came across a similar issue, when working on my little gmmx library. The solution for me was to entirely skip any implementation of __init__ and just work with alternative constructors instead. I found this also leads to a very clean and explicit API design. Consider the following example:

from dataclasses import dataclass
from jax.tree_util import register_dataclass
from functools import partial
from jax import numpy as np

@partial(register_dataclass, data_fields=["x"], meta_fields=["y"])
@dataclass
class Bar:
    x: jax.Array

    @classmethod
    def from_any_array(cls, x):
        """Initialize from any array supporting the buffer protocol"""
        return cls(x=jnp.asarray(x))

I realize that some people might dislike the additional code, but it is not really boilerplate, as the function name provides additional meaningful context for its arguments. Explicit is better than implicit. Compared to any other solution this design pattern is by far the simplest. Maybe the jax documention would benefit from documenting this pattern as well, especially in the context of using register_dataclass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants