diff --git a/.gitignore b/.gitignore index 42d1c07026..a15dac2aec 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ _build _build_html __pycache__ .eggs +.vscode diff --git a/superdesk/celery_app.py b/superdesk/celery_app/__init__.py similarity index 52% rename from superdesk/celery_app.py rename to superdesk/celery_app/__init__.py index 9f8d5ce2b3..0880af45a3 100644 --- a/superdesk/celery_app.py +++ b/superdesk/celery_app/__init__.py @@ -2,102 +2,29 @@ # # This file is part of Superdesk. # -# Copyright 2013, 2014 Sourcefabric z.u. and contributors. +# Copyright 2013 to present Sourcefabric z.u. and contributors. # # For the full copyright and license information, please see the # AUTHORS and LICENSE files distributed with this source code, or # at https://www.sourcefabric.org/superdesk/license -""" -Created on May 29, 2014 - -@author: ioan -""" - import redis -import arrow -import werkzeug -import superdesk -from bson import ObjectId from celery import Celery -from kombu.serialization import register -from eve.io.mongo import MongoJSONEncoder -from eve.utils import str_to_date -from superdesk.core import json, get_current_app, get_app_config -from superdesk.errors import SuperdeskError -from superdesk.logging import logger - - -celery = Celery(__name__) -TaskBase = celery.Task - - -def try_cast(v): - # False and 0 are getting converted to datetime by arrow - if v is None or isinstance(v, bool) or v == 0: - return v - - try: - str_to_date(v) # try if it matches format - return arrow.get(v).datetime # return timezone aware time - except Exception: - try: - return ObjectId(v) - except Exception: - return v - - -def dumps(o): - with get_current_app().app_context(): - return MongoJSONEncoder().encode(o) +from .context_task import HybridAppContextTask +from .serializer import CELERY_SERIALIZER_NAME, ContextAwareSerializerFactory -def loads(s): - o = json.loads(s) - with get_current_app().app_context(): - return serialize(o) - - -def serialize(o): - if isinstance(o, list): - return [serialize(item) for item in o] - elif isinstance(o, dict): - if o.get("kwargs") and not isinstance(o["kwargs"], dict): - o["kwargs"] = json.loads(o["kwargs"]) - return {k: serialize(v) for k, v in o.items()} - else: - return try_cast(o) - - -register("eve/json", dumps, loads, content_type="application/json") - - -def handle_exception(exc): - """Log exception to logger.""" - logger.exception(exc) - - -class AppContextTask(TaskBase): # type: ignore - abstract = True - serializer = "eve/json" - app_errors = ( - SuperdeskError, - werkzeug.exceptions.InternalServerError, # mongo layer err - ) - - def __call__(self, *args, **kwargs): - with get_current_app().app_context(): - try: - return super().__call__(*args, **kwargs) - except self.app_errors as e: - handle_exception(e) +from superdesk.logging import logger +from superdesk.core import get_current_app, get_app_config - def on_failure(self, exc, task_id, args, kwargs, einfo): - with get_current_app().app_context(): - handle_exception(exc) +# custom serializer with Kombu for Celery's message serialization +serializer_factory = ContextAwareSerializerFactory(get_current_app) +serializer_factory.register_serializer(CELERY_SERIALIZER_NAME) -celery.Task = AppContextTask +# set up celery with our custom Task which handles async/sync tasks + app context +celery = Celery(__name__) +celery.Task = HybridAppContextTask def init_celery(app): diff --git a/superdesk/celery_app/context_task.py b/superdesk/celery_app/context_task.py new file mode 100644 index 0000000000..dbbf9fca45 --- /dev/null +++ b/superdesk/celery_app/context_task.py @@ -0,0 +1,97 @@ +import asyncio +import werkzeug + +from celery import Task +from typing import Any, Callable, Tuple, Dict + +from superdesk.logging import logger +from superdesk.errors import SuperdeskError +from superdesk.celery_app.serializer import CELERY_SERIALIZER_NAME + + +class HybridAppContextTask(Task): + """ + A task class that supports running both synchronous and asynchronous tasks within the Flask application context. + It handles exceptions specifically defined in `app_errors` and logs them. + """ + + abstract = True + serializer = CELERY_SERIALIZER_NAME + app_errors = (SuperdeskError, werkzeug.exceptions.InternalServerError) + + def get_current_app(self): + """ + Method that is intended to be overwritten so the module gets to use the right app + context + """ + from superdesk.core import get_current_app + + return get_current_app() + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Executes the task function, determining if it should be run synchronously or asynchronously. + + Args: + args: Positional arguments passed to the task function. + kwargs: Keyword arguments passed to the task function. + """ + # TODO-ASYNC: update once we are fully using Quart + with self.get_current_app().app_context(): + task_func = self.run + + try: + # handle async tasks if needed + if asyncio.iscoroutinefunction(task_func): + return self.run_async(task_func, *args, **kwargs) + + # run sync otherwise + return super().__call__(*args, **kwargs) + except self.app_errors as e: + self.handle_exception(e) + + def run_async(self, task_func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Runs the task asynchronously, utilizing the current asyncio event loop. Captures + and handles exceptions defined in `app_errors`. + + Args: + task_func: The coroutine function representing the task to be executed. + args: Positional arguments for the task. + kwargs: Keyword arguments for the task. + + Returns: + If the event loop is running, returns an asyncio.Task that represents the execution of the coroutine. + Otherwise it runs the tasks and returns the result of the task. + """ + + loop = asyncio.get_event_loop() + + # We need a wrapper to handle exceptions inside the async function because asyncio + # does not propagate them in the same way as synchronous exceptions. This ensures that + # all exceptions are managed and logged regardless of where they occur within the event loop + async def wrapper(): + try: + return await task_func(*args, **kwargs) + except self.app_errors as e: + self.handle_exception(e) + return None + + if not loop.is_running(): + return loop.run_until_complete(wrapper()) + + return asyncio.create_task(wrapper()) + + def handle_exception(self, exc: Exception) -> None: + """ + Logs an exception using the configured logger from `superdesk.logging`. + """ + logger.exception(f"Error handling task: {str(exc)}") + + def on_failure(self, exc: Exception, task_id: str, args: Tuple, kwargs: Dict, einfo: str) -> None: + """ + Handles task failure by logging the exception within the Flask application context. + """ + # TODO-ASYNC: update once we are fully using Quart + with self.get_current_app().app_context(): + self.handle_exception(exc) diff --git a/superdesk/celery_app/serializer.py b/superdesk/celery_app/serializer.py new file mode 100644 index 0000000000..24710b4f32 --- /dev/null +++ b/superdesk/celery_app/serializer.py @@ -0,0 +1,109 @@ +import arrow +from bson import ObjectId +from typing import Any, Callable + +from eve.utils import str_to_date +from eve.io.mongo import MongoJSONEncoder +from kombu.serialization import register + +from superdesk.core import json +from superdesk.core.web.types import WSGIApp + + +CELERY_SERIALIZER_NAME = "context-aware/json" + + +class ContextAwareSerializerFactory: + """ + A factory class for creating serializers that automatically handle + the execution within a specific application context. + """ + + def __init__(self, get_current_app: Callable[[], WSGIApp]): + """ + Initializes the ContextAwareSerializerFactory with a callable to retrieve the current application context. + + Args: + get_current_app: A callable that returns the Flask/WSGIApp application context. + """ + self.get_current_app = get_current_app + + def try_cast(self, value: Any) -> str | Any: + """ + Tries to cast the given value to an appropriate type (datetime or ObjectId) or returns it unchanged. + + Args: + v (Any): The value to be casted. + + Returns: + Any: The casted value, or the original value if no casting was possible. + """ + if value is None or isinstance(value, bool) or value == 0: + return value + + try: + str_to_date(value) + return arrow.get(value).datetime # timezone aware time + + except Exception: + try: + return ObjectId(value) + except Exception: + return value + + def dumps(self, o: Any) -> str: + """ + Serializes the given object into a JSON string, executing within the application context. + + Args: + o (Any): The object to serialize. + + Returns: + str: The serialized JSON string. + """ + with self.get_current_app().app_context(): + return MongoJSONEncoder().encode(o) + + def loads(self, s: str) -> Any: + """ + Deserializes the given JSON string into a Python object, executing within the application context. + + Args: + s (str): The JSON string to deserialize. + + Returns: + Any: The deserialized object. + """ + o = json.loads(s) + with self.get_current_app().app_context(): + return self.serialize(o) + + def serialize(self, o: Any) -> Any: + """ + Recursively serializes complex objects such as lists and dictionaries. + + Args: + o (Any): The object to serialize. + + Returns: + Any: The serialized object. + """ + + if isinstance(o, list): + return [self.serialize(item) for item in o] + elif isinstance(o, dict): + if o.get("kwargs") and not isinstance(o["kwargs"], dict): + o["kwargs"] = json.loads(o["kwargs"]) + return {k: self.serialize(v) for k, v in o.items()} + else: + return self.try_cast(o) + + def register_serializer(self, name: str, content_type: str = "application/json") -> None: + """ + Registers a custom serializer with Kombu, which is used by Celery for message serialization. + + Args: + name (str): The name under which the serializer should be registered. + content_type (str): The MIME type associated with the serializer. + """ + register(name, self.dumps, self.loads, content_type=content_type) diff --git a/superdesk/io/commands/update_ingest.py b/superdesk/io/commands/update_ingest.py index 3a81c9de43..817951258e 100644 --- a/superdesk/io/commands/update_ingest.py +++ b/superdesk/io/commands/update_ingest.py @@ -10,14 +10,17 @@ import bson +import pytz import logging +import superdesk + from datetime import timedelta, timezone, datetime -import pytz from werkzeug.exceptions import HTTPException +from superdesk.celery_app import CELERY_SERIALIZER_NAME from superdesk.core import get_app_config, get_current_app from superdesk.resource_fields import ID_FIELD -import superdesk + from superdesk.activity import ACTIVITY_EVENT, notify_and_add_activity from superdesk.celery_app import celery from superdesk.celery_task_utils import get_lock_id @@ -270,7 +273,9 @@ def run(self, provider_name=None, sync=False): if sync: update_provider.apply(kwargs=kwargs) else: - update_provider.apply_async(expires=get_task_ttl(provider), kwargs=kwargs, serializer="eve/json") + update_provider.apply_async( + expires=get_task_ttl(provider), kwargs=kwargs, serializer=CELERY_SERIALIZER_NAME + ) def update_last_item_updated(update, items): diff --git a/tests/celery_app/__init__.py b/tests/celery_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/celery_app/context_task_test.py b/tests/celery_app/context_task_test.py new file mode 100644 index 0000000000..d00cb3af08 --- /dev/null +++ b/tests/celery_app/context_task_test.py @@ -0,0 +1,48 @@ +import asyncio +from unittest.mock import patch + +from superdesk.errors import SuperdeskError +from superdesk.celery_app import HybridAppContextTask +from superdesk.tests.asyncio import AsyncFlaskTestCase + + +class TestHybridAppContextTask(AsyncFlaskTestCase): + async def test_sync_task(self): + @self.app.celery.task(base=HybridAppContextTask) + def sync_task(): + return "sync result" + + result = sync_task.apply_async().get() + self.assertEqual(result, "sync result") + + async def test_async_task(self): + @self.app.celery.task(base=HybridAppContextTask) + async def async_task(): + await asyncio.sleep(0.1) + return "async result" + + result = await async_task.apply_async().get() + self.assertEqual(result, "async result") + + async def test_sync_task_exception(self): + @self.app.celery.task(base=HybridAppContextTask) + def sync_task_exception(): + raise SuperdeskError("Test exception") + + with patch("superdesk.celery_app.context_task.logger") as mock_logger: + sync_task_exception.apply_async().get(propagate=True) + expected_exc = SuperdeskError("Test exception") + expected_msg = f"Error handling task: {str(expected_exc)}" + mock_logger.exception.assert_called_once_with(expected_msg) + + async def test_async_task_exception(self): + @self.app.celery.task(base=HybridAppContextTask) + async def async_task_exception(): + raise SuperdeskError("Async exception") + + with patch("superdesk.celery_app.context_task.logger") as mock_logger: + await async_task_exception.apply_async().get() + + expected_exc = SuperdeskError("Async exception") + expected_msg = f"Error handling task: {str(expected_exc)}" + mock_logger.exception.assert_called_once_with(expected_msg) diff --git a/tests/celery_app/serializer_test.py b/tests/celery_app/serializer_test.py new file mode 100644 index 0000000000..c3a2f7c1da --- /dev/null +++ b/tests/celery_app/serializer_test.py @@ -0,0 +1,67 @@ +from bson import ObjectId +from datetime import datetime +from unittest.mock import MagicMock + +from superdesk.celery_app.serializer import ContextAwareSerializerFactory +from superdesk.tests import TestCase + + +class TestContextAwareSerializerFactory(TestCase): + def setUp(self): + self.get_current_app = MagicMock(return_value=self.app) + self.factory = ContextAwareSerializerFactory(self.get_current_app) + + def test_try_cast_object_id(self): + obj_id = ObjectId() + result = self.factory.try_cast(str(obj_id)) + self.assertEqual(result, obj_id) + + def test_try_cast_datetime(self): + date_str = "2021-09-10T14:31:09+0000" + result = self.factory.try_cast(date_str) + self.assertIsInstance(result, datetime) + + def test_dumps(self): + obj = {"test": "data"} + serialized = self.factory.dumps(obj) + self.assertEqual(serialized, '{"test": "data"}') + + def test_serialize_dict(self): + obj = {"key": "2021-09-10T14:31:09+0000", "nested": [{"_id": "528de7b03b80a13eefc5e610"}]} + result = self.factory.serialize(obj) + self.assertIsInstance(result["key"], datetime) + self.assertIsInstance(result["nested"][0]["_id"], ObjectId) + + def test_serialize_list(self): + obj = ["528de7b03b80a13eefc5e610", "2021-09-10T14:31:09+0000"] + result = self.factory.serialize(obj) + self.assertIsInstance(result[0], ObjectId) + self.assertIsInstance(result[1], datetime) + + def test_loads_args(self): + _id = "528de7b03b80a13eefc5e610" + obj = b'{"args": [{"_id": "528de7b03b80a13eefc5e610", "_updated": "2014-09-10T14:31:09+0000"}]}' + result = self.factory.loads(obj) + self.assertEqual(result["args"][0]["_id"], ObjectId(_id)) + self.assertIsInstance(result["args"][0]["_updated"], datetime) + + def test_loads_kwargs(self): + obj = b"""{"kwargs": "{}", "pid": 24998, "eta": null}""" + result = self.factory.loads(obj) + self.assertEqual({}, result["kwargs"]) + self.assertIsNone(result["eta"]) + + def test_loads_lists(self): + obj = b"""[{}, {"foo": null}]""" + result = self.factory.loads(obj) + self.assertEqual([{}, {"foo": None}], result) + + def test_loads_zero(self): + obj = b"""[0]""" + result = self.factory.loads(obj) + self.assertEqual([0], result) + + def test_loads_boolean(self): + obj = b"""[{"foo": false, "bar": true}]""" + result = self.factory.loads(obj) + self.assertEqual([{"foo": False, "bar": True}], result) diff --git a/tests/celery_tests.py b/tests/celery_tests.py deleted file mode 100644 index 9f20a77192..0000000000 --- a/tests/celery_tests.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8; -*- -# -# This file is part of Superdesk. -# -# Copyright 2013, 2014 Sourcefabric z.u. and contributors. -# -# For the full copyright and license information, please see the -# AUTHORS and LICENSE files distributed with this source code, or -# at https://www.sourcefabric.org/superdesk/license - -from datetime import datetime - -from bson import ObjectId -from eve.utils import date_to_str - -from superdesk.tests import TestCase -from superdesk.celery_app import try_cast, loads - - -class CeleryTestCase(TestCase): - _id = ObjectId("528de7b03b80a13eefc5e610") - - def test_cast_objectid(self): - self.assertEqual(try_cast(str(self._id)), self._id) - - def test_cast_datetime(self): - date = datetime(2012, 12, 12, 12, 12, 12, 0) - with self.app.app_context(): - s = date_to_str(date) - self.assertEqual(try_cast(s).day, date.day) - - def test_loads_args(self): - s = b'{"args": [{"_id": "528de7b03b80a13eefc5e610", "_updated": "2014-09-10T14:31:09+0000"}]}' - o = loads(s) - self.assertEqual(o["args"][0]["_id"], self._id) - self.assertIsInstance(o["args"][0]["_updated"], datetime) - - def test_loads_kwargs(self): - s = b"""{"kwargs": "{}", "pid": 24998, "eta": null}""" - o = loads(s) - self.assertEqual({}, o["kwargs"]) - self.assertIsNone(o["eta"]) - - def test_loads_lists(self): - s = b"""[{}, {"foo": null}]""" - o = loads(s) - self.assertEqual([{}, {"foo": None}], o) - - def test_loads_zero(self): - s = b"""[0]""" - o = loads(s) - self.assertEqual([0], o) - - def test_loads_boolean(self): - s = b"""[{"foo": false, "bar": true}]""" - o = loads(s) - self.assertEqual([{"foo": False, "bar": True}], o)