Skip to content

Commit

Permalink
[JAX] Add Python binding for building a colocated Python program
Browse files Browse the repository at this point in the history
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Nov 26, 2024
1 parent 6763fcf commit bbaec6e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1198,5 +1198,6 @@ pytype_library(
":util",
":xla_bridge",
"//jax/_src/lib",
"//jax/extend:ifrt_programs",
] + py_deps("numpy") + py_deps("cloudpickle"),
)
12 changes: 9 additions & 3 deletions jax/experimental/colocated_python/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
from jax.extend.ifrt_programs import ifrt_programs

ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]

Expand Down Expand Up @@ -141,8 +142,13 @@ def _compile_to_executable(
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
# executable.
pickled_function = _serialize(fun)
program = ifrt_programs.make_colocated_python_program(
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
)
# TODO(hyeontaek): Compile the program and use the executable.
del program

del name
del in_specs_leaves
del out_specs_leaves
Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ jax_multiplatform_test(
srcs = ["colocated_python_test.py"],
deps = [
"//jax:experimental_colocated_python",
"//jax/extend:ifrt_programs",
],
)

Expand Down
20 changes: 18 additions & 2 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax.experimental import colocated_python
from jax.experimental.colocated_python import func as colocated_python_func
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
if xla_extension_version < 290:
self.skipTest("Requires xla_extension_version >= 290")
if xla_extension_version < 298:
self.skipTest("Requires xla_extension_version >= 298")

def testMakeColocatedPythonProgram(self):
def add_one(x):
return x + 1

cpu_devices = _colocated_cpu_devices(jax.local_devices())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)

pickled_function = serialization._serialize(add_one)
program = ifrt_programs.make_colocated_python_program(
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
)
del program

def testSimpleFunction(self):
@colocated_python.colocated_python
Expand Down

0 comments on commit bbaec6e

Please sign in to comment.