Skip to content

Commit

Permalink
[better_errors] Finally remove api_util.debug_info.
Browse files Browse the repository at this point in the history
Following #25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
  • Loading branch information
gnecula authored and Google-ML-Automation committed Jan 20, 2025
1 parent 543dd94 commit 4fd0bb0
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 40 deletions.
26 changes: 7 additions & 19 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun

# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any],
static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
Expand All @@ -618,15 +603,16 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> TracingDebugInfo | None:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
# TODO(necula): remove type: ignore once we fix arg_names to never be None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand Down Expand Up @@ -656,12 +642,14 @@ def fun_sourceinfo(fun: Callable) -> str | None:
except AttributeError:
return None

# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
if fn_signature is None:
return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo:
) -> lu.TracingDebugInfo | None:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
Expand All @@ -2137,7 +2137,7 @@ def res_paths_thunk() -> tuple[str, ...]:
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
result_paths_thunk=res_paths_thunk)

def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
fn_trees = flattened_fun_in_tree(fn)
if fn_trees is None:
# TODO(necula): eliminate this branch
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def to_block_mapping(
"pallas_call index_map",
)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug.func_src_info # type: ignore
None, debug and debug.func_src_info # type: ignore
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def wrapped(*args):
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})
arg_names = kernel_debug_info.arg_names
arg_names = kernel_debug_info and kernel_debug_info.arg_names
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -565,17 +565,13 @@ def _infer_params_impl(
"device is also specified as an argument to jit.")

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = tracing_debug_info('jit', fun, args, kwargs,
static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)

dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
# TODO(necula): replace the above with below.
# haiku/_src/integration:hk_transforms_test fails
# dbg = tracing_debug_info('jit', fun, args, kwargs,
# static_argnums=ji.static_argnums,
# static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
# sourceinfo = ji.fun_sourceinfo,
# signature = ji.fun_signature)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
Expand Down
15 changes: 9 additions & 6 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,9 @@ def my_index_map():
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map function my_index_map at .*pallas_test.py:.* for "
"x_ref must return 1 values to match .*"
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return 1 values to match .*"
"Currently returning 2 values."):
f(a)

Expand All @@ -981,8 +982,9 @@ def my_index_map(i):
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map function my_index_map at .*pallas_test.py:.* for "
"x_ref must return integer scalars. Output\\[0\\] has "
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return integer scalars. Output\\[0\\] has "
"type .*float"):
f(a)

Expand All @@ -996,8 +998,9 @@ def my_index_map(i):
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map function my_index_map at .*pallas_test.py:.* for "
"x_ref must return integer scalars. Output\\[0\\] has "
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return integer scalars. Output\\[0\\] has "
"type .*int32\\[4\\]"):
f(a)

Expand Down

0 comments on commit 4fd0bb0

Please sign in to comment.