Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async setup #2089

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions python/cog/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from typing import Any, Optional

from .types import (
File as CogFile,
)
from .types import (
Path as CogPath,
)
from .types import Weights


class BasePredictor(ABC):
def setup(
self,
weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument
weights: Optional[Weights] = None, # pylint: disable=unused-argument
) -> None:
"""
An optional method to prepare the model so multiple predictions run efficiently.
Expand Down
18 changes: 9 additions & 9 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import enum
import importlib.util
import inspect
import io
import os.path
import sys
import types
import uuid
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -38,6 +36,7 @@
from .types import (
PYDANTIC_V2,
Input,
Weights,
)
from .types import (
File as CogFile,
Expand All @@ -60,15 +59,16 @@
]


def run_setup(predictor: BasePredictor) -> None:
def has_setup_weights(predictor: BasePredictor) -> bool:
weights_type = get_weights_type(predictor.setup)
return weights_type is not None

# No weights need to be passed, so just run setup() without any arguments.
if weights_type is None:
predictor.setup()
return

weights: Union[io.IOBase, Path, str, None]
def extract_setup_weights(predictor: BasePredictor) -> Optional[Weights]:
weights_type = get_weights_type(predictor.setup)
assert weights_type

weights: Optional[Weights]

weights_url = os.environ.get("COG_WEIGHTS")
weights_path = "weights"
Expand Down Expand Up @@ -119,7 +119,7 @@ def run_setup(predictor: BasePredictor) -> None:
else:
weights = None

predictor.setup(weights=weights) # type: ignore
return weights


def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
Expand Down
164 changes: 116 additions & 48 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@

from ..base_predictor import BasePredictor
from ..json import make_encodeable
from ..predictor import get_predict, load_predictor_from_ref, run_setup
from ..predictor import (
extract_setup_weights,
get_predict,
has_setup_weights,
load_predictor_from_ref,
)
from ..types import PYDANTIC_V2, URLPath
from ..wait import wait_for_env
from .connection import AsyncConnection, LockedConnection
Expand Down Expand Up @@ -366,7 +371,7 @@ def __init__(

# for synchronous predictors only! async predictors use _tag_var instead
self._sync_tag: Optional[str] = None
self._is_async = is_async
self._has_async_predictor = is_async

super().__init__()

Expand All @@ -379,7 +384,7 @@ def run(self) -> None:
# Initially, we ignore SIGUSR1.
signal.signal(signal.SIGUSR1, signal.SIG_IGN)

if self._is_async:
if self._has_async_predictor:
redirector = SimpleStreamRedirector(
callback=self._stream_write_hook,
tee=self._tee_output,
Expand All @@ -397,13 +402,27 @@ def run(self) -> None:
# it has sent a error Done event and we're done here.
if not self._predictor:
return
self._predictor.log = self._log # type: ignore

if not self._validate_predictor(redirector):
return

self._predictor.log = self._log # type: ignore
predict = get_predict(self._predictor)
if self._is_async:

if self._has_async_predictor:
assert isinstance(redirector, SimpleStreamRedirector)
self._setup(redirector)
asyncio.run(self._aloop(predict, redirector))
predictor = self._predictor

async def _runner() -> None:
if hasattr(predictor, "setup") and inspect.iscoroutinefunction(
predictor.setup
):
await self._asetup(redirector)
else:
self._setup(redirector)
await self._aloop(predict, redirector)

asyncio.run(_runner())
else:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)
Expand Down Expand Up @@ -453,60 +472,71 @@ def _load_predictor(self) -> Optional[BasePredictor]:

return None

def _setup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
done = Done()
try:
def _validate_predictor(
self,
redirector: Union[StreamRedirector, SimpleStreamRedirector],
) -> bool:
with self._handle_setup_error(redirector):
assert self._predictor

# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)

predict = get_predict(self._predictor)

is_async_predictor = inspect.iscoroutinefunction(
predict
) or inspect.isasyncgenfunction(predict)

# Async models require python >= 3.11 so we can use asyncio.TaskGroup
# We should check for this before getting to this point
if is_async_predictor and sys.version_info < (3, 11):
if self._has_async_predictor and sys.version_info < (3, 11):
raise FatalWorkerException(
"Cog requires Python >=3.11 for `async def predict()` support"
)

if self._max_concurrency > 1 and not is_async_predictor:
if self._max_concurrency > 1 and not self._has_async_predictor:
raise FatalWorkerException(
"max_concurrency > 1 requires an async predict function, e.g. `async def predict()`"
)

except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Envelope(
event=Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
)
if (
hasattr(self._predictor, "setup")
and inspect.iscoroutinefunction(self._predictor.setup)
and not self._has_async_predictor
):
raise FatalWorkerException(
"Invalid predictor: to use an async setup method you must use an async predict method"
)
raise
self._events.send(Envelope(event=done))

return True

return False

def _setup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
with self._handle_setup_error(redirector, ensure_done_event=True):
assert self._predictor

# Could be a function or a class
if not hasattr(self._predictor, "setup"):
return

if not has_setup_weights(self._predictor):
self._predictor.setup()
return

weights = extract_setup_weights(self._predictor)
self._predictor.setup(weights=weights) # type: ignore

async def _asetup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
with self._handle_setup_error(redirector, ensure_done_event=True):
assert self._predictor

# Could be a function or a class
if not hasattr(self._predictor, "setup"):
return

if not has_setup_weights(self._predictor):
await self._predictor.setup() # type: ignore
return

weights = extract_setup_weights(self._predictor)
await self._predictor.setup(weights=weights) # type: ignore

def _loop(
self,
Expand Down Expand Up @@ -654,6 +684,44 @@ async def _apredict(
)
)

@contextlib.contextmanager
def _handle_setup_error(
self,
redirector: Union[SimpleStreamRedirector, StreamRedirector],
*,
ensure_done_event: bool = False,
) -> Iterator[None]:
done = Done()
try:
yield
except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Envelope(
event=Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
)
)
raise

if done.error or ensure_done_event:
self._events.send(Envelope(event=done))

@contextlib.contextmanager
def _handle_predict_error(
self,
Expand Down
3 changes: 3 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,9 @@ def __get_validators__(cls) -> Iterator[Any]:
yield cls.validate


Weights = Union[File, Path, str]


def get_filename_from_urlopen(resp: urllib.response.addinfourl) -> str:
mime_type = resp.headers.get_content_type()
extension = mimetypes.guess_extension(mime_type)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio


class Predictor:
async def setup(self) -> None:
self.loop = asyncio.get_running_loop()
Expand Down
12 changes: 12 additions & 0 deletions python/tests/server/fixtures/setup_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Predictor:
async def download(self) -> None:
print("download complete!")

async def setup(self) -> None:
print("setup starting...")
await self.download()
print("setup complete!")

async def predict(self) -> str:
print("running prediction")
return "output"
9 changes: 9 additions & 0 deletions python/tests/server/fixtures/setup_async_with_sync_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Predictor:
async def download(self) -> None:
print("setup used asyncio.run! it's not very effective...")

async def setup(self) -> None:
await self.download()

def predict(self) -> str:
return "output"
Loading
Loading