Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax compilation is slower with MPI #25848

Open
SUSYUSTC opened this issue Jan 12, 2025 · 1 comment
Open

jax compilation is slower with MPI #25848

SUSYUSTC opened this issue Jan 12, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@SUSYUSTC
Copy link

Description

When I try to use jax.jit with mpi4py I find the compilation is ~ 1.5x slower, which means each process independently compiles the code even when all processes are on the same device. Is there a way such that all processes can share the compiled function, so that I would assume the compilation for mpi is not slower? Thanks!

import time
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')
print(jax.devices())
import jax.numpy as jnp
from mpi4py import MPI


@jax.jit
def func(M):
    for _ in range(1000):
        M = M.dot(jnp.eye(len(M)))
    return M


M = np.random.random((20, 20))
t1 = time.time()
func(M)
t2 = time.time()
func(M)
t3 = time.time()
print(t2 - t1, t3 - t2)

The output is

(py311) jiace@:~$ python test_jit_mpi.py
[CpuDevice(id=0)]
1.2670481204986572 3.0994415283203125e-05
(py311) jiace@:~$ mpiexec -n 8 python test_jit_mpi.py
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
[CpuDevice(id=0)]
1.7631428241729736 4.172325134277344e-05
1.809206247329712 4.7206878662109375e-05
1.8106226921081543 4.5299530029296875e-05
1.8143680095672607 4.744529724121094e-05
1.824890375137329 4.3392181396484375e-05
1.8354761600494385 0.00010800361633300781
1.8431165218353271 6.937980651855469e-05
1.9080822467803955 5.269050598144531e-05

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  1.26.1
python: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:34:09) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 2060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='jiace-XPS-8930', release='5.4.0-150-generic', version='#167~18.04.1-Ubuntu SMP Wed May 24 00:51:42 UTC 2023', machine='x86_64')
@SUSYUSTC SUSYUSTC added the bug Something isn't working label Jan 12, 2025
@ASKabalan
Copy link

Hello

This is because JAX is not aware that you are distributing

You need to do this

import time
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.distributed.initialize()
print(jax.devices())
import jax.numpy as jnp
from mpi4py import MPI


@jax.jit
def func(M):
    for _ in range(1000):
        M = M.dot(jnp.eye(len(M)))
    return M


M = np.random.random((20, 20))
t1 = time.time()
func(M)
t2 = time.time()
func(M)
t3 = time.time()
print(t2 - t1, t3 - t2)

jax.distributed.shutdown()

If you are launching this with and mpirun, mpiexec or srun you don't need to specify arguments for jax.distributed.initialize

read more in the documentations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants