diff --git a/python/cog/base_predictor.py b/python/cog/base_predictor.py index 41393ae37a..3cc5899603 100644 --- a/python/cog/base_predictor.py +++ b/python/cog/base_predictor.py @@ -1,18 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, Optional -from .types import ( - File as CogFile, -) -from .types import ( - Path as CogPath, -) +from .types import Weights class BasePredictor(ABC): def setup( self, - weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument + weights: Optional[Weights] = None, # pylint: disable=unused-argument ) -> None: """ An optional method to prepare the model so multiple predictions run efficiently. diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 4eef0fe481..a9b32f7553 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -2,13 +2,11 @@ import enum import importlib.util import inspect -import io import os.path import sys import types import uuid from collections.abc import Iterable, Iterator -from pathlib import Path from typing import ( Any, Callable, @@ -38,6 +36,7 @@ from .types import ( PYDANTIC_V2, Input, + Weights, ) from .types import ( File as CogFile, @@ -60,15 +59,16 @@ ] -def run_setup(predictor: BasePredictor) -> None: +def has_setup_weights(predictor: BasePredictor) -> bool: weights_type = get_weights_type(predictor.setup) + return weights_type is not None - # No weights need to be passed, so just run setup() without any arguments. - if weights_type is None: - predictor.setup() - return - weights: Union[io.IOBase, Path, str, None] +def extract_setup_weights(predictor: BasePredictor) -> Optional[Weights]: + weights_type = get_weights_type(predictor.setup) + assert weights_type + + weights: Optional[Weights] weights_url = os.environ.get("COG_WEIGHTS") weights_path = "weights" @@ -119,7 +119,7 @@ def run_setup(predictor: BasePredictor) -> None: else: weights = None - predictor.setup(weights=weights) # type: ignore + return weights def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]: diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 2cba7c3f7e..a91ca85d92 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -31,7 +31,12 @@ from ..base_predictor import BasePredictor from ..json import make_encodeable -from ..predictor import get_predict, load_predictor_from_ref, run_setup +from ..predictor import ( + extract_setup_weights, + get_predict, + has_setup_weights, + load_predictor_from_ref, +) from ..types import PYDANTIC_V2, URLPath from ..wait import wait_for_env from .connection import AsyncConnection, LockedConnection @@ -366,7 +371,7 @@ def __init__( # for synchronous predictors only! async predictors use _tag_var instead self._sync_tag: Optional[str] = None - self._is_async = is_async + self._has_async_predictor = is_async super().__init__() @@ -379,7 +384,7 @@ def run(self) -> None: # Initially, we ignore SIGUSR1. signal.signal(signal.SIGUSR1, signal.SIG_IGN) - if self._is_async: + if self._has_async_predictor: redirector = SimpleStreamRedirector( callback=self._stream_write_hook, tee=self._tee_output, @@ -397,13 +402,27 @@ def run(self) -> None: # it has sent a error Done event and we're done here. if not self._predictor: return - self._predictor.log = self._log # type: ignore + if not self._validate_predictor(redirector): + return + + self._predictor.log = self._log # type: ignore predict = get_predict(self._predictor) - if self._is_async: + + if self._has_async_predictor: assert isinstance(redirector, SimpleStreamRedirector) - self._setup(redirector) - asyncio.run(self._aloop(predict, redirector)) + predictor = self._predictor + + async def _runner() -> None: + if hasattr(predictor, "setup") and inspect.iscoroutinefunction( + predictor.setup + ): + await self._asetup(redirector) + else: + self._setup(redirector) + await self._aloop(predict, redirector) + + asyncio.run(_runner()) else: # We use SIGUSR1 to signal an interrupt for cancelation. signal.signal(signal.SIGUSR1, self._signal_handler) @@ -453,60 +472,71 @@ def _load_predictor(self) -> Optional[BasePredictor]: return None - def _setup( - self, redirector: Union[StreamRedirector, SimpleStreamRedirector] - ) -> None: - done = Done() - try: + def _validate_predictor( + self, + redirector: Union[StreamRedirector, SimpleStreamRedirector], + ) -> bool: + with self._handle_setup_error(redirector): assert self._predictor - # Could be a function or a class - if hasattr(self._predictor, "setup"): - run_setup(self._predictor) - - predict = get_predict(self._predictor) - - is_async_predictor = inspect.iscoroutinefunction( - predict - ) or inspect.isasyncgenfunction(predict) - # Async models require python >= 3.11 so we can use asyncio.TaskGroup # We should check for this before getting to this point - if is_async_predictor and sys.version_info < (3, 11): + if self._has_async_predictor and sys.version_info < (3, 11): raise FatalWorkerException( "Cog requires Python >=3.11 for `async def predict()` support" ) - if self._max_concurrency > 1 and not is_async_predictor: + if self._max_concurrency > 1 and not self._has_async_predictor: raise FatalWorkerException( "max_concurrency > 1 requires an async predict function, e.g. `async def predict()`" ) - except Exception as e: # pylint: disable=broad-exception-caught - traceback.print_exc() - done.error = True - done.error_detail = str(e) - except BaseException as e: - # For SystemExit and friends we attempt to add some useful context - # to the logs, but reraise to ensure the process dies. - traceback.print_exc() - done.error = True - done.error_detail = str(e) - raise - finally: - try: - redirector.drain(timeout=10) - except TimeoutError: - self._events.send( - Envelope( - event=Log( - "WARNING: logs may be truncated due to excessive volume.", - source="stderr", - ) - ) + if ( + hasattr(self._predictor, "setup") + and inspect.iscoroutinefunction(self._predictor.setup) + and not self._has_async_predictor + ): + raise FatalWorkerException( + "Invalid predictor: to use an async setup method you must use an async predict method" ) - raise - self._events.send(Envelope(event=done)) + + return True + + return False + + def _setup( + self, redirector: Union[StreamRedirector, SimpleStreamRedirector] + ) -> None: + with self._handle_setup_error(redirector, ensure_done_event=True): + assert self._predictor + + # Could be a function or a class + if not hasattr(self._predictor, "setup"): + return + + if not has_setup_weights(self._predictor): + self._predictor.setup() + return + + weights = extract_setup_weights(self._predictor) + self._predictor.setup(weights=weights) # type: ignore + + async def _asetup( + self, redirector: Union[StreamRedirector, SimpleStreamRedirector] + ) -> None: + with self._handle_setup_error(redirector, ensure_done_event=True): + assert self._predictor + + # Could be a function or a class + if not hasattr(self._predictor, "setup"): + return + + if not has_setup_weights(self._predictor): + await self._predictor.setup() # type: ignore + return + + weights = extract_setup_weights(self._predictor) + await self._predictor.setup(weights=weights) # type: ignore def _loop( self, @@ -654,6 +684,44 @@ async def _apredict( ) ) + @contextlib.contextmanager + def _handle_setup_error( + self, + redirector: Union[SimpleStreamRedirector, StreamRedirector], + *, + ensure_done_event: bool = False, + ) -> Iterator[None]: + done = Done() + try: + yield + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + done.error = True + done.error_detail = str(e) + except BaseException as e: + # For SystemExit and friends we attempt to add some useful context + # to the logs, but reraise to ensure the process dies. + traceback.print_exc() + done.error = True + done.error_detail = str(e) + raise + finally: + try: + redirector.drain(timeout=10) + except TimeoutError: + self._events.send( + Envelope( + event=Log( + "WARNING: logs may be truncated due to excessive volume.", + source="stderr", + ) + ) + ) + raise + + if done.error or ensure_done_event: + self._events.send(Envelope(event=done)) + @contextlib.contextmanager def _handle_predict_error( self, diff --git a/python/cog/types.py b/python/cog/types.py index c27247afa9..afb71d0586 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -521,6 +521,9 @@ def __get_validators__(cls) -> Iterator[Any]: yield cls.validate +Weights = Union[File, Path, str] + + def get_filename_from_urlopen(resp: urllib.response.addinfourl) -> str: mime_type = resp.headers.get_content_type() extension = mimetypes.guess_extension(mime_type) diff --git a/python/tests/server/fixtures/async_setup_uses_same_loop_as_predict.py b/python/tests/server/fixtures/async_setup_uses_same_loop_as_predict.py index f454034036..92501755d9 100644 --- a/python/tests/server/fixtures/async_setup_uses_same_loop_as_predict.py +++ b/python/tests/server/fixtures/async_setup_uses_same_loop_as_predict.py @@ -1,6 +1,5 @@ import asyncio - class Predictor: async def setup(self) -> None: self.loop = asyncio.get_running_loop() diff --git a/python/tests/server/fixtures/setup_async.py b/python/tests/server/fixtures/setup_async.py new file mode 100644 index 0000000000..5415f2d6cd --- /dev/null +++ b/python/tests/server/fixtures/setup_async.py @@ -0,0 +1,12 @@ +class Predictor: + async def download(self) -> None: + print("download complete!") + + async def setup(self) -> None: + print("setup starting...") + await self.download() + print("setup complete!") + + async def predict(self) -> str: + print("running prediction") + return "output" diff --git a/python/tests/server/fixtures/setup_async_with_sync_predict.py b/python/tests/server/fixtures/setup_async_with_sync_predict.py new file mode 100644 index 0000000000..0ccf302109 --- /dev/null +++ b/python/tests/server/fixtures/setup_async_with_sync_predict.py @@ -0,0 +1,9 @@ +class Predictor: + async def download(self) -> None: + print("setup used asyncio.run! it's not very effective...") + + async def setup(self) -> None: + await self.download() + + def predict(self) -> str: + return "output" diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index d7aa69cad1..6681ed36ae 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -335,7 +335,7 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): @pytest.mark.skipif( sys.version_info >= (3, 11), reason="Testing error message on python versions <3.11" ) -@uses_worker("simple_async", setup=False) +@uses_worker("simple_async", setup=False, is_async=True) def test_async_predictor_on_python_3_10_or_older_raises_error(worker): fut = worker.setup() result = Result() @@ -351,6 +351,57 @@ def test_async_predictor_on_python_3_10_or_older_raises_error(worker): ) +@uses_worker( + "setup_async", max_concurrency=1, min_python=(3, 11), is_async=True, setup=False +) +def test_setup_async(worker: Worker): + fut = worker.setup() + setup_result = Result() + setup_sid = worker.subscribe(setup_result.handle_event) + + # with pytest.raises(FatalWorkerException): + fut.result() + worker.unsubscribe(setup_sid) + + assert setup_result.stdout_lines == [ + "setup starting...\n", + "download complete!\n", + "setup complete!\n", + ] + + predict_result = Result() + predict_sid = worker.subscribe(predict_result.handle_event, tag="p1") + worker.predict({}, tag="p1").result() + + assert predict_result.done + assert predict_result.output == "output" + assert predict_result.stdout_lines == ["running prediction\n"] + + worker.unsubscribe(predict_sid) + + +@uses_worker( + "setup_async_with_sync_predict", + max_concurrency=1, + min_python=(3, 11), + is_async=False, + setup=False, +) +def test_setup_async_with_sync_predict_raises_error(worker: Worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail + == "Invalid predictor: to use an async setup method you must use an async predict method" + ) + + @uses_worker("simple", max_concurrency=5, setup=False) def test_concurrency_with_sync_predictor_raises_error(worker): fut = worker.setup() @@ -555,6 +606,12 @@ def test_graceful_shutdown(worker): assert fut.result() == Done() +@uses_worker("async_setup_uses_same_loop_as_predict", min_python=(3, 11), is_async=True) +def test_async_setup_uses_same_loop_as_predict(worker: Worker): + result = _process(worker, lambda: worker.predict({}), tag=None) + assert result, "Expected worker to return True to assert same event loop" + + @frozen class SetupState: fut: "Future[Done]"