-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get "invalid value (nan) encountered in jit" even when jit disabled globally #25701
Comments
Thanks for the report! Just to be clear, jax.config.update('jax_disable_jit', True) Even with this change, however, the same error message appears. This is due to the fact that The fix, if we were to do it, would be to add a specific check for What do you think? |
Thanks for your response! I guess my question is more related to the other side: the message seems to suggest that there's an optimized function that generates nan, and there's another de-optimized function that does not generate nan. However in this case log(-10) is always nan, so why it says that there's some de-optimized function that does not generate nan? Actually I have some complicated function in my realistic setting which has the same error (maybe related to autograd of svd which is natually unstable), and I hope that I can disable jit globally to remove the nan but I can't. |
Sorry, that error message is busted. #25519 will fix it, but for now the "de-optimized function doesn't generate a nan" erroneously happens every time. |
Here's what I get when I run on the #25519 branch:
I think this is a decent error message, but the phrasing is still a bit confusing with disable_jit=True present. @emilyfertig let's think about disable_jit=True, either in #25519 or a follow-up. @SUSYUSTC sorry for the confusion! We hope to land that PR soon. In the meantime, if you think debug_nans isn't giving you a useful error message where it should, you could try patching that branch. |
Thanks a lot! That is indeed what I expected to have. Another question is related to nan appeared in autograd. Say I have a function whose value is valid but grad is not. Is it possible that jax can give me the traceback which tells me which exact line generates nan in a very complicated function? Is this feature currently available in this branch or even somewhere in the main branch? Here's an example:
Clearly the issue is that sqrt is not differentiable at 0.0, so I hope it could pin to the
|
Here's the traceback on the branch:
We'd like to improve this further since it's not as easy to read as we'd like, but there's interesting information there:
The reason the issue in this particular example shows up with the If we instead set
That points to the divide-by-zero that happens in the JVP of WDYT? |
Thanks a lot for the explanation! It is super clear. Actually I can get the same message in the main branch code but I didn't understand its meaning. |
Description
In my following code clearly there's no jit anywhere, but the error suggests that the issue comes from jit.
Error information:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.36
jaxlib: 0.4.36
numpy: 1.26.1
python: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:34:09) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 2060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='jiace-XPS-8930', release='5.4.0-150-generic', version='#167~18.04.1-Ubuntu SMP Wed May 24 00:51:42 UTC 2023', machine='x86_64')
$ nvidia-smi
Mon Dec 30 22:29:29 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05 Driver Version: 525.147.05 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | N/A |
| 32% 28C P2 15W / 160W | 131MiB / 6144MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2745 G /usr/lib/xorg/Xorg 16MiB |
| 0 N/A N/A 13172 C ...envs/py311/bin/python3.11 110MiB |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: