Skip to content

Commit

Permalink
Merge pull request #25713 from jakevdp:debug-printoptions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711671926
  • Loading branch information
Google-ML-Automation committed Jan 3, 2025
2 parents 57b2154 + 3306063 commit 0f4677b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ def check_unused_args(self, used_args, args, kwargs):

formatter = _DebugPrintFormatChecker()

def _format_print_callback(fmt: str, *args, **kwargs):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
with np.printoptions(**np_printoptions):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")

def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
"""Prints values and works in staged out JAX functions.
Expand Down Expand Up @@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, **kwargs):
# Check that we provide the correct arguments to be formatted.
formatter.format(fmt, *args, **kwargs)

debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
*args, **kwargs, ordered=ordered)


# Sharding visualization
Expand Down
23 changes: 23 additions & 0 deletions tests/debugging_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,29 @@ def f(x):
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
"""))

def test_debug_print_respects_numpy_printoptions(self):
def f(x):
with np.printoptions(precision=2, suppress=True):
jax.debug.print("{}", x)
x = np.array([1.2345, 2.3456, 1E-7])

# Default numpy print options:
with jtu.capture_stdout() as output:
jax.debug.print("{}", x)
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")

# Modified print options without JIT:
with jtu.capture_stdout() as output:
f(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")

# Modified print options with JIT:
with jtu.capture_stdout() as output:
jax.jit(f)(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")


class DebugPrintTransformationTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 0f4677b

Please sign in to comment.