diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 0af1838989..f404aa9c00 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -4,7 +4,6 @@ import inspect import multiprocessing import os -import select import signal import sys import threading @@ -13,6 +12,7 @@ import uuid import warnings import weakref +from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection @@ -126,18 +126,16 @@ def __init__( self._predictions_lock = threading.Lock() self._predictions_in_flight: Dict[Optional[str], PredictionState] = {} - recv_conn, send_conn = _spawn.Pipe(duplex=False) - self._request_send_conn = send_conn - self._request_recv_conn = recv_conn - - self._pool = ThreadPoolExecutor(max_workers=1) + self._event_consumer_pool = ThreadPoolExecutor(max_workers=1) + self._prediction_start_pool = ThreadPoolExecutor(max_workers=max_concurrency) + self._input_download_pool = ThreadPoolExecutor(max_workers=8) self._event_consumer = None def setup(self) -> "Future[Done]": self._assert_state(WorkerState.NEW) self._state = WorkerState.STARTING self._child.start() - self._event_consumer = self._pool.submit(self._consume_events) + self._event_consumer = self._event_consumer_pool.submit(self._consume_events) return self._setup_result def predict( @@ -163,9 +161,57 @@ def predict( self._assert_state(WorkerState.READY) result = Future() self._predictions_in_flight[tag] = PredictionState(tag, payload, result) - self._request_send_conn.send(PredictionRequest(tag)) + + self._prediction_start_pool.submit(self._start_prediction(tag, payload)) return result + def _start_prediction( + self, tag: Optional[str], payload: Dict[str, Any] + ) -> Callable[[], None]: + def start_prediction() -> None: + try: + to_await = [] + futs = {} + # Prepare payload asynchronously (download URLPath objects) + for k, v in payload.items(): + # Check if v is an instance of URLPath + if isinstance(v, URLPath): + futs[k] = self._input_download_pool.submit(v.convert) + to_await.append(futs[k]) + # Check if v is a list of URLPath instances + elif isinstance(v, list) and all( + isinstance(item, URLPath) for item in v + ): + futs[k] = [ + self._input_download_pool.submit(item.convert) for item in v + ] + to_await += futs[k] + futures.wait(to_await, return_when=futures.FIRST_EXCEPTION) + for k, v in futs.items(): + if isinstance(v, list): + payload[k] = [] + for fut in v: + # the future may not be done if and only if another + # future finished with an exception + if fut.done(): + payload[k].append(fut.result()) + elif isinstance(v, Future): + if v.done(): + payload[k] = v.result() + # send the prediction to the child to start + self._events.send( + Envelope( + event=PredictionInput(payload=payload), + tag=tag, + ) + ) + except Exception as e: + done = Done(error=True, error_detail=str(e)) + self._publish(Envelope(done, tag)) + self._complete_prediction(done, tag) + + return start_prediction + def subscribe( self, subscriber: Callable[[_PublicEventType], None], @@ -195,7 +241,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None: if self._event_consumer: self._event_consumer.result(timeout=timeout) - self._pool.shutdown() + self._event_consumer_pool.shutdown() def terminate(self) -> None: """ @@ -209,10 +255,15 @@ def terminate(self) -> None: self._child.terminate() self._child.join() - self._pool.shutdown(wait=False) + self._event_consumer_pool.shutdown(wait=False) def cancel(self, tag: Optional[str] = None) -> None: - self._request_send_conn.send(CancelRequest(tag)) + with self._predictions_lock: + predict_state = self._predictions_in_flight.get(tag) + if predict_state and not predict_state.cancel_sent: + self._child.send_cancel_signal() + self._events.send(Envelope(event=Cancel(), tag=tag)) + predict_state.cancel_sent = True def _assert_state(self, state: WorkerState) -> None: if self._state != state: @@ -268,48 +319,14 @@ def _consume_events_inner(self) -> None: # Main event loop while self._child.is_alive(): - # see if we have any new prediction requests - - read_socks, _, _ = select.select( - [self._request_recv_conn, self._events], [], [], 0.1 - ) - if self._request_recv_conn in read_socks: - ev = self._request_recv_conn.recv() - if isinstance(ev, PredictionRequest): - with self._predictions_lock: - state = self._predictions_in_flight[ev.tag] - - # Prepare payload (download URLPath objects) - # FIXME this blocks the event loop, which is bad in concurrent mode - try: - _prepare_payload(state.payload) - except Exception as e: - done = Done(error=True, error_detail=str(e)) - self._publish(Envelope(done, state.tag)) - self._complete_prediction(done, state.tag) - else: - # Start the prediction - self._events.send( - Envelope( - event=PredictionInput(payload=state.payload), - tag=state.tag, - ) - ) - elif isinstance(ev, CancelRequest): - with self._predictions_lock: - predict_state = self._predictions_in_flight.get(ev.tag) - if predict_state and not predict_state.cancel_sent: - self._child.send_cancel_signal() - self._events.send(Envelope(event=Cancel(), tag=ev.tag)) - predict_state.cancel_sent = True - else: - log.warn("unrecognized request event: {ev}") + # wait for events from the child worker + if not self._events.poll(0.1): + continue - if self._events in read_socks: - ev = self._events.recv() - self._publish(ev) - if isinstance(ev.event, Done): - self._complete_prediction(ev.event, ev.tag) + ev = self._events.recv() + self._publish(ev) + if isinstance(ev.event, Done): + self._complete_prediction(ev.event, ev.tag) # If we dropped off the end off the end of the loop, it's because the # child process died. First, process any remaining messages on the connection @@ -844,13 +861,3 @@ def make_worker( ) parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency) return parent - - -def _prepare_payload(payload: Dict[str, Any]) -> None: - for k, v in payload.items(): - # Check if v is an instance of URLPath - if isinstance(v, URLPath): - payload[k] = v.convert() - # Check if v is a list of URLPath instances - elif isinstance(v, list) and all(isinstance(item, URLPath) for item in v): - payload[k] = [item.convert() for item in v]