Skip to content

Commit

Permalink
Improve filtering of attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Apr 30, 2024
1 parent 0544100 commit bae6003
Showing 1 changed file with 19 additions and 27 deletions.
46 changes: 19 additions & 27 deletions tjax/_src/display/display_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ def is_node_type(x: type[Any]) -> bool:
is_node_type = nnx.graph.is_node_type


def attribute_filter(value: Any, attribute_name: str) -> bool:
is_private = attribute_name.startswith('_')
if flax_loaded:
from flax.experimental import nnx # noqa: PLC0415
if isinstance(value, nnx.State) and is_private:
return False
if (isinstance(value, nnx.Variable | nnx.VariableState)
and (is_private or attribute_name.endswith('_hooks'))):
return False
if isinstance(value, FlaxModule) and attribute_name.startswith('_graph_node__'):
return False
return True


@singledispatch
def display_generic(value: Any,
*,
Expand All @@ -75,10 +89,7 @@ def display_generic(value: Any,
return x
if is_dataclass(value) and not isinstance(value, type):
return _display_dataclass(value, seen=seen, key=key)
hide_private = False
if flax_loaded:
hide_private = isinstance(value, FlaxState)
return _display_object(value, seen=seen, key=key, hide_private=hide_private)
return _display_object(value, seen=seen, key=key)
# _assemble(key, Text(str(value), style=_unknown_color))


Expand Down Expand Up @@ -222,27 +233,6 @@ def _(value: PyTreeDef,
return retval


if flax_loaded:
@display_generic.register(FlaxVariable)
def _(value: nnx.Variable[Any],
*,
seen: MutableSet[int] | None = None,
key: str = '',
) -> Tree:
if seen is None:
seen = set()
if (x := _verify(value, seen, key)) is not None:
return x
retval = display_class(key, type(value))
variables = _variables(value)
variables = {key: sub_value
for key, sub_value in variables.items()
if not (key.startswith('_') or key.endswith('_hooks') and value)}
for name, sub_value in variables.items():
retval.children.append(display_generic(sub_value, seen=seen, key=name))
return retval


# Public unexported functions ----------------------------------------------------------------------
def display_class(key: str, cls: type[Any]) -> Tree:
name = cls.__name__
Expand Down Expand Up @@ -270,14 +260,16 @@ def _display_dataclass(value: DataclassInstance,
for field_info in fields(value):
name = field_info.name
names.add(name)
if not attribute_filter(value, name):
continue
display_name = name
if not field_info.init:
display_name += ' (module)'
sub_value = getattr(value, name, None)
retval.children.append(display_generic(sub_value, seen=seen, key=display_name))
variables = _variables(value)
for name, sub_value in variables.items():
if name in names:
if name in names or not attribute_filter(value, name):
continue
retval.children.append(display_generic(sub_value, seen=seen, key=name + '*'))
return retval
Expand All @@ -292,7 +284,7 @@ def _display_object(value: Any,
retval = display_class(key, type(value))
variables = _variables(value)
for name, sub_value in variables.items():
if hide_private and name.startswith('_'):
if not attribute_filter(value, name):
continue
retval.children.append(display_generic(sub_value, seen=seen, key=name))
return retval
Expand Down

0 comments on commit bae6003

Please sign in to comment.