Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Improvements in propagation of debugging info #25916

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from jax._src import effects
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, debug_info, fun_sourceinfo, fun_signature)
from jax._src import api_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand All @@ -42,7 +41,7 @@
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)

Expand Down Expand Up @@ -324,8 +323,9 @@ def foo(x, y):
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = debug_info("checkpoint / remat", fun_sourceinfo(fun),
fun_signature(fun), args, kwargs, static_argnums, ())
debug = api_util.tracing_debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
in_avals = [core.shaped_abstractify(x) for x in args_flat]
Expand Down Expand Up @@ -415,8 +415,12 @@ def new_fun(*dyn_args, **kwargs):
# This helper is similar to those in control_flow/common.py, but with
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals, debug):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
Expand Down Expand Up @@ -530,7 +534,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):

effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
jaxpr: core.Jaxpr, **params):
assert not jaxpr.constvars
disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
Expand Down Expand Up @@ -567,7 +572,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, residuals)
_, tracers_staged = partition_list(in_used_staged, tracers)
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged)
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) # type: ignore
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
Expand Down
18 changes: 6 additions & 12 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths, tree_flatten_with_path)
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand All @@ -61,8 +60,8 @@
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes, debug_info_final)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -456,10 +455,7 @@ def value_and_grad_f(*args, **kwargs):
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
fun_src_info = fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
dbg = debug_info('value_and_grad', fun_src_info, fun_signature,
args, kwargs, (), ())
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)

f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
Expand Down Expand Up @@ -1405,11 +1401,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

src = fun_sourceinfo(fun)
signature = api_util.fun_signature(fun)

dbg = debug_info('pmap', src, signature, args, kwargs,
static_broadcasted_tuple, ())
dbg = tracing_debug_info(
'pmap', fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

f = lu.wrap_init(fun)
if static_broadcasted_tuple:
Expand Down
38 changes: 33 additions & 5 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def flattened_fun_in_tree(
(args, store, f is flatten_fun.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
# When `fn` is not the result of flatten_fun or flatten_fun_nokwargs
return None
else:
return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr]
Expand Down Expand Up @@ -589,7 +590,7 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun


# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
Expand All @@ -598,12 +599,36 @@ def debug_info(
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
kwargs: dict[str, Any],
*,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
# TODO(necula): remove type: ignore once we fix arg_names to never be None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore


def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
return inspect.signature(fun)
Expand Down Expand Up @@ -631,8 +656,11 @@ def fun_sourceinfo(fun: Callable) -> str | None:
except AttributeError:
return None

def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
) -> tuple[str, ...] | None:
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
Expand Down Expand Up @@ -665,7 +693,7 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths))
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or len(debug_info.arg_names) == len(invars) and
len(debug_info.result_paths) == len(outvars))
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

def __str__(self):
return str(self.pretty_print())
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_arg_names)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
Expand Down Expand Up @@ -636,7 +636,7 @@ def _check_for_aliased_refs(f, nondiff_argnums, args):
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
raise ValueError(
Expand Down
70 changes: 38 additions & 32 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@
from jax._src import source_info_util
from jax._src import compute_on
from jax._src import xla_metadata as xla_metadata_lib
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
Var, DropVar, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_flatten, tree_structure, generate_key_paths,
keystr)
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache, subs_list)
Expand Down Expand Up @@ -579,9 +578,9 @@ def trace_to_jaxpr_nounits(
# TODO(mattjj): superfluous wrapper...?
@lu.transformation2
def trace_to_subjaxpr_nounits(
f,
f: Callable,
trace: JaxprTrace,
instantiate: bool | Sequence[bool],
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
Expand All @@ -607,7 +606,9 @@ def trace_to_subjaxpr_nounits2(
del out_tracers
return jaxpr, (out_pvals, out_consts, env)

def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals):
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal]):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
Expand Down Expand Up @@ -1903,7 +1904,8 @@ def default_process_primitive(self, primitive, tracers, params):
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()

def process_call(self, call_primitive, f, explicit_tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun,
explicit_tracers, params):
if f.in_type is None:
f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers))
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
Expand All @@ -1915,7 +1917,7 @@ def process_call(self, call_primitive, f, explicit_tracers, params):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
propagate_source_info=False)
source_info = source_info_util.current()
out_tracers = []
out_tracers: list[Tracer] = []
for aval, _ in out_type:
if type(aval) is DShapedArray:
shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else
Expand Down Expand Up @@ -2110,35 +2112,39 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
# Callers should be using linear_util.debug_info instead!
def tracing_debug_info(
fn: Callable,
in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
in_tree: PyTreeDef,
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo:
src_info = fun_sourceinfo(fn)
arg_names: tuple[str | None, ...] | None
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
# TODO(necula): in general we can just pretend the leaves are booleans, but
# when we use custom pytrees, the flattening functions may check the type
# of the argument
try:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
for path, _ in generate_key_paths(dummy))
except:
arg_names = None # TODO(necula): we should not need this
def result_paths():
try:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
except:
return None # TODO(necula): this does not seem to be needed
return tuple(path for path, _ in generate_key_paths(dummy_result))
# TODO(necula): clean up the type: ignore below
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
dummy_args = ([False], {}) if has_kwargs else [False]
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
def res_paths_thunk() -> tuple[str, ...]:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
return tuple(tree_util.keystr(path)
for path, _ in tree_util.generate_key_paths(dummy_result))
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
result_paths_thunk=res_paths_thunk)

def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)

fn_trees = flattened_fun_in_tree(fn)
if fn_trees is None:
# TODO(necula): eliminate this branch
return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f),
(None,), None)
in_tree, out_tree_thunk, has_kws = fn_trees
return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for)

@profiler.annotate_function
def trace_to_jaxpr_dynamic(
Expand Down Expand Up @@ -2178,7 +2184,7 @@ def _check_no_returned_refs(
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {keystr(ls[i])}' # type: ignore
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
if (dbg.result_paths_thunk and
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
Expand All @@ -2190,7 +2196,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)]
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def switch(index, branches, *operands):
ops_avals = tuple(map(core.get_aval, ops))

if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
Expand Down Expand Up @@ -238,7 +238,7 @@ def cond(pred, true_fun, false_fun, *operands):
ops_avals = tuple(map(core.get_aval, ops))

if config.mutable_array_checks.value:
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def scan(f, init, xs, length=None):

if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)

Expand Down
Loading
Loading