Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.dot_product_attention CuDNN implementation raises tensor stride error during jit compile #25986

Open
liamclarkza opened this issue Jan 20, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@liamclarkza
Copy link

Description

I am currently experiencing an issue where I am getting a CuDNN error relating to the stride of my K matrix when using jax.nn.dot_product_attention within a flax model. This occurs when jitting and the error stems from the CuDNN dimension checks here. I am not sure what exactly is causing the striding issue with the k tensor, and I have checked the shapes and sharding for the inputs; however, I am struggling to find a way to debug this issue further.

When using the implementation argument set to 'xla', the model jits, and I am able to train with it.

The shapes for q, k and v are all (8, 2048, 40, 128) and all are sharded along the first (batch) dimension, having the following sharding:
NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8), spec=PartitionSpec('fsdp',), memory_kind=device).

The function is called as below:

jax.nn.dot_product_attention(
    q.astype(jnp.bfloat16),
    k.astype(jnp.bfloat16),
    v.astype(jnp.bfloat16),
    mask=None, # I have tested with/without masking but get the same error either way
    scale=float(q.shape[-1] ** -0.5),
    implementation='cudnn',
)

This gives the following error:

*** truncated ***
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 427, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 655, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'

If there are any ways to further debug the striding of my underlying tensor, and, if possible, how to force a contiguous layout that matches that of the shape of my tensor, please let me know.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.0.2
python: 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='experiment-2eb4a7d7-dad7-head', release='6.8.0-49-generic', version='#49~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Nov  6 17:42:15 UTC 2', machine='x86_64')


$ nvidia-smi
Mon Jan 20 11:16:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:19:00.0 Off |                    0 |
| N/A   38C    P0            120W /  700W |     550MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:3B:00.0 Off |                    0 |
| N/A   36C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   33C    P0            115W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:5D:00.0 Off |                    0 |
| N/A   36C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   38C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:BB:00.0 Off |                    0 |
| N/A   36C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:CB:00.0 Off |                    0 |
| N/A   37C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:DB:00.0 Off |                    0 |
| N/A   34C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    1   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    2   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    3   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    4   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    5   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    6   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    7   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
+-----------------------------------------------------------------------------------------+
@liamclarkza liamclarkza added the bug Something isn't working label Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant