Skip to content

Commit

Permalink
[better_errors] Try to fix the arg_names is None issue
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717523621
  • Loading branch information
gnecula authored and Google-ML-Automation committed Jan 20, 2025
1 parent ce48f64 commit 5f27211
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
8 changes: 6 additions & 2 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Iterable, Sequence
import inspect
import logging
import operator
from functools import partial, lru_cache
from typing import Any
Expand Down Expand Up @@ -661,15 +662,18 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
if fn_signature is None:
logging.info("XXX fn_signature is None")
return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
except (ValueError, TypeError) as e:
logging.info("XXX Failed to bind signature: %s", e)
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
Expand Down
22 changes: 14 additions & 8 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,16 +566,22 @@ def _infer_params_impl(

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)

dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
dbg1 = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
# TODO(necula): replace the above with below.
# haiku/_src/integration:hk_transforms_test fails
# dbg = tracing_debug_info('jit', fun, args, kwargs,
# static_argnums=ji.static_argnums,
# static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
# sourceinfo = ji.fun_sourceinfo,
# signature = ji.fun_signature)
dbg = tracing_debug_info('jit', fun, args, kwargs,
static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
sourceinfo = ji.fun_sourceinfo,
signature = ji.fun_signature)
assert (dbg is None) == (dbg1 is None), (dbg, dbg1)
if dbg is not None and dbg1 is not None:
assert (dbg.func_src_info == dbg1.func_src_info), (dbg, dbg1)
assert (dbg.traced_for == dbg1.traced_for), (dbg, dbg1)
assert (dbg.arg_names == dbg1.arg_names), (dbg, dbg1)

f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
Expand Down

0 comments on commit 5f27211

Please sign in to comment.