Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

B.jit_to_numpy pervents JAX transforms in AbstractObservations.__init__ #26

Closed
Alan-Chen99 opened this issue Jan 30, 2023 · 2 comments
Closed

Comments

@Alan-Chen99
Copy link

For example this code don't work:

import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
from lab import B

def compute(y):
    f = GP(EQ())
    f = f | (f(jnp.array([0., 1.])), y)
    return B.dense(f(0).mean)[0][0]

jax.vmap(compute)(jnp.array([[1., 2.], [3., 4.]]))

Converting to NumPy arrays is also slow ig?

Can this be changed so that it calls jax.lax.cond under the hood instead?

@Alan-Chen99
Copy link
Author

Alan-Chen99 commented Jan 30, 2023

It has occurred to me that this is a duplicate of #21; for reference one can fix by using

import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
from lab import B

@B.jit
def compute(y):
    f = GP(EQ())
    f = f | (f(jnp.array([0., 1.])), y)
    return B.dense(f(0).mean)[0][0]

compute(jnp.array([1., 2.]))
print(jax.make_jaxpr(compute)(jnp.array([1., 2.])))

Note that one has to call compute first or else tracing does not work.

@wesselb
Copy link
Owner

wesselb commented Jan 30, 2023

Hey @Alan-Chen99! I’m glad to you managed to get to the bottom of this. I realise the extra compilation step isn’t an ideal solution. Possibly using B.jit_to_numpy by default is a bit too aggressive and should be replaced by something that supports JAX tracing without the extra compilation step, eg jax.lax.cond as you suggest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants