-
Notifications
You must be signed in to change notification settings - Fork 316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] boring stuff: define a sampler interface #31
Comments
qdbp
changed the title
boring stuff: define a sampler interface
[RFC] boring stuff: define a sampler interface
Oct 7, 2024
I am working on adding Entropix to vLLM so I can plug it into my inference workflows. Will reference this issue here |
we should fold the "key" into either the config or the state. having it as a top level arg is messy. if it should be common to all models dataclass inheritance will come in handy here |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
tl;dr samplers should be swappable and composable. for that we need a common interface
There's a lot of hot stuff in the pipeline re. MCTS, the vanilla sampler, etc.
One thing I'm afraid of is that there's going to be a lot of spaghetti involving bespoke/subtly different ways to call different samplers, which will make benchmarking and comparison painful.
I want to get ahead of this issue by defining a common interface to samplers. Since this is Python I think this should be a Protocol, something like:
which is a light touch (no need to inherit) but can still be checked. This should be a generic enough framework for people to be able to implement their favorite MuZero etc. and have it all plug in to the same harnesses.
My goal here is to have an easy to maintain sampler benchmarking suite with easy plug and play samplers.
EDIT
given the jax idiom of passing and returning state as an argument (and to support some sampler work of my own, tee hee), I think it will make sense to expand this interface to include a
ST
type var.EDIT 2
since we're returning a
jax.Array
in place of the token output, I propose that it be acceptable to return either a single or an entire sequence of tokens at once from a sample call. the callers should be able to handle either case (and, really, a single token is just a sequence of length 1)The text was updated successfully, but these errors were encountered: