You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently experiencing an issue where I am getting a CuDNN error relating to the stride of my K matrix when using jax.nn.dot_product_attention within a flax model. This occurs when jitting and the error stems from the CuDNN dimension checks here. I am not sure what exactly is causing the striding issue with the k tensor, and I have checked the shapes and sharding for the inputs; however, I am struggling to find a way to debug this issue further.
When using the implementation argument set to 'xla', the model jits, and I am able to train with it.
The shapes for q, k and v are all (8, 2048, 40, 128) and all are sharded along the first (batch) dimension, having the following sharding: NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8), spec=PartitionSpec('fsdp',), memory_kind=device).
The function is called as below:
jax.nn.dot_product_attention(
q.astype(jnp.bfloat16),
k.astype(jnp.bfloat16),
v.astype(jnp.bfloat16),
mask=None, # I have tested with/without masking but get the same error either wayscale=float(q.shape[-1] **-0.5),
implementation='cudnn',
)
This gives the following error:
*** truncated ***
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 427, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 655, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'
If there are any ways to further debug the striding of my underlying tensor, and, if possible, how to force a contiguous layout that matches that of the shape of my tensor, please let me know.
System info (python version, jaxlib version, accelerator, etc.)
Description
I am currently experiencing an issue where I am getting a CuDNN error relating to the stride of my K matrix when using
jax.nn.dot_product_attention
within a flax model. This occurs when jitting and the error stems from the CuDNN dimension checks here. I am not sure what exactly is causing the striding issue with thek
tensor, and I have checked the shapes and sharding for the inputs; however, I am struggling to find a way to debug this issue further.When using the
implementation
argument set to'xla'
, the model jits, and I am able to train with it.The shapes for
q
,k
andv
are all(8, 2048, 40, 128)
and all are sharded along the first (batch) dimension, having the following sharding:NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8), spec=PartitionSpec('fsdp',), memory_kind=device)
.The function is called as below:
This gives the following error:
If there are any ways to further debug the striding of my underlying tensor, and, if possible, how to force a contiguous layout that matches that of the shape of my tensor, please let me know.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: