From 9470b02ee7586b71c3abf318bf0e4adc73ffd902 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 20 Jan 2025 17:17:44 +0100 Subject: [PATCH] [better_errors] Ensure debug_info.arg_names is never None. Most places in the code assumed this already, but often that usage is error reporting code, which is not yet well tested. When we cannot get the `inspect.Signature` we generate the flattened argument names as: `args[0]`, `args[1]`, `kwargs['foo']`, ... Previously, in these cases we returned `arg_names` is None. We also add support for `api_util.fun_sourceinfo` even for cases when the `fun.__code__` is not available. In those cases we used to say that `fun_sourceinfo` is `None`. Now, we use the string representation of `fun`. This should not have an effect on caches because the `fun_sourceinfo` is part of a cache key only when `fun` is also part of the same key. --- jax/_src/ad_checkpoint.py | 8 ++-- jax/_src/api_util.py | 47 ++++++++++++------ jax/_src/core.py | 10 +++- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/linear_util.py | 3 +- jax/_src/pallas/core.py | 16 +++---- jax/_src/pallas/pallas_call.py | 2 +- jax/_src/pjit.py | 5 +- tests/debug_info_test.py | 68 +++++++++++++++++++++++++++ 9 files changed, 129 insertions(+), 34 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a55432c5fe5c..4a32edc5fe3d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -454,9 +454,10 @@ def f_(*args): out_tree = lambda: tree_structure(out_shape) assert len(jaxpr.invars) == len(in_leaves) dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals") - return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore + return _saved_residuals(jaxpr, dbg.arg_names) -def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]: +def _saved_residuals(jaxpr: core.Jaxpr, + arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, _, staged_unk = partition_list(in_used_staged, in_unknowns) res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] - body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) + body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), + [""] * len(jaxpr_known.invars)) logger.log(log_level, 'remat-decorated function ' + 'saving inputs with shapes:\n' * bool(res_invars) + diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index f35ae05850e9..28508be72a6c 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -603,15 +603,13 @@ def tracing_debug_info( # TODO(necula): check if we really need this, e.g., to speed up tracing. sourceinfo: str | None = None, signature: inspect.Signature | None = None, -) -> TracingDebugInfo | None: +) -> TracingDebugInfo: if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - if arg_names is None: - return None return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) @@ -639,28 +637,47 @@ def fun_sourceinfo(fun: Callable) -> str | None: filename = fun.__code__.co_filename lineno = fun.__code__.co_firstlineno return f"{fun.__name__} at {filename}:{lineno}" - except AttributeError: - return None + except AttributeError as e: + try: + return str(fun) + except: + return "" # TODO(necula): this should never return None def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...] | None: - if fn_signature is None: - return None + ) -> tuple[str | None, ...]: + """Returns the names of the non-static arguments. + + If the `fn_signature` is given then we get from it the names of the + top-level arguments, else we use names like `args[0[]`, `args[1]`, etc. + + Since this is used early in the transformations, we allow even cases when + the args and kwargs do not match the `inspect.Signature`. + """ static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} - try: - ba = fn_signature.bind(*args_, **kwargs) - except (ValueError, TypeError): - return None - return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) + kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + if fn_signature is not None: + try: + ba = fn_signature.bind(*args_, **kwargs_) + except (ValueError, TypeError): + pass + else: + return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() + for path, l in generate_key_paths(x) if l is not static) + args_arg_names = tuple(f'args{keystr(path)}' + for path, l in generate_key_paths(args_) + if l is not static) + kwargs_arg_names = tuple(f'kwargs{keystr(path)}' + for path, l in generate_key_paths(kwargs_) + if l is not static) + arg_names = args_arg_names + kwargs_arg_names + return arg_names @lu.transformation_with_aux2 def result_paths(_fun, _store, *args, **kwargs): diff --git a/jax/_src/core.py b/jax/_src/core.py index df061d5f8b8f..f036a5c8ab0e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -79,8 +79,14 @@ class JaxprDebugInfo(NamedTuple): traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' - arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__}. + func_src_info: str + + # The paths of the flattened non-static argnames, + # e.g. ('x', 'dict_arg["a"]', ... ). + # Uses `None` for the args that do not correspond to user-named arguments, + # e.g., tangent args in jax.jvp. + arg_names: tuple[str | None, ...] result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) class Jaxpr: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index cc3399f4e4e2..ea1df444cd9c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1545,7 +1545,7 @@ def _origin_msg(self): return "" origin = ("The error occurred while tracing the function " - f"{dbg.func_src_info or ''} for {dbg.traced_for}. ") + f"{dbg.func_src_info} for {dbg.traced_for}. ") if invar_pos and dbg.arg_names: try: arg_names = [dbg.arg_names[i] for i in invar_pos] @@ -2116,7 +2116,7 @@ def tracing_debug_info( out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str -) -> lu.TracingDebugInfo | None: +) -> lu.TracingDebugInfo: # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead # We just have to make sure we grad the debugging information when we have # the unflattened args diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 919cd90c3521..f80a41d4aa8d 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -259,7 +259,8 @@ class TracingDebugInfo(NamedTuple): Formed just before staging to a jaxpr and read in trace-time error messages. """ traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__}. + func_src_info: str # The paths of the flattened non-static argnames, # e.g. ('x', 'dict_arg["a"]', ... ). diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index df825f4e20a1..684500c99d95 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -75,6 +75,7 @@ class CompilerParams(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +# TODO(necula): clean up the splitting of the fun_sourceinfo @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: #: The name of the pallas_call or the name of the kernel function. @@ -91,7 +92,7 @@ def __str__(self): @staticmethod def from_pallas_call(pallas_call_name: str | None, - src_info : str | None) -> NameAndSrcInfo: + src_info: str) -> NameAndSrcInfo: """Formats the name and the source info. Args: @@ -101,16 +102,15 @@ def from_pallas_call(pallas_call_name: str | None, """ if pallas_call_name is not None: pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name) - if src_info is None: - return NameAndSrcInfo( - "unknown" if pallas_call_name is None else pallas_call_name, - "") if pallas_call_name is not None: return NameAndSrcInfo(pallas_call_name, f"for kernel function {src_info}") - src_info_parts = src_info.split(" ") - return NameAndSrcInfo(src_info_parts[0], - " ".join(src_info_parts[1:])) + src_info_parts = src_info.split(" at ") + if len(src_info_parts) > 1: + return NameAndSrcInfo(src_info_parts[0], + " ".join(src_info_parts[1:])) + else: + return NameAndSrcInfo(src_info_parts[0], "") split_list = util.split_list diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 64cd93ba1136..87a63db3928a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1814,7 +1814,7 @@ def wrapped(*args): "pallas_call kernel", kernel, [1] * len(kernel_fun_sig.parameters), {}) - arg_names = kernel_debug_info and kernel_debug_info.arg_names + arg_names = kernel_debug_info.arg_names del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 636d8a68f142..229088e16387 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -537,7 +537,7 @@ class PjitParams(NamedTuple): in_tree: PyTreeDef out_tree: PyTreeDef donated_invars: tuple[bool, ...] - arg_names: tuple[str | None, ...] | None + arg_names: tuple[str | None, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] abstract_mesh: AbstractMesh @@ -1189,7 +1189,8 @@ def unpack(key): # have we seen this function before at all? fun_name = getattr(f, '__qualname__', f) if debug_info is not None and debug_info.func_src_info: - _, _, *rest = debug_info.func_src_info.split(' ') + # TODO(necula): clean up the extraction of the source info + _, *rest = debug_info.func_src_info.split(' at ') src_info = " defined at " + ' '.join(rest) else: src_info = '' diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 3722e83e74df..5409f677926a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -22,6 +22,7 @@ from absl.testing import absltest, parameterized import jax from jax import lax +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import test_util as jtu @@ -40,6 +41,73 @@ class DebugInfoTest(jtu.JaxTestCase): + def test_debug_info_basic(self): + def my_f(x, y, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) + self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.*") + self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertIsNone(dbg.result_paths_thunk) + + def test_debug_info_arg_passed_as_kwarg(self): + def my_f(x, y, z): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3)) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) + + def test_debug_info_pytrees(self): + def my_f(x_tree, *, y_tree): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),), + dict(y_tree=dict(z=3, w=4))) + self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", + "y_tree['w']", "y_tree['z']")) + + def test_debug_info_with_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x", "z")) + + def test_debug_info_with_pytrees_and_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)), + dict(z=(3, 4), w=(5, 6)), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + + def test_debug_info_too_many_args(self): + def my_f(x): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3)) + self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) + + def test_debug_info_no_source_info_built_in(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", max, (1,), {}) + self.assertEqual(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("args[0]",)) + + def test_debug_info_no_source_info_callable(self): + class Foo: + x: int + def __call__(self, y): + return self.x + y + + dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + self.assertRegex(dbg.func_src_info, "<.*Foo.*>") + self.assertEqual(dbg.arg_names, ("y",)) + def helper_save_tracer(self, x): self._saved_tracer = x return x