diff --git a/docs/release-notes.md b/docs/release-notes.md index 21c10bb3..256c2012 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -36,6 +36,7 @@ hide: - Invalidating caused schema errors. - ManyToMany and ForeignKey fields didn't worked when referencing tenant models. - ManyToMany fields didn't worked when specified on tenant models. +- Fix transaction method to work on instance and class. ### BREAKING diff --git a/edgy/contrib/multi_tenancy/base.py b/edgy/contrib/multi_tenancy/base.py index 9463c6dd..e9d84c83 100644 --- a/edgy/contrib/multi_tenancy/base.py +++ b/edgy/contrib/multi_tenancy/base.py @@ -30,9 +30,9 @@ def real_add_to_registry(cls, **kwargs: Any) -> type["BaseModelType"]: and not cls.meta.abstract and not cls.__is_proxy_model__ ): - assert cls.__reflected__ is False, ( - "Reflected models are not compatible with multi_tenancy" - ) + assert ( + cls.__reflected__ is False + ), "Reflected models are not compatible with multi_tenancy" if not cls.meta.register_default: # remove from models diff --git a/edgy/core/connection/registry.py b/edgy/core/connection/registry.py index ee4d9e37..2db17b9c 100644 --- a/edgy/core/connection/registry.py +++ b/edgy/core/connection/registry.py @@ -249,9 +249,9 @@ def callback(model_class: type["BaseModelType"]) -> None: if "content_type" in model_class.meta.fields: return related_name = f"reverse_{model_class.__name__.lower()}" - assert related_name not in real_content_type.meta.fields, ( - f"duplicate model name: {model_class.__name__}" - ) + assert ( + related_name not in real_content_type.meta.fields + ), f"duplicate model name: {model_class.__name__}" field_args: dict[str, Any] = { "name": "content_type", diff --git a/edgy/core/db/models/metaclasses.py b/edgy/core/db/models/metaclasses.py index b29e5f37..aed0f8fd 100644 --- a/edgy/core/db/models/metaclasses.py +++ b/edgy/core/db/models/metaclasses.py @@ -34,6 +34,8 @@ from edgy.exceptions import ImproperlyConfigured, TableBuildError if TYPE_CHECKING: + from databasez.core.transaction import Transaction + from edgy.core.connection import Database from edgy.core.db.models import Model from edgy.core.db.models.types import BaseModelType @@ -879,6 +881,12 @@ def signals(cls) -> signals_module.Broadcaster: meta: MetaInfo = cls.meta return meta.signals + def transaction(cls, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: + """Return database transaction for the assigned database""" + return cast( + "Transaction", cls.database.transaction(force_rollback=force_rollback, **kwargs) + ) + def table_schema( cls, schema: Union[str, None] = None, diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index 4b58acb3..ee8ef9ef 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -169,6 +169,10 @@ def _set_related_name_for_foreign_keys( class DatabaseMixin: _removed_copy_keys: ClassVar[set[str]] = _removed_copy_keys + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.__dict__["transaction"] = self.not_set_transaction + @classmethod def real_add_to_registry( cls: type[BaseModelType], @@ -809,8 +813,14 @@ def _get_indexes(cls, index: Index) -> Optional[sqlalchemy.Index]: ), ) - def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: - """Return database transaction for the assigned database""" + def not_set_transaction( + self=None, *, force_rollback: bool = False, **kwargs: Any + ) -> Transaction: + """ + Return database transaction for the assigned database. + + This method is automatically assigned to transaction masking the metaclass transaction for instances. + """ return cast( "Transaction", self.database.transaction(force_rollback=force_rollback, **kwargs) ) diff --git a/edgy/core/db/models/types.py b/edgy/core/db/models/types.py index d32c0394..40f125e2 100644 --- a/edgy/core/db/models/types.py +++ b/edgy/core/db/models/types.py @@ -2,23 +2,17 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union if TYPE_CHECKING: import sqlalchemy - from databasez.core.transaction import Transaction from edgy.core.connection.database import Database from edgy.core.db.models.base import BaseModel from edgy.core.db.models.managers import BaseManager from edgy.core.db.models.metaclasses import MetaInfo from edgy.core.db.querysets.base import QuerySet + from edgy.protocols.transaction_call import TransactionCallProtocol class DescriptiveMeta: @@ -59,6 +53,7 @@ class BaseModelType(ABC): query_related: ClassVar[BaseManager] meta: ClassVar[MetaInfo] Meta: ClassVar[DescriptiveMeta] = DescriptiveMeta() + transaction: ClassVar[TransactionCallProtocol] __parent__: ClassVar[Union[type[BaseModelType], None]] = None __is_proxy_model__: ClassVar[bool] = False @@ -80,10 +75,6 @@ def identifying_db_fields(self) -> Any: def can_load(self) -> bool: """identifying_db_fields are completely specified.""" - @abstractmethod - def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: - """Return database transaction for the assigned database.""" - @abstractmethod def get_columns_for_name(self, name: str) -> Sequence[sqlalchemy.Column]: """Helper for retrieving columns from field name.""" diff --git a/edgy/protocols/transaction_call.py b/edgy/protocols/transaction_call.py new file mode 100644 index 00000000..aec3863e --- /dev/null +++ b/edgy/protocols/transaction_call.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from databasez.core.transaction import Transaction + + +class TransactionCallProtocol(Protocol): + def __call__(instance: Any, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: ... diff --git a/tests/models/test_model_class.py b/tests/models/test_model_class.py index ccbc598a..d13297bb 100644 --- a/tests/models/test_model_class.py +++ b/tests/models/test_model_class.py @@ -59,6 +59,12 @@ def test_model_class(): assert isinstance(User.query.meta.fields["name"], Field) +def test_transactions(): + user = User(id=1) + User.transaction() + user.transaction() + + def test_model_pk(): user = User(pk=1) assert user.pk == 1