Skip to content

Commit

Permalink
[better_errors] Improvements in propagation of debugging info
Browse files Browse the repository at this point in the history
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
  • Loading branch information
gnecula committed Jan 20, 2025
1 parent d415c80 commit 688051a
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 85 deletions.
23 changes: 14 additions & 9 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from jax._src import effects
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, debug_info, fun_sourceinfo, fun_signature)
from jax._src import api_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand All @@ -42,7 +41,7 @@
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)

Expand Down Expand Up @@ -324,8 +323,9 @@ def foo(x, y):
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = debug_info("checkpoint / remat", fun_sourceinfo(fun),
fun_signature(fun), args, kwargs, static_argnums, ())
debug = api_util.tracing_debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
in_avals = [core.shaped_abstractify(x) for x in args_flat]
Expand Down Expand Up @@ -415,8 +415,12 @@ def new_fun(*dyn_args, **kwargs):
# This helper is similar to those in control_flow/common.py, but with
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals, debug):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
Expand Down Expand Up @@ -530,7 +534,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):

effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
jaxpr: core.Jaxpr, **params):
assert not jaxpr.constvars
disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
Expand Down Expand Up @@ -567,7 +572,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, residuals)
_, tracers_staged = partition_list(in_used_staged, tracers)
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged)
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) # type: ignore
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
Expand Down
18 changes: 6 additions & 12 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths, tree_flatten_with_path)
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand All @@ -61,8 +60,8 @@
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes, debug_info_final)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -456,10 +455,7 @@ def value_and_grad_f(*args, **kwargs):
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
fun_src_info = fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
dbg = debug_info('value_and_grad', fun_src_info, fun_signature,
args, kwargs, (), ())
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)

f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
Expand Down Expand Up @@ -1405,11 +1401,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

src = fun_sourceinfo(fun)
signature = api_util.fun_signature(fun)

dbg = debug_info('pmap', src, signature, args, kwargs,
static_broadcasted_tuple, ())
dbg = tracing_debug_info(
'pmap', fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

f = lu.wrap_init(fun)
if static_broadcasted_tuple:
Expand Down
37 changes: 32 additions & 5 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def flattened_fun_in_tree(
(args, store, f is flatten_fun.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
# When `fn` is not the result of flatten_fun or flatten_fun_nokwargs
return None
else:
return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr]
Expand Down Expand Up @@ -589,7 +590,7 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun


# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
Expand All @@ -598,12 +599,35 @@ def debug_info(
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
kwargs: dict[str, Any],
*,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
return inspect.signature(fun)
Expand Down Expand Up @@ -631,8 +655,11 @@ def fun_sourceinfo(fun: Callable) -> str | None:
except AttributeError:
return None

def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
) -> tuple[str, ...] | None:
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
Expand Down Expand Up @@ -665,7 +692,7 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths))
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_arg_names)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
Expand Down Expand Up @@ -636,7 +636,7 @@ def _check_for_aliased_refs(f, nondiff_argnums, args):
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
raise ValueError(
Expand Down
70 changes: 38 additions & 32 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@
from jax._src import source_info_util
from jax._src import compute_on
from jax._src import xla_metadata as xla_metadata_lib
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
Var, DropVar, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_flatten, tree_structure, generate_key_paths,
keystr)
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache, subs_list)
Expand Down Expand Up @@ -579,9 +578,9 @@ def trace_to_jaxpr_nounits(
# TODO(mattjj): superfluous wrapper...?
@lu.transformation2
def trace_to_subjaxpr_nounits(
f,
f: Callable,
trace: JaxprTrace,
instantiate: bool | Sequence[bool],
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
Expand All @@ -607,7 +606,9 @@ def trace_to_subjaxpr_nounits2(
del out_tracers
return jaxpr, (out_pvals, out_consts, env)

def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals):
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal]):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
Expand Down Expand Up @@ -1903,7 +1904,8 @@ def default_process_primitive(self, primitive, tracers, params):
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()

def process_call(self, call_primitive, f, explicit_tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun,
explicit_tracers, params):
if f.in_type is None:
f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers))
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
Expand All @@ -1915,7 +1917,7 @@ def process_call(self, call_primitive, f, explicit_tracers, params):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
propagate_source_info=False)
source_info = source_info_util.current()
out_tracers = []
out_tracers: list[Tracer] = []
for aval, _ in out_type:
if type(aval) is DShapedArray:
shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else
Expand Down Expand Up @@ -2110,35 +2112,39 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
# Callers should be using linear_util.debug_info instead!
def tracing_debug_info(
fn: Callable,
in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
in_tree: PyTreeDef,
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo:
src_info = fun_sourceinfo(fn)
arg_names: tuple[str | None, ...] | None
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
# TODO(necula): in general we can just pretend the leaves are booleans, but
# when we use custom pytrees, the flattening functions may check the type
# of the argument
try:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
for path, _ in generate_key_paths(dummy))
except:
arg_names = None # TODO(necula): we should not need this
def result_paths():
try:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
except:
return None # TODO(necula): this does not seem to be needed
return tuple(path for path, _ in generate_key_paths(dummy_result))
# TODO(necula): clean up the type: ignore below
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
dummy_args = ([False], {}) if has_kwargs else [False]
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
def res_paths_thunk() -> tuple[str, ...]:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
return tuple(tree_util.keystr(path)
for path, _ in tree_util.generate_key_paths(dummy_result))
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
result_paths_thunk=res_paths_thunk)

def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)

fn_trees = flattened_fun_in_tree(fn)
if fn_trees is None:
# TODO(necula): eliminate this branch
return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f),
(None,), None)
in_tree, out_tree_thunk, has_kws = fn_trees
return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for)

@profiler.annotate_function
def trace_to_jaxpr_dynamic(
Expand Down Expand Up @@ -2178,7 +2184,7 @@ def _check_no_returned_refs(
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {keystr(ls[i])}' # type: ignore
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
if (dbg.result_paths_thunk and
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
Expand All @@ -2190,7 +2196,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)]
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def switch(index, branches, *operands):
ops_avals = tuple(map(core.get_aval, ops))

if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
Expand Down Expand Up @@ -238,7 +238,7 @@ def cond(pred, true_fun, false_fun, *operands):
ops_avals = tuple(map(core.get_aval, ops))

if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def scan(f, init, xs, length=None):

if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)

Expand Down
Loading

0 comments on commit 688051a

Please sign in to comment.