Skip to content

Commit

Permalink
switching off force_rollback, refcounting Database connect/disconnect…
Browse files Browse the repository at this point in the history
… calls. (#34)

Changes:

- make force_rollback smarter
- allow switching back from force_rollback
- allow nesting calls of Database
- document __copy__ and  force_rollback
- optimize delete/reset path for global weakmaps
- bump version
  • Loading branch information
devkral authored Aug 8, 2024
1 parent 4f753cd commit 6b9a0a9
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 72 deletions.
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.8.5"
__version__ = "0.9.0"

__all__ = ["Database", "DatabaseURL"]
2 changes: 1 addition & 1 deletion databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def __aenter__(self) -> Connection:
self.connection_transaction = self.transaction(
existing_transaction=raw_transaction
)
# we don't need to call __aenter__ of connection_transaction
# we don't need to call __aenter__ of connection_transaction, it is not on the stack
except BaseException as e:
self._connection_counter -= 1
raise e
Expand Down
146 changes: 109 additions & 37 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import typing
import weakref
from contextvars import ContextVar
from functools import lru_cache
from types import TracebackType

Expand Down Expand Up @@ -61,6 +62,61 @@ def init() -> None:
)


ACTIVE_FORCE_ROLLBACKS: ContextVar[
typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]]
] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None)


class ForceRollback:
default: bool

def __init__(self, default: bool):
self.default = default

def set(self, value: typing.Union[bool, None] = None) -> None:
force_rollbacks = ACTIVE_FORCE_ROLLBACKS.get()
if force_rollbacks is None:
# shortcut, we don't need to initialize anything for None (reset)
if value is None:
return
force_rollbacks = weakref.WeakKeyDictionary()
else:
force_rollbacks = force_rollbacks.copy()
if value is None:
force_rollbacks.pop(self, None)
else:
force_rollbacks[self] = value
# it is always a copy required to prevent sideeffects between the contexts
ACTIVE_FORCE_ROLLBACKS.set(force_rollbacks)

def __bool__(self) -> bool:
force_rollbacks = ACTIVE_FORCE_ROLLBACKS.get()
if force_rollbacks is None:
return self.default
return force_rollbacks.get(self, self.default)

@contextlib.contextmanager
def __call__(self, force_rollback: bool = True) -> typing.Iterator[None]:
initial = bool(self)
self.set(force_rollback)
try:
yield
finally:
self.set(initial)


class ForceRollbackDescriptor:
def __get__(self, obj: Database, objtype: typing.Type[Database]) -> ForceRollback:
return obj._force_rollback

def __set__(self, obj: Database, value: typing.Union[bool, None]) -> None:
assert value is None or isinstance(value, bool), f"Invalid type: {value!r}."
obj._force_rollback.set(value)

def __delete__(self, obj: Database) -> None:
obj._force_rollback.set(None)


