Skip to content

Commit

Permalink
Support async setup method on Predictor
Browse files Browse the repository at this point in the history
This commit introduces the ability to define an async `setup` function
on your predictor. For simplicity an async `setup()` function is only
supported alongside an async `predict()` function. An error will be
raised during setup if this is not the case.

Various pieces of the code have been extracted into smaller methods
in order to achieve this. A new `_handle_setup_error` context manager
has been created to handle setup errors and send appropriate `Done`
event over the worker channel.

The `_setup()` method has been split into two phases, first we perform
validation on the requirements for async/concurrency support. Then we
attempt to run the `setup()` method either as a direct call for the
non-async path or as part of the event loop in the async path.
  • Loading branch information
Aron Carroll committed Dec 12, 2024
1 parent 08e00ec commit 541dfae
Showing 1 changed file with 102 additions and 41 deletions.
143 changes: 102 additions & 41 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,27 @@ def run(self) -> None:
# it has sent a error Done event and we're done here.
if not self._predictor:
return
self._predictor.log = self._log # type: ignore

if not self._presetup_validation(redirector):
return

self._predictor.log = self._log # type: ignore
predict = get_predict(self._predictor)

if self._is_async:
assert isinstance(redirector, SimpleStreamRedirector)
self._setup(redirector)
asyncio.run(self._aloop(predict, redirector))
predictor = self._predictor

async def _runner():
if hasattr(predictor, "setup") and inspect.iscoroutinefunction(
predictor.setup
):
await self._asetup(redirector)
else:
self._setup(redirector)
await self._aloop(predict, redirector)

asyncio.run(_runner())
else:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)
Expand Down Expand Up @@ -458,22 +472,14 @@ def _load_predictor(self) -> Optional[BasePredictor]:

return None

def _setup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
done = Done()
try:
def _presetup_validation(
self,
redirector: Union[StreamRedirector, SimpleStreamRedirector],
) -> bool:
is_valid = False
with self._handle_setup_error(redirector):
assert self._predictor

# Could be a function or a class
if hasattr(self._predictor, "setup"):
if not has_setup_weights(predictor):
predictor.setup()
return

weights = extract_setup_weights(predictor)
predictor.setup(weights=weights) # type: ignore

predict = get_predict(self._predictor)

is_async_predictor = inspect.iscoroutinefunction(
Expand All @@ -492,31 +498,52 @@ def _setup(
"max_concurrency > 1 requires an async predict function, e.g. `async def predict()`"
)

except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Envelope(
event=Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
)
if (
hasattr(self._predictor, "setup")
and inspect.iscoroutinefunction(self._predictor.setup)
and not is_async_predictor
):
raise FatalWorkerException(
"Invalid predictor: to use an async setup method you must use an async predict method"
)
raise
self._events.send(Envelope(event=done))

is_valid = True

return is_valid

def _setup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
with self._handle_setup_error(redirector):
assert self._predictor

# Could be a function or a class
if not hasattr(self._predictor, "setup"):
return

if not has_setup_weights(self._predictor):
self._predictor.setup()
return

weights = extract_setup_weights(self._predictor)
self._predictor.setup(weights=weights) # type: ignore

async def _asetup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
with self._handle_setup_error(redirector):
assert self._predictor

# Could be a function or a class
if not hasattr(self._predictor, "setup"):
return

if not has_setup_weights(self._predictor):
await self._predictor.setup() # type: ignore
return

weights = extract_setup_weights(self._predictor)
await self._predictor.setup(weights=weights) # type: ignore

def _loop(
self,
Expand Down Expand Up @@ -664,6 +691,40 @@ async def _apredict(
)
)

@contextlib.contextmanager
def _handle_setup_error(
self,
redirector: Union[SimpleStreamRedirector, StreamRedirector],
) -> Iterator[None]:
done = Done()
try:
yield
except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Envelope(
event=Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
)
)
raise
self._events.send(Envelope(event=done))

@contextlib.contextmanager
def _handle_predict_error(
self,
Expand Down

0 comments on commit 541dfae

Please sign in to comment.