diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a9f348fe71e1..e56a9ca598cb 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -32,8 +32,7 @@ from jax._src import effects from jax._src import source_info_util from jax._src import traceback_util -from jax._src.api_util import ( - flatten_fun, debug_info, fun_sourceinfo, fun_signature) +from jax._src import api_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -42,7 +41,7 @@ from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo from jax._src.traceback_util import api_boundary -from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure +from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -324,8 +323,9 @@ def foo(x, y): @wraps(fun) @api_boundary def fun_remat(*args, **kwargs): - debug = debug_info("checkpoint / remat", fun_sourceinfo(fun), - fun_signature(fun), args, kwargs, static_argnums, ()) + debug = api_util.tracing_debug_info( + "checkpoint / remat", fun, + args, kwargs, static_argnums=static_argnums) fun_, args = _remat_static_argnums(fun, static_argnums, args) args_flat, in_tree = tree_flatten((args, kwargs)) in_avals = [core.shaped_abstractify(x) for x in args_flat] @@ -415,8 +415,12 @@ def new_fun(*dyn_args, **kwargs): # This helper is similar to those in control_flow/common.py, but with # remat-specific errors. @weakref_lru_cache -def _trace_to_jaxpr(fun, in_tree, in_avals, debug): - flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) +def _trace_to_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug: lu.TracingDebugInfo + ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: + flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree) try: jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) except core.ConcretizationTypeError as e: @@ -530,7 +534,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect) -def remat_partial_eval(trace, *tracers, jaxpr, **params): +def remat_partial_eval(trace: core.JaxprTrace, *tracers: core.Tracer, + jaxpr: core.Jaxpr, **params): assert not jaxpr.constvars disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: diff --git a/jax/_src/api.py b/jax/_src/api.py index 29118ba53baa..6eb7f437c529 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -42,7 +42,6 @@ tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, prefix_errors, generate_key_paths, tree_flatten_with_path) -from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch @@ -61,7 +60,7 @@ flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, - apply_flat_fun_nokwargs, check_callable, debug_info, + apply_flat_fun_nokwargs, check_callable, tracing_debug_info, result_paths, flat_out_axes, debug_info_final, fun_sourceinfo) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit @@ -454,10 +453,7 @@ def value_and_grad_f(*args, **kwargs): raise TypeError(f"differentiating with respect to {argnums=} requires at least " f"{max_argnum + 1} positional arguments to be passed by the caller, " f"but got only {len(args)} positional arguments.") - fun_src_info = fun_sourceinfo(fun) - fun_signature = api_util.fun_signature(fun) - dbg = debug_info('value_and_grad', fun_src_info, fun_signature, - args, kwargs, (), ()) + dbg = tracing_debug_info('value_and_grad', fun, args, kwargs) f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f_partial, dyn_args = argnums_partial(f, argnums, args, @@ -1403,11 +1399,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") - src = fun_sourceinfo(fun) - signature = api_util.fun_signature(fun) - - dbg = debug_info('pmap', src, signature, args, kwargs, - static_broadcasted_tuple, ()) + dbg = tracing_debug_info( + 'pmap', fun, args, kwargs, + static_argnums=static_broadcasted_tuple) f = lu.wrap_init(fun) if static_broadcasted_tuple: diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 930011fe27ef..544b44ec75ff 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -116,6 +116,7 @@ def flattened_fun_in_tree( (args, store, f is flatten_fun.args[0]) for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens) except ValueError: + # When `fn` is not the result of flatten_fun or flatten_fun_nokwargs return None else: return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr] @@ -590,19 +591,20 @@ def api_hook(fun, tag: str): return fun -def debug_info( - traced_for: str, fun_src_info: str | None, - fun_signature: inspect.Signature | None, - args: tuple[Any, ...], kwargs: dict[str, Any], - static_argnums: tuple[int, ...], - static_argnames: tuple[str, ...] -) -> TracingDebugInfo | None: - """Try to build trace-time debug info for fun when applied to args/kwargs.""" - arg_names = _arg_names(fun_signature, args, kwargs, static_argnums, +def tracing_debug_info( + traced_for: str, + fun: Callable, + args: Sequence[Any], + kwargs: dict[str, Any], + static_argnums: tuple[int, ...] = (), + static_argnames: tuple[str, ...] = () +) -> TracingDebugInfo: + sourceinfo = fun_sourceinfo(fun) + signature = fun_signature(fun) + arg_names = _arg_names(signature, args, kwargs, static_argnums, static_argnames) - if arg_names is None: - return None - return TracingDebugInfo(traced_for, fun_src_info, arg_names, None) + return TracingDebugInfo(traced_for, sourceinfo, arg_names, None) + def fun_signature(fun: Callable) -> inspect.Signature | None: try: @@ -631,18 +633,25 @@ def fun_sourceinfo(fun: Callable) -> str | None: except AttributeError: return None -def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, - ) -> tuple[str, ...] | None: - if fn_signature is None: return None +def _arg_names(fn_signature: inspect.Signature | None, + args, kwargs, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + ) -> tuple[str | None, ...]: + """The names of the non-static args and argnames.""" static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + nr_non_static_args = (len(args) - len(static_argnums_) + + len(kwargs) - len(static_argnames_)) + if fn_signature is None: + return (None,) * nr_non_static_args try: ba = fn_signature.bind(*args_, **kwargs) except (ValueError, TypeError): - return None + return (None,) * nr_non_static_args return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() for path, l in generate_key_paths(x) if l is not static) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 43a771cd2b0a..575057858cd3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -579,9 +579,9 @@ def trace_to_jaxpr_nounits( # TODO(mattjj): superfluous wrapper...? @lu.transformation2 def trace_to_subjaxpr_nounits( - f, + f: Callable, trace: JaxprTrace, - instantiate: bool | Sequence[bool], + instantiate: Sequence[bool] | bool, in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( @@ -607,7 +607,9 @@ def trace_to_subjaxpr_nounits2( del out_tracers return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): +def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, + instantiate: Sequence[bool] | bool, + in_pvals: Sequence[PartialVal]): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] @@ -1903,7 +1905,8 @@ def default_process_primitive(self, primitive, tracers, params): self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f, explicit_tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, + explicit_tracers, params): if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) @@ -1915,7 +1918,7 @@ def process_call(self, call_primitive, f, explicit_tracers, params): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) source_info = source_info_util.current() - out_tracers = [] + out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else @@ -2110,36 +2113,33 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): # Callers should be using linear_util.debug_info instead! def tracing_debug_info( fn: Callable, - in_tree: PyTreeDef | None, - out_tree_thunk: Callable[[], PyTreeDef] | None, + in_tree: PyTreeDef, + out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str ) -> lu.TracingDebugInfo: + # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead src_info = fun_sourceinfo(fn) - arg_names: tuple[str | None, ...] | None - try: - dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore - args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) - ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore - arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items() - for path, _ in generate_key_paths(dummy)) - except: - arg_names = None # TODO(necula): we should not need this - def result_paths(): - try: - out_tree = out_tree_thunk() - dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves) - except: - return None # TODO(necula): this does not seem to be needed - return tuple(path for path, _ in generate_key_paths(dummy_result)) - # TODO(necula): clean up the type: ignore below - return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type] + signature = api_util.fun_signature(fn) + dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore + args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) + arg_names = api_util._arg_names(signature, args, kwargs, (), ()) + def result_paths() -> tuple[str, ...]: + out_tree = out_tree_thunk() + dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves) + return tuple(path for path, _ in generate_key_paths(dummy_result)) # type: ignore + return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) + def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo: - in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False) + fn_trees = flattened_fun_in_tree(fn) + if fn_trees is None: + # TODO(necula): eliminate this branch + return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f), + (None,), None) + in_tree, out_tree, has_kws = fn_trees return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for) - @profiler.annotate_function def trace_to_jaxpr_dynamic( fun: lu.WrappedFun, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index a9bcfd24d1c7..17116e857671 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -139,7 +139,7 @@ def switch(index, branches, *operands): ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') + dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore _check_no_aliased_ref_args(dbg, ops_avals, ops) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( @@ -238,7 +238,7 @@ def cond(pred, true_fun, false_fun, *operands): ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') + dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore _check_no_aliased_ref_args(dbg, ops_avals, ops) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( (true_fun, false_fun), ops_tree, ops_avals, 'cond') diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c9c4eaf8c25c..ab51297ade40 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -275,7 +275,7 @@ def scan(f, init, xs, length=None): if config.mutable_array_checks.value: in_flat, in_tree = tree_flatten((init, xs)) - dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') + dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore in_avals = tuple(_map(core.get_aval, in_flat)) _check_no_aliased_ref_args(dbg, in_avals, in_flat) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0fab8613312..41f73df46e64 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1810,13 +1810,11 @@ def wrapped(*args): kernel_fun_sig = api_util.fun_signature(kernel) arg_names = None if kernel_fun_sig: - kernel_debug_info = api_util.debug_info( + kernel_debug_info = api_util.tracing_debug_info( "pallas_call kernel", - kernel_src_info, - kernel_fun_sig, - [1] * len(kernel_fun_sig.parameters), {}, (), ()) - if kernel_debug_info: - arg_names = kernel_debug_info.arg_names + kernel, + [1] * len(kernel_fun_sig.parameters), {}) + arg_names = kernel_debug_info.arg_names del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ccd087107f6e..39b2520a6c69 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,7 @@ from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info, + argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info, hoist_obj_attrs, _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe @@ -176,7 +176,7 @@ def __eq__(self, other): return self is other -def _python_pjit_helper(fun, jit_info, *args, **kwargs): +def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): p, args_flat = _infer_params(fun, jit_info, args, kwargs) for arg in args_flat: @@ -566,9 +566,10 @@ def _infer_params_impl( axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, - ji.static_argnums, ji.static_argnames) - f = lu.wrap_init(fun) + dbg = tracing_debug_info('jit', fun, args, kwargs, + static_argnums=ji.static_argnums, + static_argnames=ji.static_argnames) + f = lu.wrap_init(fun, debug_info=dbg) f, res_paths = result_paths(f) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args @@ -732,8 +733,9 @@ def _infer_params( signature, dynargs = jax_jit.parse_arguments( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, ji.static_argnames, tree_util.default_registry) - dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, - ji.static_argnums, ji.static_argnames) + dbg = tracing_debug_info('jit', fun, args, kwargs, + static_argnums=ji.static_argnums, + static_argnames=ji.static_argnames) avals = _infer_input_type(fun, dbg, dynargs) entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env) if entry.pjit_params is None: @@ -744,7 +746,9 @@ def _infer_params( entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs -def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]: +def _infer_input_type(fun: Callable, + dbg: lu.TracingDebugInfo | None, + explicit_args) -> tuple[core.AbstractValue, ...]: avals = [] try: for i, x in enumerate(explicit_args): @@ -1301,7 +1305,7 @@ def _create_pjit_jaxpr( with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for) + pe_debug = debug_info or pe.tracing_debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug) @@ -1672,7 +1676,7 @@ def _pjit_call_impl_python( if compiled._auto_spmd_lowering and config.enable_checks.value: pxla.check_array_xla_sharding_layout_match( args, compiled._in_shardings, compiled._in_layouts, - jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx) + jaxpr.jaxpr._debug_info, compiled._kept_var_idx) if config.distributed_debug.value: # Defensively only perform fingerprint logic if debug logging is enabled # NOTE(skyewm): I didn't benchmark this