Skip to content

Commit

Permalink
feat: add custom app exception handling & auto CUDA OOM exception han…
Browse files Browse the repository at this point in the history
…dling (#274)

* feat: add app exception handling

* feat: add exception handlers

* feat: add tests

* refactor: use the old style

* feat: check for cuda message too

* fix: remove app start

* refactor: use app client

* fix: app client startup

* fix: remove limit on log test
  • Loading branch information
badayvedat authored Aug 4, 2024
1 parent 5d6799d commit c010931
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 21 deletions.
42 changes: 37 additions & 5 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
import fal.flags as flags
from fal._serialization import include_modules_from, patch_pickle
from fal.container import ContainerImage
from fal.exceptions import FalServerlessException
from fal.exceptions import (
AppException,
CUDAOutOfMemoryException,
FalServerlessException,
FieldException,
)
from fal.exceptions._cuda import _is_cuda_oom_exception
from fal.logging.isolate import IsolateLogPrinter
from fal.sdk import (
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
Expand Down Expand Up @@ -1002,13 +1008,39 @@ async def not_found_exception_handler(request: Request, exc: HTTPException):
# If it's not a generic 404, just return the original message.
return JSONResponse({"detail": exc.detail}, 404)

@_app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
return JSONResponse({"detail": exc.message}, exc.status_code)

@_app.exception_handler(FieldException)
async def field_exception_handler(request: Request, exc: FieldException):
return JSONResponse(exc.to_pydantic_format(), exc.status_code)

@_app.exception_handler(CUDAOutOfMemoryException)
async def cuda_out_of_memory_exception_handler(
request: Request, exc: CUDAOutOfMemoryException
):
return JSONResponse({"detail": exc.message}, exc.status_code)

@_app.exception_handler(Exception)
async def traceback_logging_exception_handler(request: Request, exc: Exception):
print(
json.dumps(
{"traceback": "".join(traceback.format_exception(exc)[::-1])} # type: ignore
_, MINOR, *_ = sys.version_info

# traceback.format_exception() has a different signature in Python >=3.10
if MINOR >= 10:
formatted_exception = traceback.format_exception(exc) # type: ignore
else:
formatted_exception = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
)

print(json.dumps({"traceback": "".join(formatted_exception[::-1])}))

if _is_cuda_oom_exception(exc):
return await cuda_out_of_memory_exception_handler(
request, CUDAOutOfMemoryException()
)

return JSONResponse({"detail": "Internal Server Error"}, 500)

routes = self.collect_routes()
Expand Down
5 changes: 2 additions & 3 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ def _print_logs():
try:
with httpx.Client() as client:
retries = 100
while retries:
for _ in range(retries):
resp = client.get(info.url + "/health")

if resp.is_success:
break
elif resp.status_code != 500:
elif resp.status_code not in (500, 404):
resp.raise_for_status()
time.sleep(0.1)
retries -= 1

client = cls(app_cls, info.url)
yield client
Expand Down
3 changes: 2 additions & 1 deletion projects/fal/src/fal/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations

from ._base import FalServerlessException # noqa: F401
from ._base import AppException, FalServerlessException, FieldException # noqa: F401
from ._cuda import CUDAOutOfMemoryException # noqa: F401
43 changes: 43 additions & 0 deletions projects/fal/src/fal/exceptions/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence


class FalServerlessException(Exception):
"""Base exception type for fal Serverless related flows and APIs."""

pass


@dataclass
class AppException(FalServerlessException):
"""
Base exception class for application-specific errors.
Attributes:
message: A descriptive message explaining the error.
status_code: The HTTP status code associated with the error.
"""

message: str
status_code: int


@dataclass
class FieldException(FalServerlessException):
"""Exception raised for errors related to specific fields.
Attributes:
field: The field that caused the error.
message: A descriptive message explaining the error.
status_code: The HTTP status code associated with the error. Defaults to 422
type: The type of error. Defaults to "value_error"
"""

field: str
message: str
status_code: int = 422
type: str = "value_error"

def to_pydantic_format(self) -> Sequence[dict]:
return [
{
"loc": ["body", self.field],
"msg": self.message,
"type": self.type,
}
]
44 changes: 44 additions & 0 deletions projects/fal/src/fal/exceptions/_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from dataclasses import dataclass

from ._base import AppException

# PyTorch error message for out of memory
_CUDA_OOM_MESSAGE = "CUDA error: out of memory"

# Special status code for CUDA out of memory errors
_CUDA_OOM_STATUS_CODE = 503


@dataclass
class CUDAOutOfMemoryException(AppException):
"""Exception raised when a CUDA operation runs out of memory."""

message: str = _CUDA_OOM_MESSAGE
status_code: int = _CUDA_OOM_STATUS_CODE


# based on https://github.com/Lightning-AI/pytorch-lightning/blob/37e04d075a5532c69b8ac7457795b4345cca30cc/src/lightning/pytorch/utilities/memory.py#L49
def _is_cuda_oom_exception(exception: BaseException) -> bool:
return _is_cuda_out_of_memory(exception) or _is_cudnn_snafu(exception)


# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def _is_cuda_out_of_memory(exception: BaseException) -> bool:
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "CUDA" in exception.args[0]
and "out of memory" in exception.args[0]
)


# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def _is_cudnn_snafu(exception: BaseException) -> bool:
# For/because of https://github.com/pytorch/pytorch/issues/4107
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
)
57 changes: 45 additions & 12 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
import httpx
import pytest
from fal import apps
from fal.app import AppClient
from fal.cli.deploy import _get_user
from fal.container import ContainerImage
from fal.exceptions import AppException, FieldException
from fal.exceptions._cuda import _CUDA_OOM_MESSAGE, _CUDA_OOM_STATUS_CODE
from fal.rest_client import REST_CLIENT
from fal.workflows import Workflow
from fastapi import WebSocket
Expand Down Expand Up @@ -139,9 +142,25 @@ class ExceptionApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"

@fal.endpoint("/fail")
def reset(self) -> Output:
def fail(self) -> Output:
raise Exception("this app is designed to fail!")

@fal.endpoint("/app-exception")
def app_exception(self) -> Output:
raise AppException(message="this app is designed to fail", status_code=401)

@fal.endpoint("/field-exception")
def field_exception(self, input: Input) -> Output:
raise FieldException(
field="rhs",
message="rhs must be an integer",
)

@fal.endpoint("/cuda-exception")
def cuda_exception(self) -> Output:
# mimicking error message from PyTorch (https://github.com/pytorch/pytorch/blob/6c65fd03942415b68040e102c44cf5109d2d851e/c10/cuda/CUDACachingAllocator.cpp#L1234C12-L1234C30)
raise RuntimeError("CUDA out of memory")


class RTInput(BaseModel):
prompt: str
Expand Down Expand Up @@ -270,14 +289,8 @@ def test_stateful_app():
@pytest.fixture(scope="module")
def test_exception_app():
# Create a temporary app, register it, and return the ID of it.

app = fal.wrap_app(ExceptionApp)
app_revision = app.host.register(
func=app.func,
options=app.options,
)
user = _get_user()
yield f"{user.user_id}/{app_revision}"
with AppClient.connect(ExceptionApp) as client:
yield client


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -592,10 +605,12 @@ def test_workflows(test_app: str):
assert data["result"] == 10


def test_traceback_logs(test_exception_app: str):
def test_traceback_logs(test_exception_app: AppClient):
date = datetime.utcnow().isoformat()

with pytest.raises(HTTPStatusError):
apps.run(test_exception_app, arguments={}, path="/fail")
test_exception_app.fail({})

with httpx.Client(
base_url=REST_CLIENT.base_url,
headers=REST_CLIENT.get_headers(),
Expand All @@ -604,11 +619,29 @@ def test_traceback_logs(test_exception_app: str):
# Give some time for logs to propagate through the logging subsystem.
time.sleep(5)
response = client.get(
REST_CLIENT.base_url + f"/logs/?traceback=true&limit=10&since={date}"
REST_CLIENT.base_url + f"/logs/?traceback=true&since={date}"
)
for log in json.loads(response.text):
assert log["message"].count("\n") > 1, "Logs are multi-line"
assert '{"traceback":' not in log["message"], "Logs are not JSON-wrapped"
assert (
"this app is designed to fail" in log["message"]
), "Logs contain the traceback message"


def test_app_exceptions(test_exception_app: AppClient):
with pytest.raises(HTTPStatusError) as app_exc:
test_exception_app.app_exception({})

assert app_exc.value.response.status_code == 401

with pytest.raises(HTTPStatusError) as field_exc:
test_exception_app.field_exception({"lhs": 1, "rhs": "2"})

assert field_exc.value.response.status_code == 422

with pytest.raises(HTTPStatusError) as cuda_exc:
test_exception_app.cuda_exception({})

assert cuda_exc.value.response.status_code == _CUDA_OOM_STATUS_CODE
assert cuda_exc.value.response.json()["detail"] == _CUDA_OOM_MESSAGE

0 comments on commit c010931

Please sign in to comment.