Skip to content

Commit

Permalink
Merge pull request #25988 from gnecula:debug_info_tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717592689
  • Loading branch information
Google-ML-Automation committed Jan 20, 2025
2 parents 799eb98 + e5d89e7 commit e41f4ca
Show file tree
Hide file tree
Showing 5 changed files with 478 additions and 287 deletions.
6 changes: 6 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ jax_multiplatform_test(
shard_count = 10,
)

jax_multiplatform_test(
name = "debug_info_test",
srcs = ["debug_info_test.py"],
enable_configs = ["tpu_v3_2x2"],
)

jax_multiplatform_test(
name = "device_test",
srcs = ["device_test.py"],
Expand Down
231 changes: 0 additions & 231 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,75 +1322,6 @@ def test_jit_lower_compile_executable(self):
self.assertIsNotNone(f.runtime_executable())
self.assertIsNotNone(g.runtime_executable())

def test_jit_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())

lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)

hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)

@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())

lowered = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.)

hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)

hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)

@parameterized.parameters(['a', 'b', [('a', 'b')]])
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']

lowered = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
self.assertNotIn("kwargs['a']", hlo_str)
self.assertNotIn("kwargs['b']", hlo_str)

hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in (
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
"kwargs['w']", "kwargs['a']", "kwargs['b']"
):
self.assertNotIn(s, hlo_str)

def test_jit_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}

hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)

def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
Expand Down Expand Up @@ -3598,21 +3529,6 @@ def f(_, __):
# level, which is no longer live.
jax.jit(jnp.add)(jnp.ones(()), count)

def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1

with self.assertRaisesRegex(UnexpectedTracerError,
"for pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1

with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.eval_shape(self.helper_save_tracer, 1)
_ = self._saved_tracer+1

def test_escaped_tracer_shape_dtype(self):
with self.assertRaisesRegex(core.UnexpectedTracerError, r"int32\[4,3\]"):
Expand Down Expand Up @@ -3659,120 +3575,6 @@ def f():

f() # doesn't crash

def test_concrete_error_because_arg_unary(self):
@jax.jit
def f(x):
if x > 0:
return x
else:
return 0

msg = r"on the value of the argument x"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)

def test_concrete_error_because_arg_binary(self):
@jax.jit
def f(x, y):
if x > y:
return x
else:
return y

msg = r"on the values of the arguments x and y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)

def test_concrete_error_because_arg_ternary(self):
@jax.jit
def f(x, y, z):
if x > z:
return x
else:
return y

msg = r"on the values of the arguments x and z"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)

with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, z=3)

with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, y=2, z=3)

def test_concrete_error_because_arg_varargs(self):
@jax.jit
def f(*args):
x, y, z = args
if x > z:
return x
else:
return y

msg = r"on the values of the arguments args"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)

def test_concrete_error_because_arg_kwargs(self):
@jax.jit
def f(**kwargs):
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
if x > z:
return x
else:
return y

msg = r"on the values of the arguments kwargs"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)

def test_concrete_error_because_arg_pytree(self):
@jax.jit
def f(xy, z):
x, y = xy
if x > 0:
return x
else:
return y

msg = r"on the value of the argument xy"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)

def test_concrete_error_because_const(self):
@jax.jit
def f():
assert jnp.add(1, 1) > 0

msg = "on these lines"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

def test_concrete_error_because_const_2(self):
@jax.jit
def f():
result = sum(jnp.add(1, 1) for _ in range(6))
assert result > 0

msg = "Additional originating lines are not shown."
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

def test_concrete_error_with_nested_call(self):
@jax.jit
def f(x, y):
if y:
return x

@jax.jit
def g(x):
return f(x, True)

msg = r"on the value of the argument y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)

def test_linearize_aux(self):
def fn(x):
return x * 2 - 3, x > 0
Expand Down Expand Up @@ -4940,39 +4742,6 @@ def f2(x):
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)

def test_remat_concrete_error(self):
@jax.remat # no static_argnums or concrete
def g(x):
if x > 0:
return lax.sin(x)
else:
return lax.cos(x)

with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(3.)

@partial(jax.remat, static_argnums=(0,)) # using static_argnums but...
def g(x):
if x > 0: # jnp operations still get staged!
return lax.sin(x)
else:
return lax.cos(x)

with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(jnp.array(3.))

# But don't raise an error mentioning static_argnums here:
@jax.remat
def g(x):
jax.jit(lambda: 0 if jnp.add(1, 1) else 0)()
return lax.sin(x)

try:
g(jnp.array(3.))
except core.ConcretizationTypeError as e:
msg = str(e)
self.assertNotIn('static_argnums', msg)

@unittest.skip
def test_remat_grad_python_control_flow_static_argnums(self):
@partial(jax.remat, static_argnums=(0,))
Expand Down
Loading

0 comments on commit e41f4ca

Please sign in to comment.