You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
from lab import B
def compute(y):
f = GP(EQ())
f = f | (f(jnp.array([0., 1.])), y)
return B.dense(f(0).mean)[0][0]
jax.vmap(compute)(jnp.array([[1., 2.], [3., 4.]]))
Converting to NumPy arrays is also slow ig?
Can this be changed so that it calls jax.lax.cond under the hood instead?
The text was updated successfully, but these errors were encountered:
It has occurred to me that this is a duplicate of #21; for reference one can fix by using
import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
from lab import B
@B.jit
def compute(y):
f = GP(EQ())
f = f | (f(jnp.array([0., 1.])), y)
return B.dense(f(0).mean)[0][0]
compute(jnp.array([1., 2.]))
print(jax.make_jaxpr(compute)(jnp.array([1., 2.])))
Note that one has to call compute first or else tracing does not work.
Hey @Alan-Chen99! I’m glad to you managed to get to the bottom of this. I realise the extra compilation step isn’t an ideal solution. Possibly using B.jit_to_numpy by default is a bit too aggressive and should be replaced by something that supports JAX tracing without the extra compilation step, eg jax.lax.cond as you suggest.
For example this code don't work:
Converting to NumPy arrays is also slow ig?
Can this be changed so that it calls jax.lax.cond under the hood instead?
The text was updated successfully, but these errors were encountered: