diff --git a/python/cog/server/clients.py b/python/cog/server/clients.py index 14725271dd..068ed95d3c 100644 --- a/python/cog/server/clients.py +++ b/python/cog/server/clients.py @@ -7,9 +7,10 @@ import httpx import structlog +from fastapi.encoders import jsonable_encoder from .. import types -from ..schema import Status, WebhookEvent +from ..schema import PredictionResponse, Status, WebhookEvent from ..types import Path from .eventtypes import PredictionInput from .response_throttler import ResponseThrottler @@ -105,6 +106,7 @@ def __init__(self) -> None: self.retry_webhook_client = httpx_retry_client() self.file_client = httpx_file_client() self.download_client = httpx.AsyncClient(follow_redirects=True, http2=True) + self.log = structlog.get_logger(__name__).bind() async def aclose(self) -> None: # not used but it's not actually critical to close them @@ -119,26 +121,29 @@ async def send_webhook( self, url: str, response: Dict[str, Any], event: WebhookEvent ) -> None: if Status.is_terminal(response["status"]): - log.info("sending terminal webhook with status %s", response["status"]) + self.log.info("sending terminal webhook with status %s", response["status"]) # For terminal updates, retry persistently await self.retry_webhook_client.post(url, json=response) else: - log.info("sending webhook with status %s", response["status"]) + self.log.info("sending webhook with status %s", response["status"]) # For other requests, don't retry, and ignore any errors try: await self.webhook_client.post(url, json=response) except httpx.RequestError: - log.warn("caught exception while sending webhook", exc_info=True) + self.log.warn("caught exception while sending webhook", exc_info=True) def make_webhook_sender( self, url: Optional[str], webhook_events_filter: Collection[WebhookEvent] ) -> WebhookSenderType: throttler = ResponseThrottler(response_interval=_response_interval) - async def sender(response: Any, event: WebhookEvent) -> None: + async def sender(response: PredictionResponse, event: WebhookEvent) -> None: if url and event in webhook_events_filter: if throttler.should_send_response(response): - await self.send_webhook(url, response, event) + # jsonable_encoder is quite slow in context, it would be ideal + # to skip the heavy parts of this for well-known output types + dict_response = jsonable_encoder(response.dict(exclude_unset=True)) + await self.send_webhook(url, dict_response, event) throttler.update_last_sent_response_time() return sender @@ -213,6 +218,9 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any: Iterates through an object from make_encodeable and uploads any files. When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files. """ + # skip four isinstance checks for fast text models + if type(obj) == str: # noqa: E721 + return obj # # it would be kind of cleaner to make the default file_url # # instead of skipping entirely, we need to convert to datauri # if url is None: diff --git a/python/cog/server/response_throttler.py b/python/cog/server/response_throttler.py index 41e2ed0312..64c188c667 100644 --- a/python/cog/server/response_throttler.py +++ b/python/cog/server/response_throttler.py @@ -1,7 +1,6 @@ import time -from typing import Any, Dict -from ..schema import Status +from ..schema import PredictionResponse, Status class ResponseThrottler: @@ -9,8 +8,8 @@ def __init__(self, response_interval: float) -> None: self.last_sent_response_time = 0.0 self.response_interval = response_interval - def should_send_response(self, response: Dict[str, Any]) -> bool: - if Status.is_terminal(response["status"]): + def should_send_response(self, response: PredictionResponse) -> bool: + if Status.is_terminal(response.status): return True return self.seconds_since_last_response() >= self.response_interval diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 654495f13e..0a65db8bb4 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -9,7 +9,6 @@ import httpx import structlog from attrs import define -from fastapi.encoders import jsonable_encoder from .. import schema, types from .clients import SKIP_START_EVENT, ClientManager @@ -72,6 +71,9 @@ def __init__( self.client_manager = ClientManager() + # bind logger instead of the module-level logger proxy for performance + self.log = log.bind() + def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]: def handle_error(task: RunnerTask) -> None: exc = task.exception() @@ -83,7 +85,7 @@ def handle_error(task: RunnerTask) -> None: try: raise exc except Exception: - log.error(f"caught exception while running {activity}", exc_info=True) + self.log.error(f"caught exception while running {activity}", exc_info=True) if self._shutdown_event is not None: self._shutdown_event.set() @@ -121,7 +123,7 @@ def predict( # if upload url was not set, we can respect output_file_prefix # but maybe we should just throw an error upload_url = request.output_file_prefix or self._upload_url - event_handler = PredictionEventHandler(request, self.client_manager, upload_url) + event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log) self._response = event_handler.response prediction_input = PredictionInput.from_request(request) @@ -143,13 +145,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse: tb = traceback.format_exc() await event_handler.append_logs(tb) await event_handler.failed(error=str(e)) - log.warn("failed to download url path from input", exc_info=True) + self.log.warn("failed to download url path from input", exc_info=True) return event_handler.response except Exception as e: tb = traceback.format_exc() await event_handler.append_logs(tb) await event_handler.failed(error=str(e)) - log.error("caught exception while running prediction", exc_info=True) + self.log.error("caught exception while running prediction", exc_info=True) if self._shutdown_event is not None: self._shutdown_event.set() raise # we don't actually want to raise anymore but w/e @@ -195,8 +197,10 @@ def __init__( request: schema.PredictionRequest, client_manager: ClientManager, upload_url: Optional[str], + logger: Optional[structlog.BoundLogger] = None, ) -> None: - log.info("starting prediction") + self.logger = logger or log.bind() + self.logger.info("starting prediction") # maybe this should be a deep copy to not share File state with child worker self.p = schema.PredictionResponse(**request.dict()) self.p.status = schema.Status.PROCESSING @@ -244,7 +248,7 @@ async def append_logs(self, logs: str) -> None: await self._send_webhook(schema.WebhookEvent.LOGS) async def succeeded(self) -> None: - log.info("prediction succeeded") + self.logger.info("prediction succeeded") self.p.status = schema.Status.SUCCEEDED self._set_completed_at() # These have been set already: this is to convince the typechecker of @@ -257,14 +261,14 @@ async def succeeded(self) -> None: await self._send_webhook(schema.WebhookEvent.COMPLETED) async def failed(self, error: str) -> None: - log.info("prediction failed", error=error) + self.logger.info("prediction failed", error=error) self.p.status = schema.Status.FAILED self.p.error = error self._set_completed_at() await self._send_webhook(schema.WebhookEvent.COMPLETED) async def canceled(self) -> None: - log.info("prediction canceled") + self.logger.info("prediction canceled") self.p.status = schema.Status.CANCELED self._set_completed_at() await self._send_webhook(schema.WebhookEvent.COMPLETED) @@ -273,8 +277,7 @@ def _set_completed_at(self) -> None: self.p.completed_at = datetime.now(tz=timezone.utc) async def _send_webhook(self, event: schema.WebhookEvent) -> None: - dict_response = jsonable_encoder(self.response.dict(exclude_unset=True)) - await self._webhook_sender(dict_response, event) + await self._webhook_sender(self.response, event) async def _upload_files(self, output: Any) -> Any: try: diff --git a/python/tests/server/test_response_throttler.py b/python/tests/server/test_response_throttler.py index 99985d337a..34ab339a48 100644 --- a/python/tests/server/test_response_throttler.py +++ b/python/tests/server/test_response_throttler.py @@ -1,35 +1,38 @@ import time -from cog.schema import Status +from cog.schema import PredictionResponse, Status from cog.server.response_throttler import ResponseThrottler +processing = PredictionResponse(input={}, status=Status.PROCESSING) +succeeded = PredictionResponse(input={}, status=Status.SUCCEEDED) + def test_zero_interval(): throttler = ResponseThrottler(response_interval=0) - assert throttler.should_send_response({"status": Status.PROCESSING}) + assert throttler.should_send_response(processing) throttler.update_last_sent_response_time() - assert throttler.should_send_response({"status": Status.SUCCEEDED}) + assert throttler.should_send_response(succeeded) def test_terminal_status(): throttler = ResponseThrottler(response_interval=10) - assert throttler.should_send_response({"status": Status.PROCESSING}) + assert throttler.should_send_response(processing) throttler.update_last_sent_response_time() - assert not throttler.should_send_response({"status": Status.PROCESSING}) + assert not throttler.should_send_response(processing) throttler.update_last_sent_response_time() - assert throttler.should_send_response({"status": Status.SUCCEEDED}) + assert throttler.should_send_response(succeeded) def test_nonzero_internal(): throttler = ResponseThrottler(response_interval=0.2) - assert throttler.should_send_response({"status": Status.PROCESSING}) + assert throttler.should_send_response(processing) throttler.update_last_sent_response_time() - assert not throttler.should_send_response({"status": Status.PROCESSING}) + assert not throttler.should_send_response(processing) throttler.update_last_sent_response_time() time.sleep(0.3) - assert throttler.should_send_response({"status": Status.PROCESSING}) + assert throttler.should_send_response(processing)