Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ptxas unsupported version error #25853

Open
lengstrom opened this issue Jan 13, 2025 · 9 comments
Open

ptxas unsupported version error #25853

lengstrom opened this issue Jan 13, 2025 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@lengstrom
Copy link

Description

Minimal example:

from jax.random import PRNGKey
PRNGKey(0)

yields the error

  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/random.py", line 246, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key
    return prng.random_seed(seed, impl=impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/prng.py", line 541, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/core.py", line 941, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/prng.py", line 553, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/prng.py", line 558, in random_seed_impl_base
    return seed(seeds)
           ^^^^^^^^^^^
  File "/mnt/xfs/home/engstrom/conda_envs/benclip/lib/python3.11/site-packages/jax/_src/prng.py", line 774, in threefry_seed
    return _threefry_seed(seed)
           ^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-deep-chungus-11.csail.mit.edu-e7d9f2edf16d3f14-465726-62b8fece3009c, line 5; fatal   : Unsupported .version 8.3; current version is '8.2'
ptxas fatal   : Ptx assembly aborted due to errors

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.0.2
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
device info: NVIDIA A100 80GB PCIe-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asdf.csail.mit.edu', release='5.15.0-130-generic', version='#140-Ubuntu SMP Wed Dec 18 17:59:53 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon Jan 13 00:45:15 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:01:00.0 Off |                    0 |
| N/A   53C    P0             75W /  300W |   17865MiB /  81920MiB |      0%      Default |
@lengstrom lengstrom added the bug Something isn't working label Jan 13, 2025
@jakevdp jakevdp changed the title Error on making PRNGKey ptxas unsupported version error Jan 13, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 13, 2025

xref #25344

@lengstrom
Copy link
Author

From #25718: adding the python packaged ptxas to the PATH was an effective workaround..

export PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/cuda_nvcc/bin')"):$PATH

@dfm
Copy link
Collaborator

dfm commented Jan 15, 2025

Thanks for the report and the update. This workaround shouldn't be needed anymore, so let's look into what happened. Can you share all the steps that you used to install JAX? It might also be useful to know what the following outputs (before you change PATH):

import jax
print(jax._src.lib.cuda_path)

@lengstrom
Copy link
Author

The output I get for this is:

/mnt/xfs/home/engstrom/conda_envs/ffcv_2/lib/python3.12/site-packages/nvidia/cuda_nvcc

I think that the problem is that I have /usr/local/cuda-12.2/bin in my path, which points to outdated binaries? Removing this from my PATH fixes the problem.

@dfm
Copy link
Collaborator

dfm commented Jan 15, 2025

Thanks! Yeah, but XLA should search that cuda_path/bin first. Can you please let me know how you installed JAX (e.g. did you use conda or pip?)?

@lengstrom
Copy link
Author

I used pip to install jax via pip install "jax[cuda12]"!

@dfm
Copy link
Collaborator

dfm commented Jan 15, 2025

Thanks. I am able to reproduce this issue, and I think I tracked down the place where XLA is ignoring JAX's CUDA path, but I'm not totally sure how to fix it. For now, I'm glad that you found a workaround, and I'll keep pushing on getting this fixed.

@lengstrom
Copy link
Author

awesome happy to have helped!

@dfm
Copy link
Collaborator

dfm commented Jan 17, 2025

This should be fixed in the next JAX release thanks to this change: openxla/xla#21547

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants