diff --git a/jax/BUILD b/jax/BUILD index d35ff0e399a6..4260a8a1acb2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1198,5 +1198,6 @@ pytype_library( ":util", ":xla_bridge", "//jax/_src/lib", + "//jax/extend:ifrt_programs", ] + py_deps("numpy") + py_deps("cloudpickle"), ) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 3e95ddf03c7e..6639e7eefdd6 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -28,7 +28,8 @@ from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend -from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs +from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @@ -141,8 +142,13 @@ def _compile_to_executable( devices: xc.DeviceList, ) -> Callable[..., Any]: """Compiles a Python function into a runtime executable.""" - # TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an - # executable. + pickled_function = _serialize(fun) + program = ifrt_programs.make_colocated_python_program( + name, pickled_function, devices, in_specs_leaves, out_specs_leaves + ) + # TODO(hyeontaek): Compile the program and use the executable. + del program + del name del in_specs_leaves del out_specs_leaves diff --git a/tests/BUILD b/tests/BUILD index f0668b42b309..92a6ed99ceca 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1387,6 +1387,7 @@ jax_multiplatform_test( srcs = ["colocated_python_test.py"], deps = [ "//jax:experimental_colocated_python", + "//jax/extend:ifrt_programs", ], ) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 9f65e3aeced4..f86a68a998f3 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -22,6 +22,8 @@ from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax.experimental import colocated_python from jax.experimental.colocated_python import func as colocated_python_func +from jax.experimental.colocated_python import serialization +from jax.extend.ifrt_programs import ifrt_programs import jax.numpy as jnp import numpy as np @@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 290: - self.skipTest("Requires xla_extension_version >= 290") + if xla_extension_version < 298: + self.skipTest("Requires xla_extension_version >= 298") + + def testMakeColocatedPythonProgram(self): + def add_one(x): + return x + 1 + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) + aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) + + pickled_function = serialization._serialize(add_one) + program = ifrt_programs.make_colocated_python_program( + "add_one", pickled_function, [cpu_devices[0]], [aval], [aval] + ) + del program def testSimpleFunction(self): @colocated_python.colocated_python