diff --git a/tests/BUILD b/tests/BUILD index cfa2083b2735..6a8518299863 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -38,6 +38,12 @@ jax_multiplatform_test( shard_count = 10, ) +jax_multiplatform_test( + name = "debug_info_test", + srcs = ["debug_info_test.py"], + enable_configs = ["tpu_v3_2x2"], +) + jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index fbe3610b1f08..5bf9a1203cc1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1322,75 +1322,6 @@ def test_jit_lower_compile_executable(self): self.assertIsNotNone(f.runtime_executable()) self.assertIsNotNone(g.runtime_executable()) - def test_jit_lower_arg_info(self): - def f(x, y, *args, **kwargs): - return y['hi'] + args[1] + sum(kwargs.values()) - - lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertNotIn("\"x\"", hlo_str) - self.assertIn("y['hi']", hlo_str) - self.assertNotIn("args[0]", hlo_str) - self.assertIn("args[1]", hlo_str) - self.assertIn("kwargs['z']", hlo_str) - self.assertIn("kwargs['w']", hlo_str) - - hlo_str = lowered.as_text("stablehlo", debug_info=False) - for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): - self.assertNotIn(s, hlo_str) - - @parameterized.parameters([0, 2, [(0, 2)]]) - def test_jit_lower_arg_info_static_argnums(self, static_argnums): - def f(x, y, *args, **kwargs): - return y['hi'] + args[1] + sum(kwargs.values()) - - lowered = jax.jit(f, static_argnums=static_argnums).lower( - (1.,), {'hi': 2.}, 3., 4., z=5., w=6.) - - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertNotIn("\"x\"", hlo_str) - self.assertIn("y['hi']", hlo_str) - self.assertNotIn("args[0]", hlo_str) - self.assertIn("args[1]", hlo_str) - self.assertIn("kwargs['z']", hlo_str) - self.assertIn("kwargs['w']", hlo_str) - - hlo_str = lowered.as_text("stablehlo", debug_info=False) - for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): - self.assertNotIn(s, hlo_str) - - @parameterized.parameters(['a', 'b', [('a', 'b')]]) - def test_jit_lower_arg_info_static_argnames(self, static_argnames): - def f(x, y, *args, **kwargs): - return y['hi'] + args[1] + kwargs['z'] + kwargs['w'] - - lowered = jax.jit(f, static_argnames=static_argnames).lower( - (1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertNotIn("\"x\"", hlo_str) - self.assertIn("y['hi']", hlo_str) - self.assertNotIn("args[0]", hlo_str) - self.assertIn("args[1]", hlo_str) - self.assertIn("kwargs['z']", hlo_str) - self.assertIn("kwargs['w']", hlo_str) - self.assertNotIn("kwargs['a']", hlo_str) - self.assertNotIn("kwargs['b']", hlo_str) - - hlo_str = lowered.as_text("stablehlo", debug_info=False) - for s in ( - "\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", - "kwargs['w']", "kwargs['a']", "kwargs['b']" - ): - self.assertNotIn(s, hlo_str) - - def test_jit_lower_result_info(self): - def f(x, y, z): - return {'a': x, 'b': [y]} - - hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True) - self.assertIn("jax.result_info = \"['a']\"", hlo_str) - self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) - def test_jit_lower_compile_with_compiler_options(self): def f(x): return jnp.sqrt(x ** 2) + 1. @@ -3598,21 +3529,6 @@ def f(_, __): # level, which is no longer live. jax.jit(jnp.add)(jnp.ones(()), count) - def test_escaped_tracer_transform_name(self): - with self.assertRaisesRegex(UnexpectedTracerError, - "for jit"): - jax.jit(self.helper_save_tracer)(1) - _ = self._saved_tracer+1 - - with self.assertRaisesRegex(UnexpectedTracerError, - "for pmap"): - jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2))) - _ = self._saved_tracer+1 - - with self.assertRaisesRegex(UnexpectedTracerError, - "for jit"): - jax.eval_shape(self.helper_save_tracer, 1) - _ = self._saved_tracer+1 def test_escaped_tracer_shape_dtype(self): with self.assertRaisesRegex(core.UnexpectedTracerError, r"int32\[4,3\]"): @@ -3659,120 +3575,6 @@ def f(): f() # doesn't crash - def test_concrete_error_because_arg_unary(self): - @jax.jit - def f(x): - if x > 0: - return x - else: - return 0 - - msg = r"on the value of the argument x" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1) - - def test_concrete_error_because_arg_binary(self): - @jax.jit - def f(x, y): - if x > y: - return x - else: - return y - - msg = r"on the values of the arguments x and y" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1, 2) - - def test_concrete_error_because_arg_ternary(self): - @jax.jit - def f(x, y, z): - if x > z: - return x - else: - return y - - msg = r"on the values of the arguments x and z" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1, 2, 3) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1, 2, z=3) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1, y=2, z=3) - - def test_concrete_error_because_arg_varargs(self): - @jax.jit - def f(*args): - x, y, z = args - if x > z: - return x - else: - return y - - msg = r"on the values of the arguments args" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(1, 2, 3) - - def test_concrete_error_because_arg_kwargs(self): - @jax.jit - def f(**kwargs): - x, y, z = kwargs['x'], kwargs['y'], kwargs['z'] - if x > z: - return x - else: - return y - - msg = r"on the values of the arguments kwargs" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(x=1, y=2, z=3) - - def test_concrete_error_because_arg_pytree(self): - @jax.jit - def f(xy, z): - x, y = xy - if x > 0: - return x - else: - return y - - msg = r"on the value of the argument xy" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f((1, 2), z=3) - - def test_concrete_error_because_const(self): - @jax.jit - def f(): - assert jnp.add(1, 1) > 0 - - msg = "on these lines" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f() - - def test_concrete_error_because_const_2(self): - @jax.jit - def f(): - result = sum(jnp.add(1, 1) for _ in range(6)) - assert result > 0 - - msg = "Additional originating lines are not shown." - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f() - - def test_concrete_error_with_nested_call(self): - @jax.jit - def f(x, y): - if y: - return x - - @jax.jit - def g(x): - return f(x, True) - - msg = r"on the value of the argument y" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - g(1) - def test_linearize_aux(self): def fn(x): return x * 2 - 3, x > 0 @@ -4940,39 +4742,6 @@ def f2(x): expected = f_lin_expected(3.) self.assertAllClose(ans, expected, check_dtypes=False) - def test_remat_concrete_error(self): - @jax.remat # no static_argnums or concrete - def g(x): - if x > 0: - return lax.sin(x) - else: - return lax.cos(x) - - with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): - g(3.) - - @partial(jax.remat, static_argnums=(0,)) # using static_argnums but... - def g(x): - if x > 0: # jnp operations still get staged! - return lax.sin(x) - else: - return lax.cos(x) - - with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): - g(jnp.array(3.)) - - # But don't raise an error mentioning static_argnums here: - @jax.remat - def g(x): - jax.jit(lambda: 0 if jnp.add(1, 1) else 0)() - return lax.sin(x) - - try: - g(jnp.array(3.)) - except core.ConcretizationTypeError as e: - msg = str(e) - self.assertNotIn('static_argnums', msg) - @unittest.skip def test_remat_grad_python_control_flow_static_argnums(self): @partial(jax.remat, static_argnums=(0,)) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py new file mode 100644 index 000000000000..3722e83e74df --- /dev/null +++ b/tests/debug_info_test.py @@ -0,0 +1,472 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import functools +import math +import operator + +from absl.testing import absltest, parameterized +import jax +from jax import lax +from jax._src import config +from jax._src import core +from jax._src import test_util as jtu +from jax._src.compilation_cache import is_persistent_cache_enabled +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +from jax.errors import UnexpectedTracerError +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + +class DebugInfoTest(jtu.JaxTestCase): + + def helper_save_tracer(self, x): + self._saved_tracer = x + return x + + def test_jit_lower_arg_info(self): + def f(x, y, *args, **kwargs): + return y['hi'] + args[1] + sum(kwargs.values()) + + lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.) + hlo_str = lowered.as_text("stablehlo", debug_info=True) + self.assertNotIn("\"x\"", hlo_str) + self.assertIn("y['hi']", hlo_str) + self.assertNotIn("args[0]", hlo_str) + self.assertIn("args[1]", hlo_str) + self.assertIn("kwargs['z']", hlo_str) + self.assertIn("kwargs['w']", hlo_str) + + hlo_str = lowered.as_text("stablehlo", debug_info=False) + for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): + self.assertNotIn(s, hlo_str) + + @parameterized.parameters([0, 2, [(0, 2)]]) + def test_jit_lower_arg_info_static_argnums(self, static_argnums): + def f(x, y, *args, **kwargs): + return y['hi'] + args[1] + sum(kwargs.values()) + + lowered = jax.jit(f, static_argnums=static_argnums).lower( + (1.,), {'hi': 2.}, 3., 4., z=5., w=6.) + + hlo_str = lowered.as_text("stablehlo", debug_info=True) + self.assertNotIn("\"x\"", hlo_str) + self.assertIn("y['hi']", hlo_str) + self.assertNotIn("args[0]", hlo_str) + self.assertIn("args[1]", hlo_str) + self.assertIn("kwargs['z']", hlo_str) + self.assertIn("kwargs['w']", hlo_str) + + hlo_str = lowered.as_text("stablehlo", debug_info=False) + for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): + self.assertNotIn(s, hlo_str) + + @parameterized.parameters(['a', 'b', [('a', 'b')]]) + def test_jit_lower_arg_info_static_argnames(self, static_argnames): + def f(x, y, *args, **kwargs): + return y['hi'] + args[1] + kwargs['z'] + kwargs['w'] + + lowered = jax.jit(f, static_argnames=static_argnames).lower( + (1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.) + hlo_str = lowered.as_text("stablehlo", debug_info=True) + self.assertNotIn("\"x\"", hlo_str) + self.assertIn("y['hi']", hlo_str) + self.assertNotIn("args[0]", hlo_str) + self.assertIn("args[1]", hlo_str) + self.assertIn("kwargs['z']", hlo_str) + self.assertIn("kwargs['w']", hlo_str) + self.assertNotIn("kwargs['a']", hlo_str) + self.assertNotIn("kwargs['b']", hlo_str) + + hlo_str = lowered.as_text("stablehlo", debug_info=False) + for s in ( + "\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", + "kwargs['w']", "kwargs['a']", "kwargs['b']" + ): + self.assertNotIn(s, hlo_str) + + def test_jit_lower_result_info(self): + def f(x, y, z): + return {'a': x, 'b': [y]} + + hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True) + self.assertIn("jax.result_info = \"['a']\"", hlo_str) + self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) + + def test_jit_lower_compile_arg_type_mismatch(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + x = jnp.array(1, dtype=int) + x_f32 = x.astype(jnp.float32) + x_i32 = x.astype(jnp.int32) + f_exe = jax.jit(f).lower(x_f32).compile() + self.assertRaisesRegex( + TypeError, + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*", + lambda: f_exe(x_i32)) + + def test_jit_bad_input(self): + def f(x): + return x + + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + "value is of type .* and was passed to the function at path x.") + with self.assertRaisesRegex(TypeError, err_str): + jax.jit(f)("foo") + + # Jax type objects aren't valid data arguments. + with self.assertRaisesRegex(TypeError, err_str): + jax.jit(f)(jnp.int32) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations(self): + @jax.jit + def f(x, y): + return jnp.sin(x) * y['hi'] + + x = jnp.float32(1.) + y = {'hi': jnp.arange(3., dtype='float32')} + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + + # print on first miss, not on hit + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(x, y) + f(x, y) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn('TRACING CACHE MISS', msg) + self.assertIn('never seen function', msg) + + # shape change + y_ = {'hi': jnp.arange(4, dtype='float32')} + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(x, y_) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn('never seen input type signature', msg) + self.assertIn('closest seen input type signature has 1 mismatches', msg) + self.assertIn('seen f32[3], but now given f32[4]', msg) + + # weak type change (assuming no x64) + if not config.enable_x64.value: + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(1., y) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn('weak_type=True', msg) + self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(1, y=y) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) + + # tracing config change + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + with jax.numpy_rank_promotion('warn'): + f(x, y) + # depending on the backend, we may or may not get persistent cache warnings + self.assertTrue(1 <= len(cm.output) <= expected_log_len) + msg = cm.output[0] + self.assertIn("tracing context doesn't match", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_new_function_in_loop(self): + @jax.jit + def f(x, y): + return jnp.sin(x) * y['hi'] + + x = jnp.float32(1.) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + for _ in range(2): + jax.jit(lambda x: 2 * x)(3) + if is_persistent_cache_enabled(): + # number of warnings depends on the backend + self.assertTrue(4 <= len(cm.output) <= 6) + msg = cm.output[3] + self.assertIn('another function defined on the same line', msg) + else: + self.assertLen(cm.output, 2) + _, msg = cm.output + self.assertIn('another function defined on the same line', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_unpacks_transforms(self): + # Tests that the explain_tracing_cache_miss() function does not throw an + # error when unpacking `transforms` with a length greater than 3. + @jax.jit + def f(key): + return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(jax.random.key(seed=123)) + + if is_persistent_cache_enabled(): + # 5 warnings from tracing cache, 5-10 from persistent cache depending on + # the backend + self.assertTrue(10 <= len(cm.output) <= 15) + self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) + else: + self.assertLen(cm.output, 5) + for msg in cm.output: + self.assertIn("TRACING CACHE MISS", msg) + + def test_cache_miss_explanations_no_source_info(self): + # ``operator.add`` is a built-in function and does not have source info. + with config.explain_cache_misses(True): + jax.jit(operator.add)(42, 24) + + + def test_concrete_error_because_arg_unary(self): + @jax.jit + def f(x): + if x > 0: + return x + else: + return 0 + + msg = r"on the value of the argument x" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1) + + def test_concrete_error_because_arg_binary(self): + @jax.jit + def f(x, y): + if x > y: + return x + else: + return y + + msg = r"on the values of the arguments x and y" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1, 2) + + def test_concrete_error_because_arg_ternary(self): + @jax.jit + def f(x, y, z): + if x > z: + return x + else: + return y + + msg = r"on the values of the arguments x and z" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1, 2, 3) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1, 2, z=3) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1, y=2, z=3) + + def test_concrete_error_because_arg_varargs(self): + @jax.jit + def f(*args): + x, y, z = args + if x > z: + return x + else: + return y + + msg = r"on the values of the arguments args" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(1, 2, 3) + + def test_concrete_error_because_arg_kwargs(self): + @jax.jit + def f(**kwargs): + x, y, z = kwargs['x'], kwargs['y'], kwargs['z'] + if x > z: + return x + else: + return y + + msg = r"on the values of the arguments kwargs" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(x=1, y=2, z=3) + + def test_concrete_error_because_arg_pytree(self): + @jax.jit + def f(xy, z): + x, y = xy + if x > 0: + return x + else: + return y + + msg = r"on the value of the argument xy" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f((1, 2), z=3) + + def test_concrete_error_because_const(self): + @jax.jit + def f(): + assert jnp.add(1, 1) > 0 + + msg = "on these lines" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f() + + def test_concrete_error_because_const_2(self): + @jax.jit + def f(): + result = sum(jnp.add(1, 1) for _ in range(6)) + assert result > 0 + + msg = "Additional originating lines are not shown." + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f() + + def test_concrete_error_with_nested_call(self): + @jax.jit + def f(x, y): + if y: + return x + + @jax.jit + def g(x): + return f(x, True) + + msg = r"on the value of the argument y" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + g(1) + + def test_escaped_tracer_transform_name(self): + with self.assertRaisesRegex(UnexpectedTracerError, + "for jit"): + jax.jit(self.helper_save_tracer)(1) + _ = self._saved_tracer+1 + + with self.assertRaisesRegex(UnexpectedTracerError, + "for pmap"): + jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2))) + _ = self._saved_tracer+1 + + with self.assertRaisesRegex(UnexpectedTracerError, + "for jit"): + jax.eval_shape(self.helper_save_tracer, 1) + _ = self._saved_tracer+1 + + def test_remat_concrete_error(self): + @jax.remat # no static_argnums or concrete + def g(x): + if x > 0: + return lax.sin(x) + else: + return lax.cos(x) + + with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): + g(3.) + + @functools.partial(jax.remat, static_argnums=(0,)) # using static_argnums but... + def g(x): + if x > 0: # jnp operations still get staged! + return lax.sin(x) + else: + return lax.cos(x) + + with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): + g(jnp.array(3.)) + + # But don't raise an error mentioning static_argnums here: + @jax.remat + def g(x): + jax.jit(lambda: 0 if jnp.add(1, 1) else 0)() + return lax.sin(x) + + try: + g(jnp.array(3.)) + except core.ConcretizationTypeError as e: + msg = str(e) + self.assertNotIn('static_argnums', msg) + + +class EagerPmapMixin: + + def setUp(self): + super().setUp() + stack = contextlib.ExitStack() + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.ignore_warning( + message="Some donated buffers were not usable", category=UserWarning)) + self.addCleanup(stack.close) + + +@jtu.pytest_mark_if_available('multiaccelerator') +class PythonPmapEagerTest(EagerPmapMixin, jtu.JaxTestCase): + def test_pmap_lower_arg_info(self): + def f(x, y, *args, **kwargs): + return y['hi'] + args[1] + sum(kwargs.values()) + + lowered = jax.pmap(f).lower( + {'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]), + jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.])) + hlo_str = lowered.as_text("stablehlo", debug_info=True) + self.assertNotIn("\"x\"", hlo_str) + self.assertIn("y['hi']", hlo_str) + self.assertIn("args[0]", hlo_str) + self.assertIn("args[1]", hlo_str) + self.assertIn("kwargs['z']", hlo_str) + self.assertIn("kwargs['w']", hlo_str) + + def test_pmap_lower_result_info(self): + def f(x, y, z): + return {'a': x, 'b': [y]} + + lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),), + [jnp.array([3])]) + hlo_str = lowered.as_text("stablehlo", debug_info=True) + self.assertIn("jax.result_info = \"['a']\"", hlo_str) + self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) + + def testLowerCompileArgTypeMismatch(self): + f = jax.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') + shape = (jax.device_count(), 4) + x = np.arange(math.prod(shape), dtype=int).reshape(shape) + x_f32 = x.astype(jnp.float32) + x_i32 = x.astype(jnp.int32) + f_exe = f.lower(x_f32).compile() + self.assertRaisesRegex( + TypeError, + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*", + lambda: f_exe(x_i32)) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3944c92d4368..175b96a9c5a6 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -29,7 +29,6 @@ import jax from jax._src import core from jax import dtypes -from jax.errors import UnexpectedTracerError from jax import lax from jax import random from jax._src import test_util as jtu @@ -589,20 +588,6 @@ def test_fori_loop_returns_init_with_nonpositive_length( init = jnp.float32(10) self.assertEqual(fori_loop_with_static_upper_and_lower(init), init) - def test_fori_error_points_to_user_code(self): - # See https://github.com/jax-ml/jax/issues/23637 - def my_body(_, c): - return bool(c) - - with self.assertRaisesRegex( - jax.errors.TracerBoolConversionError, - "occurred while tracing the function my_body at .*control_flow_test.py.* for scan"): - jax.lax.fori_loop(0, 5, my_body, 3.) - - with self.assertRaisesRegex( - jax.errors.TracerBoolConversionError, - "occurred while tracing the function my_body at .*control_flow_test.py.* for while_loop"): - jax.jit(lambda ubound: jax.lax.fori_loop(0, ubound, my_body, 3.))(5) def testForiLoopBatched(self): def body_fun(i, loop_carry): @@ -2750,22 +2735,6 @@ def cond_fun(val): self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False) - def test_unexpected_tracer_error(self): - with self.assertRaisesRegex(UnexpectedTracerError, "for while_loop"): - lst = [] - def side_effecting_body(val): - lst.append(val) - return val+1 - lax.while_loop(lambda x: x < 2, side_effecting_body, 1) - lst[0] += 1 - - with self.assertRaisesRegex(UnexpectedTracerError, "for scan"): - lst = [] - def side_effecting_scan(carry, val): - lst.append(val) - return carry, val+1 - lax.scan(side_effecting_scan, None, jnp.ones((2, 2))) - lst[0] += 1 def test_while_loop_fixed_point_with_batched_pred_and_consts(self): def f(i, x): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 9f00cd7878f4..af2d03e2945d 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2119,31 +2119,6 @@ def test_remat_of_pmap_policy(self, remat): self.assertEqual(jaxpr_text.count(' sin '), 1) self.assertEqual(jaxpr_text.count(' cos '), 2) - def test_pmap_lower_arg_info(self): - def f(x, y, *args, **kwargs): - return y['hi'] + args[1] + sum(kwargs.values()) - - lowered = jax.pmap(f).lower( - {'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]), - jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.])) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertNotIn("\"x\"", hlo_str) - self.assertIn("y['hi']", hlo_str) - self.assertIn("args[0]", hlo_str) - self.assertIn("args[1]", hlo_str) - self.assertIn("kwargs['z']", hlo_str) - self.assertIn("kwargs['w']", hlo_str) - - def test_pmap_lower_result_info(self): - def f(x, y, z): - return {'a': x, 'b': [y]} - - lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),), - [jnp.array([3])]) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertIn("jax.result_info = \"['a']\"", hlo_str) - self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) - def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),