class Database:
"""
An abstraction on the top of the EncodeORM databases.Database object.
Expand Down Expand Up @@ -89,8 +145,10 @@ class Database:
backend: interfaces.DatabaseBackend
url: DatabaseURL
options: typing.Any
_force_rollback: bool
is_connected: bool = False
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()

def __init__(
self,
Expand All @@ -108,9 +166,7 @@ def __init__(
self.url = url.url
self.options = url.options
if force_rollback is None:
self._force_rollback = url._force_rollback
else:
self._force_rollback = force_rollback
force_rollback = bool(url.force_rollback)
else:
url = DatabaseURL(url)
if config and "connection" in config:
Expand All @@ -121,7 +177,9 @@ def __init__(
self.backend, self.url, self.options = self.apply_database_url_and_options(
url, **options
)
self._force_rollback = bool(force_rollback)
if force_rollback is None:
force_rollback = False
self._force_rollback = ForceRollback(force_rollback)
self.backend.owner = self
self._connection_map = weakref.WeakKeyDictionary()

Expand All @@ -130,6 +188,9 @@ def __init__(
self._global_connection: typing.Optional[Connection] = None
self._global_transaction: typing.Optional[Transaction] = None

self.ref_counter: int = 0
self.ref_lock: asyncio.Lock = asyncio.Lock()

def __copy__(self) -> Database:
return self.__class__(self)

Expand All @@ -155,45 +216,65 @@ def _connection(self, connection: typing.Optional[Connection]) -> typing.Optiona

return self._connection

async def inc_refcount(self) -> bool:
async with self.ref_lock:
self.ref_counter += 1
# on the first call is count is 1 because of the former +1
if self.ref_counter == 1:
return True
return False

async def decr_refcount(self) -> bool:
async with self.ref_lock:
self.ref_counter -= 1
# on the last call, the count is 0
if self.ref_counter == 0:
return True
return False

async def connect(self) -> None:
"""
Establish the connection pool.
"""
if self.is_connected:
logger.debug("Already connected, skipping connection")
if not await self.inc_refcount():
assert self.is_connected, "ref_count < 0"
return None

await self.backend.connect(self.url, **self.options)
logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA)
self.is_connected = True

if self._force_rollback:
assert self._global_connection is None
assert self._global_transaction is None
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self, self.backend)
self._global_transaction = self._global_connection.transaction(force_rollback=True)
self._global_connection = Connection(self, self.backend)
self._global_transaction = self._global_connection.transaction(force_rollback=True)
await self._global_transaction.__aenter__()

await self._global_transaction.__aenter__()

async def disconnect(self) -> None:
async def disconnect(self, force: bool = False) -> None:
"""
Close all connections in the connection pool.
"""
if not self.is_connected:
logger.debug("Already disconnected, skipping disconnection")
return None
if not await self.decr_refcount() or force:
if not self.is_connected:
logger.debug("Already disconnected, skipping disconnection")
return None
if force:
logger.warning("Force disconnect, despite refcount not 0")
else:
return None

if self._force_rollback:
assert self._global_connection is not None
assert self._global_transaction is not None
assert self._global_connection is not None
assert self._global_transaction is not None

await self._global_transaction.__aexit__()
await self._global_transaction.__aexit__()
assert (
self._global_connection._connection_counter == 0
), f"global connection active: {self._global_connection._connection_counter}"

self._global_transaction = None
self._global_connection = None
else:
self._connection = None
self._global_transaction = None
self._global_connection = None
self._connection = None

await self.backend.disconnect()
logger.info(
Expand Down Expand Up @@ -296,8 +377,8 @@ async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
await connection.drop_all(meta, **kwargs)

def connection(self) -> Connection:
if self._global_connection is not None:
return self._global_connection
if self.force_rollback:
return typing.cast(Connection, self._global_connection)

if not self._connection:
self._connection = Connection(self, self.backend)
Expand All @@ -307,15 +388,6 @@ def connection(self) -> Connection:
def engine(self) -> typing.Optional[AsyncEngine]:
return self.backend.engine

@contextlib.contextmanager
def force_rollback(self) -> typing.Iterator[None]:
initial = self._force_rollback
self._force_rollback = True
try:
yield
finally:
self._force_rollback = initial

@classmethod
def get_backends(
cls,
Expand Down
38 changes: 24 additions & 14 deletions databasez/core/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def _transaction(
) -> typing.Optional[interfaces.TransactionBackend]:
transactions = ACTIVE_TRANSACTIONS.get()
if transactions is None:
# shortcut, we don't need to initialize anything for None (remove transaction)
if transaction is None:
return None
transactions = weakref.WeakKeyDictionary()
else:
transactions = transactions.copy()
Expand All @@ -62,8 +65,10 @@ def _transaction(
transactions.pop(self, None)
else:
transactions[self] = transaction

# It is always a copy required to
# prevent sideeffects between contexts
ACTIVE_TRANSACTIONS.set(transactions)

return transactions.get(self, None)

async def __aenter__(self) -> Transaction:
Expand Down Expand Up @@ -107,33 +112,38 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return wrapper # type: ignore

async def start(self) -> Transaction:
async with self.connection._transaction_lock:
is_root = not self.connection._transaction_stack
_transaction = self.connection._connection.transaction(self._existing_transaction)
connection = self.connection
async with connection._transaction_lock:
is_root = not connection._transaction_stack
_transaction = connection._connection.transaction(self._existing_transaction)
_transaction.owner = self
await self.connection.__aenter__()
await connection.__aenter__()
if self._existing_transaction is None:
await _transaction.start(is_root=is_root, **self._extra_options)
self._transaction = _transaction
self.connection._transaction_stack.append(self)
connection._transaction_stack.append(self)
return self

async def commit(self) -> None:
async with self.connection._transaction_lock:
connection = self.connection
async with connection._transaction_lock:
_transaction = self._transaction
if _transaction is not None:
# delete transaction from ACTIVE_TRANSACTIONS
self._transaction = None
assert self.connection._transaction_stack[-1] is self
self.connection._transaction_stack.pop()
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await _transaction.commit()
await self.connection.__aexit__()
await connection.__aexit__()

async def rollback(self) -> None:
async with self.connection._transaction_lock:
connection = self.connection
async with connection._transaction_lock:
_transaction = self._transaction
if _transaction is not None:
# delete transaction from ACTIVE_TRANSACTIONS
self._transaction = None
assert self.connection._transaction_stack[-1] is self
self.connection._transaction_stack.pop()
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await _transaction.rollback()
await self.connection.__aexit__()
await connection.__aexit__()
5 changes: 3 additions & 2 deletions databasez/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
connection: ConnectionBackend,
existing_transaction: typing.Optional[Transaction] = None,
):
# cannot be a weak ref otherwise connections get lost when retrieving them via transactions
self.connection = connection
self.raw_transaction = existing_transaction

Expand Down Expand Up @@ -235,14 +236,14 @@ def __copy__(self) -> DatabaseBackend:

@property
def owner(self) -> typing.Optional[RootDatabase]:
result = self.__dict__.get("root")
result = self.__dict__.get("owner")
if result is None:
return None
return typing.cast("RootDatabase", result())

@owner.setter
def owner(self, value: RootDatabase) -> None:
self.__dict__["root"] = weakref.ref(value)
self.__dict__["owner"] = weakref.ref(value)

@abstractmethod
async def connect(self, database_url: DatabaseURL, **options: typing.Any) -> None:
Expand Down
4 changes: 2 additions & 2 deletions databasez/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) ->

await db_client.disconnect()

async def disconnect(self) -> None:
async def disconnect(self, force: bool = False) -> None:
if self.drop:
await self.drop_database(self.test_db_url)
await super().disconnect()
await super().disconnect(force)
22 changes: 19 additions & 3 deletions docs/connections-and-transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ from databasez import Database

**Parameters**

* **url** - The `url` of the connection string.
* **url** - The `url` of the connection string or a Database object to copy from.

<sup>Default: `None`</sup>

* **force_rollback** - A boolean flag indicating if it should force the rollback.
* **force_rollback** - An optional boolean flag for force_rollback. Overwritable at runtime possible.
Note: when None it copies the value from the provided Database object or sets it to False.

<sup>Default: `False`</sup>
<sup>Default: `None`</sup>

* **config** - A python like dictionary as alternative to the `url` that contains the information
to connect to the database.
Expand All @@ -35,6 +36,21 @@ to connect to the database.
at the same time.


**Attributes***

* **force_rollback**:
It evaluates its trueness value to the active value of force_rollback for this context.
You can delete it to reset it (`del database.force_rollback`) (it uses the descriptor magic).

**Functions**

* **__copy__** - Either usable directly or via copy from the copy module. A fresh Database object with the same options as the existing is created.
Note: for creating a copy with overwritten initial force_rollback you can use: `Database(database_obj, force_rollback=False)`.
Note: you have to connect it.

* **force_rollback(force_rollback=True)**: - The magic attribute is also function returning a context-manager for temporary overwrites of force_rollback.


## Connecting and disconnecting

You can control the database connection, by using it as a async context manager.
Expand Down
Loading

0 comments on commit 6b9a0a9

Please sign in to comment.