-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior. In #22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration. PiperOrigin-RevId: 716596154
- Loading branch information
1 parent
a4a657b
commit e9c5244
Showing
5 changed files
with
470 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections.abc import Callable, Sequence | ||
import functools | ||
from typing import Any | ||
|
||
from jax._src import api_util | ||
from jax._src import core | ||
from jax._src import custom_api_util | ||
from jax._src import linear_util as lu | ||
from jax._src import source_info_util | ||
from jax._src import traceback_util | ||
from jax._src import tree_util | ||
from jax._src import util | ||
from jax._src.interpreters import ad | ||
from jax._src.interpreters import batching | ||
from jax._src.interpreters import mlir | ||
from jax._src.interpreters import partial_eval as pe | ||
|
||
source_info_util.register_exclusion(__file__) | ||
traceback_util.register_exclusion(__file__) | ||
|
||
map, unsafe_map = util.safe_map, map | ||
zip, unsafe_zip = util.safe_zip, zip | ||
|
||
|
||
@custom_api_util.register_custom_decorator_type | ||
class custom_dce: | ||
"""Customize the DCE behavior of a JAX-transformable function. | ||
JAX uses dead code elimination (DCE) to remove unused computations from a | ||
JAX program. This typically works transparently when the program is | ||
completely specified by known JAX operations, but opaque kernels like calls | ||
to :py:func:`~jax.experimental.pallas.pallas_call` or | ||
:py:func:`~jax.ffi.ffi_call`, for example, may cause problems. | ||
This decorator allows users to customize the DCE behavior of a function by | ||
defining a custom DCE rule. This custom rule is then invoked during DCE with | ||
a Pytree of ``bool``s indicating which outputs should be computed. | ||
For example:: | ||
@jax.experimental.custom_dce.custom_dce | ||
@jax.custom_vjp | ||
def f(x, y): | ||
return jnp.sin(x) * y, x * jnp.sin(y) | ||
@f.def_dce | ||
def f_dce_rule(used_outs, x, y): | ||
outs = [] | ||
if used_outs[0]: | ||
outs.append(jnp.sin(x) * y) | ||
if used_outs[1]: | ||
outs.append(x * jnp.sin(y)) | ||
return tuple(outs), (True, True) | ||
In this example, ``used_outs`` is a ``tuple`` with two elements indicating | ||
which outputs are required. Then, the DCE rule returns only the required | ||
outputs, and another Pytree of ``bool``s indicating which inputs were used. | ||
""" | ||
|
||
fun: Callable[..., Any] | ||
dce_rule: Callable[..., Any] | None | ||
|
||
def __init__(self, fun: Callable[..., Any]): | ||
functools.update_wrapper(self, fun) | ||
self.fun = fun | ||
self.dce_rule = None | ||
|
||
__getattr__ = custom_api_util.forward_attr | ||
|
||
def def_dce( | ||
self, | ||
dce_rule: Callable[..., Any], | ||
) -> Callable[..., Any]: | ||
"""Define a custom DCE rule for this function. | ||
Args: | ||
dce_rule: A function that takes a Pytree of ``bool``s indicating which | ||
outputs should be computed as the first argument, and then the original | ||
function's arguments. This rule must return a pair of outputs where the | ||
the first element is a Pytree including only the required outputs, and | ||
the second element is a Pytree of ``bool``s indicating which inputs were | ||
used. | ||
""" | ||
self.dce_rule = dce_rule | ||
return dce_rule | ||
|
||
@traceback_util.api_boundary | ||
def __call__(self, *args, **kwargs): | ||
args = api_util.resolve_kwargs(self.fun, args, kwargs) | ||
fun_name = util.fun_name(self.fun) | ||
if self.dce_rule is None: | ||
raise AttributeError( | ||
f"No DCE rule defined for custom_dce function {fun_name} using " | ||
"def_dce." | ||
) | ||
rule_name = util.fun_name(self.dce_rule) | ||
args_flat, in_tree = tree_util.tree_flatten(args) | ||
flat_fun, out_tree = api_util.flatten_fun_nokwargs( | ||
lu.wrap_init(self.fun), in_tree | ||
) | ||
in_avals = [core.get_aval(x) for x in args_flat] | ||
|
||
@pe._memoize | ||
def dce_jaxpr_thunk(*used_outs: bool): | ||
flat_rule, aux = flatten_dce_rule( | ||
lu.wrap_init(self.dce_rule), | ||
fun_name, | ||
rule_name, | ||
used_outs, | ||
in_tree, | ||
out_tree(), | ||
) | ||
debug = pe.tracing_debug_info( | ||
self.dce_rule, in_tree, lambda: aux()[0], False, "custom_dce_rule" | ||
) | ||
dce_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( | ||
flat_rule, in_avals, debug | ||
) | ||
assert not consts | ||
_, used_ins = aux() | ||
|
||
# TODO(danfm): Update debug info. | ||
invars = [v for used, v in zip(used_ins, dce_jaxpr.invars) if used] | ||
dce_jaxpr = dce_jaxpr.replace(invars=invars) | ||
|
||
return pe.close_jaxpr(dce_jaxpr), used_ins | ||
|
||
debug = pe.tracing_debug_info( | ||
self.fun, in_tree, out_tree, False, "custom_dce" | ||
) | ||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) | ||
assert not consts | ||
closed_call = pe.close_jaxpr(jaxpr) | ||
out_flat = custom_dce_p.bind( | ||
*args_flat, fun_jaxpr=closed_call, dce_jaxpr_thunk=dce_jaxpr_thunk | ||
) | ||
return tree_util.tree_unflatten(out_tree(), out_flat) | ||
|
||
|
||
@lu.transformation_with_aux2 | ||
def flatten_dce_rule( | ||
f, store, primal_name, rule_name, used_outs, in_tree, out_tree, *args_flat | ||
): | ||
py_used_outs = tree_util.tree_unflatten(out_tree, used_outs) | ||
py_args = tree_util.tree_unflatten(in_tree, args_flat) | ||
py_out = f(py_used_outs, *py_args) | ||
if not isinstance(py_out, (tuple, list)) or len(py_out) != 2: | ||
raise TypeError( | ||
f"Custom DCE rule {rule_name} for function {primal_name} must produce " | ||
"a pair (list or tuple of length 2) where the first element only " | ||
"includes the requested outputs, and the second element indicates " | ||
f"which inputs were used. Instead, {rule_name} returned {py_out}." | ||
) | ||
py_out, used_ins = py_out | ||
out_flat, rule_out_tree = tree_util.tree_flatten(py_out) | ||
if len(out_flat) != sum(used_outs): | ||
raise TypeError( | ||
f"Custom DCE rule {rule_name} for function {primal_name} must produce " | ||
"a pair (list or tuple of length 2) where the first element only " | ||
f"includes the requested outputs. {rule_name} returned {py_out} with " | ||
f"{len(out_flat)} leaves, but {sum(used_outs)} were expected." | ||
) | ||
|
||
if isinstance(used_ins, list): | ||
used_ins = tuple(used_ins) | ||
used_ins_flat, used_ins_tree = tree_util.tree_flatten(used_ins) | ||
if used_ins_tree != in_tree: | ||
raise TypeError( | ||
f"Custom DCE rule {rule_name} for function {primal_name} must produce " | ||
"a pair (list or tuple of length 2) where the second element has the " | ||
"same container (pytree) structure as the function's input. " | ||
f"{rule_name} returned {used_ins_tree}, but {in_tree} was expected." | ||
) | ||
used_ins_flat = map(bool, used_ins_flat) | ||
|
||
store.store((rule_out_tree, used_ins_flat)) | ||
return out_flat | ||
|
||
|
||
def custom_dce_impl(*args, fun_jaxpr, **_): | ||
return core.jaxpr_as_fun(fun_jaxpr)(*args) | ||
|
||
|
||
def custom_dce_abstract_eval(*args, fun_jaxpr, **_): | ||
del args # unused | ||
return fun_jaxpr.out_avals, fun_jaxpr.effects | ||
|
||
|
||
def custom_dce_batching(axis_data, args, dims, *, fun_jaxpr, dce_jaxpr_thunk): | ||
in_batched = [d is not batching.not_mapped for d in dims] | ||
args = [ | ||
batching.moveaxis(x, d, 0) if b else x | ||
for b, x, d in zip(in_batched, args, dims) | ||
] | ||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr( | ||
fun_jaxpr, axis_data, in_batched, True | ||
) | ||
|
||
@pe._memoize | ||
def batched_dce_jaxpr_thunk(*used_outs: bool): | ||
dce_jaxpr, used_ins = dce_jaxpr_thunk(*used_outs) | ||
dce_jaxpr_batched, _ = batching.batch_jaxpr( | ||
dce_jaxpr, | ||
axis_data, | ||
[b for used, b in zip(used_ins, in_batched) if used], | ||
True, | ||
) | ||
return dce_jaxpr_batched, used_ins | ||
|
||
out_flat = custom_dce_p.bind( | ||
*args, | ||
fun_jaxpr=batched_fun_jaxpr, | ||
dce_jaxpr_thunk=batched_dce_jaxpr_thunk, | ||
) | ||
out_dims = [0 if b else batching.not_mapped for b in out_batched] | ||
return out_flat, out_dims | ||
|
||
|
||
def custom_dce_jvp(primals, tangents, *, fun_jaxpr, **_): | ||
in_nz = [not isinstance(t, ad.Zero) for t in tangents] | ||
tangents = [t for nz, t in zip(in_nz, tangents) if nz] | ||
jvp_jaxpr, out_nz = ad.jvp_jaxpr(fun_jaxpr, in_nz, False) | ||
|
||
# TODO(danfm): We should avoid losing the DCE rule here, but it is more | ||
# straightforward to implement it like this to start. Instead, we should | ||
# bind a custom_dce primitive. To support that, we would need to add a | ||
# partial eval rule, and maybe a transpose rule. | ||
out = core.call_p.bind( | ||
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents | ||
) | ||
|
||
out_primals, out_tangents = util.split_list(out, [len(out_nz)]) | ||
out_tangents_iter = iter(out_tangents) | ||
out_tangents = [ | ||
next(out_tangents_iter) if nz else ad.Zero.from_primal_value(p) | ||
for p, nz in zip(out_primals, out_nz) | ||
] | ||
return out_primals, out_tangents | ||
|
||
|
||
def custom_dce_rule(used_outs: Sequence[bool], eqn: core.JaxprEqn): | ||
if not any(used_outs) and not pe.has_effects(eqn): | ||
return [False] * len(eqn.invars), None | ||
|
||
dce_jaxpr_thunk = eqn.params["dce_jaxpr_thunk"] | ||
jaxpr, used_ins = dce_jaxpr_thunk(*used_outs) | ||
invars = [v for used, v in zip(used_ins, eqn.invars) if used] | ||
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] | ||
|
||
@pe._memoize | ||
def new_dce_jaxpr_thunk(*new_used_outs: bool): | ||
all_used_outs = util.merge_lists( | ||
used_outs, | ||
[False] * (len(used_outs) - len(new_used_outs)), | ||
new_used_outs, | ||
) | ||
new_jaxpr, all_used_ins = dce_jaxpr_thunk(*all_used_outs) | ||
not_used, new_used_ins = util.partition_list(used_ins, all_used_ins) | ||
assert not any(not_used) | ||
return new_jaxpr, new_used_ins | ||
|
||
new_params = dict(eqn.params) | ||
new_params["dce_jaxpr_thunk"] = new_dce_jaxpr_thunk | ||
new_params["fun_jaxpr"] = jaxpr | ||
new_eqn = pe.new_jaxpr_eqn( | ||
invars, | ||
outvars, | ||
custom_dce_p, | ||
new_params, | ||
jaxpr.effects, | ||
eqn.source_info, | ||
eqn.ctx, | ||
) | ||
return used_ins, new_eqn | ||
|
||
|
||
custom_dce_p = core.Primitive("custom_dce_call") | ||
custom_dce_p.multiple_results = True | ||
custom_dce_p.def_impl(custom_dce_impl) | ||
custom_dce_p.def_effectful_abstract_eval(custom_dce_abstract_eval) | ||
mlir.register_lowering( | ||
custom_dce_p, mlir.lower_fun(custom_dce_impl, multiple_results=True) | ||
) | ||
batching.fancy_primitive_batchers[custom_dce_p] = custom_dce_batching | ||
ad.primitive_jvps[custom_dce_p] = custom_dce_jvp | ||
pe.dce_rules[custom_dce_p] = custom_dce_rule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from jax._src.custom_dce import ( | ||
custom_dce as custom_dce, | ||
custom_dce_p as custom_dce_p, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.