Skip to content

Commit

Permalink
[better_errors] Finally remove api_util.debug_info.
Browse files Browse the repository at this point in the history
Following #25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717523621
  • Loading branch information
gnecula authored and Google-ML-Automation committed Jan 20, 2025
1 parent ce48f64 commit a8e29ef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 30 deletions.
26 changes: 7 additions & 19 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun

# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any],
static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
Expand All @@ -618,15 +603,16 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> TracingDebugInfo | None:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
# TODO(necula): remove type: ignore once we fix arg_names to never be None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand Down Expand Up @@ -656,12 +642,14 @@ def fun_sourceinfo(fun: Callable) -> str | None:
except AttributeError:
return None

# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
if fn_signature is None:
return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -565,17 +565,13 @@ def _infer_params_impl(
"device is also specified as an argument to jit.")

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = tracing_debug_info('jit', fun, args, kwargs,
static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)

dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
# TODO(necula): replace the above with below.
# haiku/_src/integration:hk_transforms_test fails
# dbg = tracing_debug_info('jit', fun, args, kwargs,
# static_argnums=ji.static_argnums,
# static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
# sourceinfo = ji.fun_sourceinfo,
# signature = ji.fun_signature)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
Expand Down

0 comments on commit a8e29ef

Please sign in to comment.