Skip to content

Commit

Permalink
fix transactions in multithreading contexts (#53)
Browse files Browse the repository at this point in the history
Changes:

- fix transactions with multithreading
- stack the transaction backend instances on the transaction stack and
  remove ACTIVE_TRANSACTIONS (unreliable with multithreading)
- bump version
- cleanup utils move  async helpers to corresponding files
- decrease connection counter on failed start calls (transaction) when not using `__aenter__`
  • Loading branch information
devkral authored Sep 5, 2024
1 parent 61d3581 commit fca84df
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 273 deletions.
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.10.1"
__version__ = "0.10.2"

__all__ = ["Database", "DatabaseURL"]
4 changes: 2 additions & 2 deletions databasez/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .connection import Connection
from .database import Database, init
from .databaseurl import DatabaseURL
from .transaction import ACTIVE_TRANSACTIONS, Transaction
from .transaction import Transaction

__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction", "ACTIVE_TRANSACTIONS"]
__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction"]
58 changes: 54 additions & 4 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import asyncio
import typing
import weakref
from functools import partial
from threading import Event, Lock, Thread, current_thread
from types import TracebackType

from sqlalchemy import text

from databasez import interfaces
from databasez.utils import multiloop_protector
from databasez.utils import _arun_with_timeout, arun_coroutine_threadsafe, multiloop_protector

from .transaction import Transaction

Expand Down Expand Up @@ -59,7 +60,50 @@ def _init_thread(
database._global_connection._isolation_thread = None # type: ignore


class AsyncHelperConnection:
def __init__(
self,
connection: Connection,
fn: typing.Callable,
args: typing.Any,
kwargs: typing.Any,
timeout: typing.Optional[float],
) -> None:
self.connection = connection
self.fn = partial(fn, self.connection, *args, **kwargs)
self.timeout = timeout
self.ctm = None

async def call(self) -> typing.Any:
# is automatically awaited
result = await _arun_with_timeout(self.fn(), self.timeout)
return result

async def acall(self) -> typing.Any:
return await arun_coroutine_threadsafe(
self.call(), self.connection._loop, self.connection.poll_interval
)

def __await__(self) -> typing.Any:
return self.acall().__await__()

async def __aiter__(self) -> typing.Any:
result = await self.acall()
try:
while True:
yield await arun_coroutine_threadsafe(
_arun_with_timeout(result.__anext__(), self.timeout),
self.connection._loop,
self.connection.poll_interval,
)
except StopAsyncIteration:
pass


class Connection:
# async helper
async_helper: typing.Type[AsyncHelperConnection] = AsyncHelperConnection

def __init__(
self, database: Database, force_rollback: bool = False, full_isolation: bool = False
) -> None:
Expand All @@ -86,11 +130,18 @@ def __init__(
self._connection.owner = self
self._connection_counter = 0

self._transaction_stack: typing.List[Transaction] = []
# for keeping weak references to transactions active
self._transaction_stack: typing.List[
typing.Tuple[Transaction, interfaces.TransactionBackend]
] = []

self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

@multiloop_protector(True)
def _get_connection_backend(self) -> interfaces.ConnectionBackend:
return self._connection

@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def _aenter(self) -> None:
async with self._connection_lock:
Expand All @@ -111,8 +162,7 @@ async def _aenter(self) -> None:
self.connection_transaction = self.transaction(
force_rollback=self._force_rollback
)
# make re-entrant, we have already the connection lock
await self.connection_transaction.start(True)
await self.connection_transaction.start()
except BaseException as e:
self._connection_counter -= 1
raise e
Expand Down
54 changes: 53 additions & 1 deletion databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from types import TracebackType

from databasez import interfaces
from databasez.utils import DATABASEZ_POLL_INTERVAL, arun_coroutine_threadsafe, multiloop_protector
from databasez.utils import (
DATABASEZ_POLL_INTERVAL,
_arun_with_timeout,
arun_coroutine_threadsafe,
multiloop_protector,
)

from .connection import Connection
from .databaseurl import DatabaseURL
Expand Down Expand Up @@ -118,6 +123,51 @@ def __delete__(self, obj: Database) -> None:
obj._force_rollback.set(None)


class AsyncHelperDatabase:
def __init__(
self,
database: Database,
fn: typing.Callable,
args: typing.Any,
kwargs: typing.Any,
timeout: typing.Optional[float],
) -> None:
self.database = database
self.fn = fn
self.args = args
self.kwargs = kwargs
self.timeout = timeout
self.ctm = None

async def call(self) -> typing.Any:
async with self.database as database:
return await _arun_with_timeout(
self.fn(database, *self.args, **self.kwargs), self.timeout
)

def __await__(self) -> typing.Any:
return self.call().__await__()

async def __aenter__(self) -> typing.Any:
database = await self.database.__aenter__()
self.ctm = await _arun_with_timeout(
self.fn(database, *self.args, **self.kwargs), timeout=self.timeout
)
return await self.ctm.__aenter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
assert self.ctm is not None
try:
await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None)
finally:
await self.database.__aexit__()


class Database:
"""
An abstraction on the top of the EncodeORM databases.Database object.
Expand Down Expand Up @@ -156,6 +206,8 @@ class Database:
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
# async helper
async_helper: typing.Type[AsyncHelperDatabase] = AsyncHelperDatabase

def __init__(
self,
Expand Down
Loading

0 comments on commit fca84df

Please sign in to comment.