Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix transaction method to work on instance and class #263

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ hide:
- Invalidating caused schema errors.
- ManyToMany and ForeignKey fields didn't worked when referencing tenant models.
- ManyToMany fields didn't worked when specified on tenant models.
- Fix transaction method to work on instance and class.

### BREAKING

Expand Down
6 changes: 3 additions & 3 deletions edgy/contrib/multi_tenancy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def real_add_to_registry(cls, **kwargs: Any) -> type["BaseModelType"]:
and not cls.meta.abstract
and not cls.__is_proxy_model__
):
assert cls.__reflected__ is False, (
"Reflected models are not compatible with multi_tenancy"
)
assert (
cls.__reflected__ is False
), "Reflected models are not compatible with multi_tenancy"

if not cls.meta.register_default:
# remove from models
Expand Down
6 changes: 3 additions & 3 deletions edgy/core/connection/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def callback(model_class: type["BaseModelType"]) -> None:
if "content_type" in model_class.meta.fields:
return
related_name = f"reverse_{model_class.__name__.lower()}"
assert related_name not in real_content_type.meta.fields, (
f"duplicate model name: {model_class.__name__}"
)
assert (
related_name not in real_content_type.meta.fields
), f"duplicate model name: {model_class.__name__}"

field_args: dict[str, Any] = {
"name": "content_type",
Expand Down
8 changes: 8 additions & 0 deletions edgy/core/db/models/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from edgy.exceptions import ImproperlyConfigured, TableBuildError

if TYPE_CHECKING:
from databasez.core.transaction import Transaction

from edgy.core.connection import Database
from edgy.core.db.models import Model
from edgy.core.db.models.types import BaseModelType
Expand Down Expand Up @@ -879,6 +881,12 @@ def signals(cls) -> signals_module.Broadcaster:
meta: MetaInfo = cls.meta
return meta.signals

def transaction(cls, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database"""
return cast(
"Transaction", cls.database.transaction(force_rollback=force_rollback, **kwargs)
)

def table_schema(
cls,
schema: Union[str, None] = None,
Expand Down
14 changes: 12 additions & 2 deletions edgy/core/db/models/mixins/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _set_related_name_for_foreign_keys(
class DatabaseMixin:
_removed_copy_keys: ClassVar[set[str]] = _removed_copy_keys

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__dict__["transaction"] = self.not_set_transaction

@classmethod
def real_add_to_registry(
cls: type[BaseModelType],
Expand Down Expand Up @@ -809,8 +813,14 @@ def _get_indexes(cls, index: Index) -> Optional[sqlalchemy.Index]:
),
)

def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database"""
def not_set_transaction(
self=None, *, force_rollback: bool = False, **kwargs: Any
) -> Transaction:
"""
Return database transaction for the assigned database.

This method is automatically assigned to transaction masking the metaclass transaction for instances.
"""
return cast(
"Transaction", self.database.transaction(force_rollback=force_rollback, **kwargs)
)
15 changes: 3 additions & 12 deletions edgy/core/db/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,17 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

if TYPE_CHECKING:
import sqlalchemy
from databasez.core.transaction import Transaction

from edgy.core.connection.database import Database
from edgy.core.db.models.base import BaseModel
from edgy.core.db.models.managers import BaseManager
from edgy.core.db.models.metaclasses import MetaInfo
from edgy.core.db.querysets.base import QuerySet
from edgy.protocols.transaction_call import TransactionCallProtocol


class DescriptiveMeta:
Expand Down Expand Up @@ -59,6 +53,7 @@ class BaseModelType(ABC):
query_related: ClassVar[BaseManager]
meta: ClassVar[MetaInfo]
Meta: ClassVar[DescriptiveMeta] = DescriptiveMeta()
transaction: ClassVar[TransactionCallProtocol]

__parent__: ClassVar[Union[type[BaseModelType], None]] = None
__is_proxy_model__: ClassVar[bool] = False
Expand All @@ -80,10 +75,6 @@ def identifying_db_fields(self) -> Any:
def can_load(self) -> bool:
"""identifying_db_fields are completely specified."""

@abstractmethod
def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database."""

@abstractmethod
def get_columns_for_name(self, name: str) -> Sequence[sqlalchemy.Column]:
"""Helper for retrieving columns from field name."""
Expand Down
10 changes: 10 additions & 0 deletions edgy/protocols/transaction_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from databasez.core.transaction import Transaction


class TransactionCallProtocol(Protocol):
def __call__(instance: Any, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: ...
6 changes: 6 additions & 0 deletions tests/models/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def test_model_class():
assert isinstance(User.query.meta.fields["name"], Field)


def test_transactions():
user = User(id=1)
User.transaction()
user.transaction()


def test_model_pk():
user = User(pk=1)
assert user.pk == 1
Expand Down
Loading