Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end support for concurrent async models #2066

Merged
merged 16 commits into from
Dec 6, 2024

Conversation

philandstuff
Copy link
Contributor

This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run). I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop. As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code. This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).

@philandstuff philandstuff requested a review from a team November 26, 2024 11:33
Copy link
Member

@erbridge erbridge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me. A few minor comments and suggestions.

python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/runner.py Outdated Show resolved Hide resolved
python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/worker.py Outdated Show resolved Hide resolved
python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/worker.py Outdated Show resolved Hide resolved
Copy link
Contributor

@aron aron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just tested this locally with a basic async model that yields strings at an interval. Works wonderfully.

python/cog/server/runner.py Show resolved Hide resolved
python/cog/server/worker.py Outdated Show resolved Hide resolved
python/cog/server/worker.py Outdated Show resolved Hide resolved
@aron aron force-pushed the support-concurrent-predictions-in-child branch 2 times, most recently from e307a39 to 41adaa9 Compare November 26, 2024 15:40
@aron
Copy link
Contributor

aron commented Nov 26, 2024

Cutting a 0.14.0-alpha.1 pre-release build https://github.com/replicate/cog/actions/runs/12035403546

Abandoned this after chatting with @nickstenning

@aron
Copy link
Contributor

aron commented Nov 26, 2024

This is probably for later, but I don't think we have the correct support for typing the output of these async predict functions. For example async def predict(...) -> Iterator[str] will raise type errors as it should be AsyncGenerator

@aron aron force-pushed the support-concurrent-predictions-in-child branch 3 times, most recently from 05b7c34 to 5cdaf9b Compare November 29, 2024 15:51
philandstuff and others added 16 commits December 6, 2024 15:28
We require python >=3.11 to support asyncio.TaskGroup
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).
The use of `Optional` allowed `None` as a valid value. This has been
changed to use `NotRequired` which allows the field to be omitted but
must always be an integer when present.
Inside the worker we track predictions by tag not exterenal predicition
IDs, this commit updates the variable names to reflect this.
the `for tag in done_tags:` was resetting the existing `tag` variable and
breaking things.
This commit manually calls `flush()` on the `SimpleStreamWrapper` each
time the string provided to `write()` contains a newline character.

The previous implementation assumed that the underlying TextIOWrapper
class would call our custom `flush()` method but this is not the case as
`TextIOWrapper` is implemented in C and calls into the compiled code.
This helps with the transition between the main cog branch and the
experimental `async` branch. Models built for the `async` branch
can be run on either without code changes.
This copies the functionality over from the `async` branch and allows us
to test models on production. This will emit a `DeprecationWarning`.
Users will need to add `warnings.filterwarnings("once",
DeprecationWarning)` to their code to see the error.
@aron aron force-pushed the support-concurrent-predictions-in-child branch from 6b7f18d to 925ad8d Compare December 6, 2024 15:28
@aron
Copy link
Contributor

aron commented Dec 6, 2024

@philandstuff I've tested this pretty thoroughly on a combination of models:

  1. File outputs sync/async
  2. Text generators sync/async
  3. Emit metrics sync/async
  4. flux-dev which was having issues with setup runs
  5. Models calling self.log()

@aron aron merged commit 3ca6205 into main Dec 6, 2024
19 checks passed
@aron aron deleted the support-concurrent-predictions-in-child branch December 6, 2024 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants