Adds a minimal but viable implementation of string arrays (with numpy.dtypes.StringDType
) in JAX. Currently this only supports making of a string array by means of either jax.numpy.asarray
or jax.device_put
and reading it back with jax.device_get
.
#70394
Job | Run time |
---|---|
1m 30s | |
8m 9s | |
1m 24s | |
7m 59s | |
3m 1s | |
2m 29s | |
1m 25s | |
25m 57s |