-
Notifications
You must be signed in to change notification settings - Fork 616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Add execution support for broadcast_in_dim_p
and iota_p
#6865
base: dynamic-capture-hop-2
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## dynamic-capture-hop-2 #6865 +/- ##
========================================================
Coverage ? 99.60%
========================================================
Files ? 477
Lines ? 45263
Branches ? 0
========================================================
Hits ? 45083
Misses ? 180
Partials ? 0 ☔ View full report in Codecov by Sentry. |
broadcast_in_dim_p
broadcast_in_dim_p
and iota_p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I'd suggest expanding the docstring of _fill_in_shape_with_dyn_shape
, and also adding some more inline comments to the primitive registrations about the changes. I had to look at the jaxpr of a dummy function locally to convince myself why this works 😅 . Maybe it's worth including some info in the intro_to_dynamic_shapes.md
file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I think a few more dev comments should be added, otherwise this looks approval ready to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks excellent, I just don't understand the examples in the docstrings. If you could clarify that it would be great 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Co-authored-by: Pietropaolo Frisoni <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
Context:
While we can capture the creation of arrays with a dynamic shape like
jnp.ones
,jnp.zeros
, andjnp.arange
, jax does not support executing these equations:I think
jax.core.eval_jaxpr
is trying to jit and lower thebroadcast_in_dim
primitive, but can't. This problem propogates to executing any jaxpr with aPlxprInterpreter
Description of the Change:
Registers special handling for
broadcast_in_dim
andiota
.Benefits:
We can execute the creation of dynamically shaped arrays.
Possible Drawbacks:
Sensitive to the location of
jax.lax.broadcast_in_dim_p
andjax.lax.iota_p
. So this may break if jax moves things around.Related GitHub Issues:
[sc-82653] [sc-82656]