Skip to content

Commit

Permalink
Show seen array dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed May 19, 2024
1 parent 4f8e1f8 commit de7b4ad
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
25 changes: 24 additions & 1 deletion tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax.numpy as jnp
import numpy as np
from jax import enable_custom_prng, jit, tree, vmap
from jax import Array, enable_custom_prng, jit, tree, vmap
from jax.random import key
from pytest import CaptureFixture
from rich.console import Console
Expand Down Expand Up @@ -176,6 +176,29 @@ class C:
""")


def test_seen_array(capsys: CaptureFixture[str],
console: Console) -> None:
@dataclass
class C:
x: Array
y: Array

z = jnp.zeros(2)
tree_def = C(z, z)

print_generic(tree_def, immediate=True, console=console)
assert isinstance(console.file, StringIO)
captured = console.file.getvalue()
verify(captured,
"""
C[dataclass]
├── x=Jax Array (2,) float64
│ └── 0.0000 │ 0.0000
└── y=Jax Array (2,) float64
└── y=<seen>
""")


if __name__ == "__main__":
@dataclass
class Triplet:
Expand Down
15 changes: 8 additions & 7 deletions tjax/_src/display/display_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,18 @@ def _(value: Array,
) -> Tree:
if seen is None:
seen = set()
if (x := _verify(value, seen, key)) is not None:
return x
retval = _assemble(key,
Text(f"Jax Array {value.shape} {value.dtype}",
style=_jax_array_color))
try:
np_value = np.asarray(value)
except TracerArrayConversionError:
pass
if (x := _verify(value, seen, key)) is not None:
retval.add(x)
else:
_show_array(retval, np_value)
try:
np_value = np.asarray(value)
except TracerArrayConversionError:
retval.add(_assemble(key, Text('<error>', style=_seen_color)))
else:
_show_array(retval, np_value)
return retval


Expand Down

0 comments on commit de7b4ad

Please sign in to comment.