Skip to content

Commit

Permalink
[better_errors] Ensure debug_info.arg_names is never None.
Browse files Browse the repository at this point in the history
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`.
This should not have an effect on caches because the
`fun_sourceinfo` is part of a cache key only when
`fun` is also part of the same key.
  • Loading branch information
gnecula committed Jan 20, 2025
1 parent e41f4ca commit 9470b02
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 34 deletions.
8 changes: 5 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ def f_(*args):
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
return _saved_residuals(jaxpr, dbg.arg_names)

def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand Down Expand Up @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars),
[""] * len(jaxpr_known.invars))
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
Expand Down
47 changes: 32 additions & 15 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,15 +603,13 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo | None:
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


Expand Down Expand Up @@ -639,28 +637,47 @@ def fun_sourceinfo(fun: Callable) -> str | None:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None
except AttributeError as e:
try:
return str(fun)
except:
return "<unknown>"

# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None:
return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.
If the `fn_signature` is given then we get from it the names of the
top-level arguments, else we use names like `args[0[]`, `args[1]`, etc.
Since this is used early in the transformations, we allow even cases when
the args and kwargs do not match the `inspect.Signature`.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
return arg_names

@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
Expand Down
10 changes: 8 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@

class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__}.
func_src_info: str

# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
# Uses `None` for the args that do not correspond to user-named arguments,
# e.g., tangent args in jax.jvp.
arg_names: tuple[str | None, ...]
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)

class Jaxpr:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ def _origin_msg(self):
return ""

origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
f"{dbg.func_src_info} for {dbg.traced_for}. ")
if invar_pos and dbg.arg_names:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
Expand Down Expand Up @@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
) -> lu.TracingDebugInfo:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ class TracingDebugInfo(NamedTuple):
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__}.
func_src_info: str

# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
Expand Down
16 changes: 8 additions & 8 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CompilerParams(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


# TODO(necula): clean up the splitting of the fun_sourceinfo
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
Expand All @@ -91,7 +92,7 @@ def __str__(self):

@staticmethod
def from_pallas_call(pallas_call_name: str | None,
src_info : str | None) -> NameAndSrcInfo:
src_info: str) -> NameAndSrcInfo:
"""Formats the name and the source info.
Args:
Expand All @@ -101,16 +102,15 @@ def from_pallas_call(pallas_call_name: str | None,
"""
if pallas_call_name is not None:
pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name)
if src_info is None:
return NameAndSrcInfo(
"unknown" if pallas_call_name is None else pallas_call_name,
"")
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
src_info_parts = src_info.split(" at ")
if len(src_info_parts) > 1:
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
else:
return NameAndSrcInfo(src_info_parts[0], "")


split_list = util.split_list
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def wrapped(*args):
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})
arg_names = kernel_debug_info and kernel_debug_info.arg_names
arg_names = kernel_debug_info.arg_names
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str | None, ...] | None
arg_names: tuple[str | None, ...]
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
Expand Down Expand Up @@ -1189,7 +1189,8 @@ def unpack(key):
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
if debug_info is not None and debug_info.func_src_info:
_, _, *rest = debug_info.func_src_info.split(' ')
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
src_info = " defined at " + ' '.join(rest)
else:
src_info = ''
Expand Down
68 changes: 68 additions & 0 deletions tests/debug_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
Expand All @@ -40,6 +41,73 @@

class DebugInfoTest(jtu.JaxTestCase):

def test_debug_info_basic(self):
def my_f(x, y, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.*")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk)

def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z"))

def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']"))

def test_debug_info_with_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z"))

def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))

def test_debug_info_too_many_args(self):
def my_f(x):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))

def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "<built-in function max>")
self.assertEqual(dbg.arg_names, ("args[0]",))

def test_debug_info_no_source_info_callable(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y

dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<.*Foo.*>")
self.assertEqual(dbg.arg_names, ("y",))

def helper_save_tracer(self, x):
self._saved_tracer = x
return x
Expand Down

0 comments on commit 9470b02

Please sign in to comment.