Skip to content

Commit

Permalink
[SDESK-7287] Support async tasks in Celery (superdesk#2658)
Browse files Browse the repository at this point in the history
* Implement hybrid async/sync Task class

SDESK-7287

* Improve task error handling and add tests

SDESK-7287

* Rename module

SDESK-7287

* Minor improvements to await task's result

SDESK-7287

* Add a couple of TODOs

SDESK-7287
  • Loading branch information
eos87 authored Aug 13, 2024
1 parent 966d380 commit 3ceeb01
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 144 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ _build
_build_html
__pycache__
.eggs
.vscode
95 changes: 11 additions & 84 deletions superdesk/celery_app.py → superdesk/celery_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,29 @@
#
# This file is part of Superdesk.
#
# Copyright 2013, 2014 Sourcefabric z.u. and contributors.
# Copyright 2013 to present Sourcefabric z.u. and contributors.
#
# For the full copyright and license information, please see the
# AUTHORS and LICENSE files distributed with this source code, or
# at https://www.sourcefabric.org/superdesk/license

"""
Created on May 29, 2014
@author: ioan
"""

import redis
import arrow
import werkzeug
import superdesk
from bson import ObjectId
from celery import Celery
from kombu.serialization import register
from eve.io.mongo import MongoJSONEncoder
from eve.utils import str_to_date
from superdesk.core import json, get_current_app, get_app_config
from superdesk.errors import SuperdeskError
from superdesk.logging import logger


celery = Celery(__name__)
TaskBase = celery.Task


def try_cast(v):
# False and 0 are getting converted to datetime by arrow
if v is None or isinstance(v, bool) or v == 0:
return v

try:
str_to_date(v) # try if it matches format
return arrow.get(v).datetime # return timezone aware time
except Exception:
try:
return ObjectId(v)
except Exception:
return v


def dumps(o):
with get_current_app().app_context():
return MongoJSONEncoder().encode(o)

from .context_task import HybridAppContextTask
from .serializer import CELERY_SERIALIZER_NAME, ContextAwareSerializerFactory

def loads(s):
o = json.loads(s)
with get_current_app().app_context():
return serialize(o)


def serialize(o):
if isinstance(o, list):
return [serialize(item) for item in o]
elif isinstance(o, dict):
if o.get("kwargs") and not isinstance(o["kwargs"], dict):
o["kwargs"] = json.loads(o["kwargs"])
return {k: serialize(v) for k, v in o.items()}
else:
return try_cast(o)


register("eve/json", dumps, loads, content_type="application/json")


def handle_exception(exc):
"""Log exception to logger."""
logger.exception(exc)


class AppContextTask(TaskBase): # type: ignore
abstract = True
serializer = "eve/json"
app_errors = (
SuperdeskError,
werkzeug.exceptions.InternalServerError, # mongo layer err
)

def __call__(self, *args, **kwargs):
with get_current_app().app_context():
try:
return super().__call__(*args, **kwargs)
except self.app_errors as e:
handle_exception(e)
from superdesk.logging import logger
from superdesk.core import get_current_app, get_app_config

def on_failure(self, exc, task_id, args, kwargs, einfo):
with get_current_app().app_context():
handle_exception(exc)

# custom serializer with Kombu for Celery's message serialization
serializer_factory = ContextAwareSerializerFactory(get_current_app)
serializer_factory.register_serializer(CELERY_SERIALIZER_NAME)

celery.Task = AppContextTask
# set up celery with our custom Task which handles async/sync tasks + app context
celery = Celery(__name__)
celery.Task = HybridAppContextTask


def init_celery(app):
Expand Down
97 changes: 97 additions & 0 deletions superdesk/celery_app/context_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import asyncio
import werkzeug

from celery import Task
from typing import Any, Callable, Tuple, Dict

from superdesk.logging import logger
from superdesk.errors import SuperdeskError
from superdesk.celery_app.serializer import CELERY_SERIALIZER_NAME


class HybridAppContextTask(Task):
"""
A task class that supports running both synchronous and asynchronous tasks within the Flask application context.
It handles exceptions specifically defined in `app_errors` and logs them.
"""

abstract = True
serializer = CELERY_SERIALIZER_NAME
app_errors = (SuperdeskError, werkzeug.exceptions.InternalServerError)

def get_current_app(self):
"""
Method that is intended to be overwritten so the module gets to use the right app
context
"""
from superdesk.core import get_current_app

return get_current_app()

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Executes the task function, determining if it should be run synchronously or asynchronously.
Args:
args: Positional arguments passed to the task function.
kwargs: Keyword arguments passed to the task function.
"""
# TODO-ASYNC: update once we are fully using Quart
with self.get_current_app().app_context():
task_func = self.run

try:
# handle async tasks if needed
if asyncio.iscoroutinefunction(task_func):
return self.run_async(task_func, *args, **kwargs)

# run sync otherwise
return super().__call__(*args, **kwargs)
except self.app_errors as e:
self.handle_exception(e)

def run_async(self, task_func: Callable, *args: Any, **kwargs: Any) -> Any:
"""
Runs the task asynchronously, utilizing the current asyncio event loop. Captures
and handles exceptions defined in `app_errors`.
Args:
task_func: The coroutine function representing the task to be executed.
args: Positional arguments for the task.
kwargs: Keyword arguments for the task.
Returns:
If the event loop is running, returns an asyncio.Task that represents the execution of the coroutine.
Otherwise it runs the tasks and returns the result of the task.
"""

loop = asyncio.get_event_loop()

# We need a wrapper to handle exceptions inside the async function because asyncio
# does not propagate them in the same way as synchronous exceptions. This ensures that
# all exceptions are managed and logged regardless of where they occur within the event loop
async def wrapper():
try:
return await task_func(*args, **kwargs)
except self.app_errors as e:
self.handle_exception(e)
return None

if not loop.is_running():
return loop.run_until_complete(wrapper())

return asyncio.create_task(wrapper())

def handle_exception(self, exc: Exception) -> None:
"""
Logs an exception using the configured logger from `superdesk.logging`.
"""
logger.exception(f"Error handling task: {str(exc)}")

def on_failure(self, exc: Exception, task_id: str, args: Tuple, kwargs: Dict, einfo: str) -> None:
"""
Handles task failure by logging the exception within the Flask application context.
"""
# TODO-ASYNC: update once we are fully using Quart
with self.get_current_app().app_context():
self.handle_exception(exc)
109 changes: 109 additions & 0 deletions superdesk/celery_app/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import arrow
from bson import ObjectId
from typing import Any, Callable

from eve.utils import str_to_date
from eve.io.mongo import MongoJSONEncoder
from kombu.serialization import register

from superdesk.core import json
from superdesk.core.web.types import WSGIApp


CELERY_SERIALIZER_NAME = "context-aware/json"


class ContextAwareSerializerFactory:
"""
A factory class for creating serializers that automatically handle
the execution within a specific application context.
"""

def __init__(self, get_current_app: Callable[[], WSGIApp]):
"""
Initializes the ContextAwareSerializerFactory with a callable to retrieve the current application context.
Args:
get_current_app: A callable that returns the Flask/WSGIApp application context.
"""
self.get_current_app = get_current_app

def try_cast(self, value: Any) -> str | Any:
"""
Tries to cast the given value to an appropriate type (datetime or ObjectId) or returns it unchanged.
Args:
v (Any): The value to be casted.
Returns:
Any: The casted value, or the original value if no casting was possible.
"""
if value is None or isinstance(value, bool) or value == 0:
return value

try:
str_to_date(value)
return arrow.get(value).datetime # timezone aware time

except Exception:
try:
return ObjectId(value)
except Exception:
return value

def dumps(self, o: Any) -> str:
"""
Serializes the given object into a JSON string, executing within the application context.
Args:
o (Any): The object to serialize.
Returns:
str: The serialized JSON string.
"""
with self.get_current_app().app_context():
return MongoJSONEncoder().encode(o)

def loads(self, s: str) -> Any:
"""
Deserializes the given JSON string into a Python object, executing within the application context.
Args:
s (str): The JSON string to deserialize.
Returns:
Any: The deserialized object.
"""
o = json.loads(s)
with self.get_current_app().app_context():
return self.serialize(o)

def serialize(self, o: Any) -> Any:
"""
Recursively serializes complex objects such as lists and dictionaries.
Args:
o (Any): The object to serialize.
Returns:
Any: The serialized object.
"""

if isinstance(o, list):
return [self.serialize(item) for item in o]
elif isinstance(o, dict):
if o.get("kwargs") and not isinstance(o["kwargs"], dict):
o["kwargs"] = json.loads(o["kwargs"])
return {k: self.serialize(v) for k, v in o.items()}
else:
return self.try_cast(o)

def register_serializer(self, name: str, content_type: str = "application/json") -> None:
"""
Registers a custom serializer with Kombu, which is used by Celery for message serialization.
Args:
name (str): The name under which the serializer should be registered.
content_type (str): The MIME type associated with the serializer.
"""
register(name, self.dumps, self.loads, content_type=content_type)
11 changes: 8 additions & 3 deletions superdesk/io/commands/update_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@


import bson
import pytz
import logging
import superdesk

from datetime import timedelta, timezone, datetime
import pytz
from werkzeug.exceptions import HTTPException

from superdesk.celery_app import CELERY_SERIALIZER_NAME
from superdesk.core import get_app_config, get_current_app
from superdesk.resource_fields import ID_FIELD
import superdesk

from superdesk.activity import ACTIVITY_EVENT, notify_and_add_activity
from superdesk.celery_app import celery
from superdesk.celery_task_utils import get_lock_id
Expand Down Expand Up @@ -270,7 +273,9 @@ def run(self, provider_name=None, sync=False):
if sync:
update_provider.apply(kwargs=kwargs)
else:
update_provider.apply_async(expires=get_task_ttl(provider), kwargs=kwargs, serializer="eve/json")
update_provider.apply_async(
expires=get_task_ttl(provider), kwargs=kwargs, serializer=CELERY_SERIALIZER_NAME
)


def update_last_item_updated(update, items):
Expand Down
Empty file added tests/celery_app/__init__.py
Empty file.
Loading

0 comments on commit 3ceeb01

Please sign in to comment.