You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I try to use jax.jit with mpi4py I find the compilation is ~ 1.5x slower, which means each process independently compiles the code even when all processes are on the same device. Is there a way such that all processes can share the compiled function, so that I would assume the compilation for mpi is not slower? Thanks!
import time
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')
print(jax.devices())
import jax.numpy as jnp
from mpi4py import MPI
@jax.jit
def func(M):
for _ in range(1000):
M = M.dot(jnp.eye(len(M)))
return M
M = np.random.random((20, 20))
t1 = time.time()
func(M)
t2 = time.time()
func(M)
t3 = time.time()
print(t2 - t1, t3 - t2)
Description
When I try to use
jax.ji
t withmpi4py
I find the compilation is ~ 1.5x slower, which means each process independently compiles the code even when all processes are on the same device. Is there a way such that all processes can share the compiled function, so that I would assume the compilation for mpi is not slower? Thanks!The output is
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: