Skip to content

Commit

Permalink
optimize webhook serialization and logging (#1651)
Browse files Browse the repository at this point in the history
* optimize webhook serialization and logging
* optimize logging by binding structlog proxies
* fix tests

---------

Signed-off-by: technillogue <[email protected]>
  • Loading branch information
technillogue committed May 8, 2024
1 parent 12b0abe commit 08f3780
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 30 deletions.
20 changes: 14 additions & 6 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import httpx
import structlog
from fastapi.encoders import jsonable_encoder

from .. import types
from ..schema import Status, WebhookEvent
from ..schema import PredictionResponse, Status, WebhookEvent
from ..types import Path
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(self) -> None:
self.retry_webhook_client = httpx_retry_client()
self.file_client = httpx_file_client()
self.download_client = httpx.AsyncClient(follow_redirects=True, http2=True)
self.log = structlog.get_logger(__name__).bind()

async def aclose(self) -> None:
# not used but it's not actually critical to close them
Expand All @@ -119,26 +121,29 @@ async def send_webhook(
self, url: str, response: Dict[str, Any], event: WebhookEvent
) -> None:
if Status.is_terminal(response["status"]):
log.info("sending terminal webhook with status %s", response["status"])
self.log.info("sending terminal webhook with status %s", response["status"])
# For terminal updates, retry persistently
await self.retry_webhook_client.post(url, json=response)
else:
log.info("sending webhook with status %s", response["status"])
self.log.info("sending webhook with status %s", response["status"])
# For other requests, don't retry, and ignore any errors
try:
await self.webhook_client.post(url, json=response)
except httpx.RequestError:
log.warn("caught exception while sending webhook", exc_info=True)
self.log.warn("caught exception while sending webhook", exc_info=True)

def make_webhook_sender(
self, url: Optional[str], webhook_events_filter: Collection[WebhookEvent]
) -> WebhookSenderType:
throttler = ResponseThrottler(response_interval=_response_interval)

async def sender(response: Any, event: WebhookEvent) -> None:
async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
if url and event in webhook_events_filter:
if throttler.should_send_response(response):
await self.send_webhook(url, response, event)
# jsonable_encoder is quite slow in context, it would be ideal
# to skip the heavy parts of this for well-known output types
dict_response = jsonable_encoder(response.dict(exclude_unset=True))
await self.send_webhook(url, dict_response, event)
throttler.update_last_sent_response_time()

return sender
Expand Down Expand Up @@ -213,6 +218,9 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
Iterates through an object from make_encodeable and uploads any files.
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
"""
# skip four isinstance checks for fast text models
if type(obj) == str: # noqa: E721
return obj
# # it would be kind of cleaner to make the default file_url
# # instead of skipping entirely, we need to convert to datauri
# if url is None:
Expand Down
7 changes: 3 additions & 4 deletions python/cog/server/response_throttler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import time
from typing import Any, Dict

from ..schema import Status
from ..schema import PredictionResponse, Status


class ResponseThrottler:
def __init__(self, response_interval: float) -> None:
self.last_sent_response_time = 0.0
self.response_interval = response_interval

def should_send_response(self, response: Dict[str, Any]) -> bool:
if Status.is_terminal(response["status"]):
def should_send_response(self, response: PredictionResponse) -> bool:
if Status.is_terminal(response.status):
return True

return self.seconds_since_last_response() >= self.response_interval
Expand Down
25 changes: 14 additions & 11 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import httpx
import structlog
from attrs import define
from fastapi.encoders import jsonable_encoder

from .. import schema, types
from .clients import SKIP_START_EVENT, ClientManager
Expand Down Expand Up @@ -72,6 +71,9 @@ def __init__(

self.client_manager = ClientManager()

# bind logger instead of the module-level logger proxy for performance
self.log = log.bind()

def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
Expand All @@ -83,7 +85,7 @@ def handle_error(task: RunnerTask) -> None:
try:
raise exc
except Exception:
log.error(f"caught exception while running {activity}", exc_info=True)
self.log.error(f"caught exception while running {activity}", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

Expand Down Expand Up @@ -121,7 +123,7 @@ def predict(
# if upload url was not set, we can respect output_file_prefix
# but maybe we should just throw an error
upload_url = request.output_file_prefix or self._upload_url
event_handler = PredictionEventHandler(request, self.client_manager, upload_url)
event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log)
self._response = event_handler.response

prediction_input = PredictionInput.from_request(request)
Expand All @@ -143,13 +145,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
self.log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
except Exception as e:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
log.error("caught exception while running prediction", exc_info=True)
self.log.error("caught exception while running prediction", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()
raise # we don't actually want to raise anymore but w/e
Expand Down Expand Up @@ -195,8 +197,10 @@ def __init__(
request: schema.PredictionRequest,
client_manager: ClientManager,
upload_url: Optional[str],
logger: Optional[structlog.BoundLogger] = None,
) -> None:
log.info("starting prediction")
self.logger = logger or log.bind()
self.logger.info("starting prediction")
# maybe this should be a deep copy to not share File state with child worker
self.p = schema.PredictionResponse(**request.dict())
self.p.status = schema.Status.PROCESSING
Expand Down Expand Up @@ -244,7 +248,7 @@ async def append_logs(self, logs: str) -> None:
await self._send_webhook(schema.WebhookEvent.LOGS)

async def succeeded(self) -> None:
log.info("prediction succeeded")
self.logger.info("prediction succeeded")
self.p.status = schema.Status.SUCCEEDED
self._set_completed_at()
# These have been set already: this is to convince the typechecker of
Expand All @@ -257,14 +261,14 @@ async def succeeded(self) -> None:
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def failed(self, error: str) -> None:
log.info("prediction failed", error=error)
self.logger.info("prediction failed", error=error)
self.p.status = schema.Status.FAILED
self.p.error = error
self._set_completed_at()
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def canceled(self) -> None:
log.info("prediction canceled")
self.logger.info("prediction canceled")
self.p.status = schema.Status.CANCELED
self._set_completed_at()
await self._send_webhook(schema.WebhookEvent.COMPLETED)
Expand All @@ -273,8 +277,7 @@ def _set_completed_at(self) -> None:
self.p.completed_at = datetime.now(tz=timezone.utc)

async def _send_webhook(self, event: schema.WebhookEvent) -> None:
dict_response = jsonable_encoder(self.response.dict(exclude_unset=True))
await self._webhook_sender(dict_response, event)
await self._webhook_sender(self.response, event)

async def _upload_files(self, output: Any) -> Any:
try:
Expand Down
21 changes: 12 additions & 9 deletions python/tests/server/test_response_throttler.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
import time

from cog.schema import Status
from cog.schema import PredictionResponse, Status
from cog.server.response_throttler import ResponseThrottler

processing = PredictionResponse(input={}, status=Status.PROCESSING)
succeeded = PredictionResponse(input={}, status=Status.SUCCEEDED)


def test_zero_interval():
throttler = ResponseThrottler(response_interval=0)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_terminal_status():
throttler = ResponseThrottler(response_interval=10)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_nonzero_internal():
throttler = ResponseThrottler(response_interval=0.2)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()

time.sleep(0.3)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)

0 comments on commit 08f3780

Please sign in to comment.