From 6b9a0a9aa09390307d446520bf787ab65fb4c1e8 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 8 Aug 2024 10:07:00 +0200 Subject: [PATCH] switching off force_rollback, refcounting Database connect/disconnect calls. (#34) Changes: - make force_rollback smarter - allow switching back from force_rollback - allow nesting calls of Database - document __copy__ and force_rollback - optimize delete/reset path for global weakmaps - bump version --- databasez/__init__.py | 2 +- databasez/core/connection.py | 2 +- databasez/core/database.py | 146 ++++++++++++++++++++------- databasez/core/transaction.py | 38 ++++--- databasez/interfaces.py | 5 +- databasez/testclient.py | 4 +- docs/connections-and-transactions.md | 22 +++- docs/release-notes.md | 20 ++++ tests/test_databases.py | 67 ++++++++++-- tests/test_transactions.py | 18 +++- 10 files changed, 252 insertions(+), 72 deletions(-) diff --git a/databasez/__init__.py b/databasez/__init__.py index 2777fa9..7be2e17 100644 --- a/databasez/__init__.py +++ b/databasez/__init__.py @@ -1,5 +1,5 @@ from databasez.core import Database, DatabaseURL -__version__ = "0.8.5" +__version__ = "0.9.0" __all__ = ["Database", "DatabaseURL"] diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 23acbec..e3ebb70 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -44,7 +44,7 @@ async def __aenter__(self) -> Connection: self.connection_transaction = self.transaction( existing_transaction=raw_transaction ) - # we don't need to call __aenter__ of connection_transaction + # we don't need to call __aenter__ of connection_transaction, it is not on the stack except BaseException as e: self._connection_counter -= 1 raise e diff --git a/databasez/core/database.py b/databasez/core/database.py index feb9996..996ccef 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -6,6 +6,7 @@ import logging import typing import weakref +from contextvars import ContextVar from functools import lru_cache from types import TracebackType @@ -61,6 +62,61 @@ def init() -> None: ) +ACTIVE_FORCE_ROLLBACKS: ContextVar[ + typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]] +] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None) + + +class ForceRollback: + default: bool + + def __init__(self, default: bool): + self.default = default + + def set(self, value: typing.Union[bool, None] = None) -> None: + force_rollbacks = ACTIVE_FORCE_ROLLBACKS.get() + if force_rollbacks is None: + # shortcut, we don't need to initialize anything for None (reset) + if value is None: + return + force_rollbacks = weakref.WeakKeyDictionary() + else: + force_rollbacks = force_rollbacks.copy() + if value is None: + force_rollbacks.pop(self, None) + else: + force_rollbacks[self] = value + # it is always a copy required to prevent sideeffects between the contexts + ACTIVE_FORCE_ROLLBACKS.set(force_rollbacks) + + def __bool__(self) -> bool: + force_rollbacks = ACTIVE_FORCE_ROLLBACKS.get() + if force_rollbacks is None: + return self.default + return force_rollbacks.get(self, self.default) + + @contextlib.contextmanager + def __call__(self, force_rollback: bool = True) -> typing.Iterator[None]: + initial = bool(self) + self.set(force_rollback) + try: + yield + finally: + self.set(initial) + + +class ForceRollbackDescriptor: + def __get__(self, obj: Database, objtype: typing.Type[Database]) -> ForceRollback: + return obj._force_rollback + + def __set__(self, obj: Database, value: typing.Union[bool, None]) -> None: + assert value is None or isinstance(value, bool), f"Invalid type: {value!r}." + obj._force_rollback.set(value) + + def __delete__(self, obj: Database) -> None: + obj._force_rollback.set(None) + + class Database: """ An abstraction on the top of the EncodeORM databases.Database object. @@ -89,8 +145,10 @@ class Database: backend: interfaces.DatabaseBackend url: DatabaseURL options: typing.Any - _force_rollback: bool is_connected: bool = False + _force_rollback: ForceRollback + # descriptor + force_rollback = ForceRollbackDescriptor() def __init__( self, @@ -108,9 +166,7 @@ def __init__( self.url = url.url self.options = url.options if force_rollback is None: - self._force_rollback = url._force_rollback - else: - self._force_rollback = force_rollback + force_rollback = bool(url.force_rollback) else: url = DatabaseURL(url) if config and "connection" in config: @@ -121,7 +177,9 @@ def __init__( self.backend, self.url, self.options = self.apply_database_url_and_options( url, **options ) - self._force_rollback = bool(force_rollback) + if force_rollback is None: + force_rollback = False + self._force_rollback = ForceRollback(force_rollback) self.backend.owner = self self._connection_map = weakref.WeakKeyDictionary() @@ -130,6 +188,9 @@ def __init__( self._global_connection: typing.Optional[Connection] = None self._global_transaction: typing.Optional[Transaction] = None + self.ref_counter: int = 0 + self.ref_lock: asyncio.Lock = asyncio.Lock() + def __copy__(self) -> Database: return self.__class__(self) @@ -155,45 +216,65 @@ def _connection(self, connection: typing.Optional[Connection]) -> typing.Optiona return self._connection + async def inc_refcount(self) -> bool: + async with self.ref_lock: + self.ref_counter += 1 + # on the first call is count is 1 because of the former +1 + if self.ref_counter == 1: + return True + return False + + async def decr_refcount(self) -> bool: + async with self.ref_lock: + self.ref_counter -= 1 + # on the last call, the count is 0 + if self.ref_counter == 0: + return True + return False + async def connect(self) -> None: """ Establish the connection pool. """ - if self.is_connected: - logger.debug("Already connected, skipping connection") + if not await self.inc_refcount(): + assert self.is_connected, "ref_count < 0" return None await self.backend.connect(self.url, **self.options) logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA) self.is_connected = True - if self._force_rollback: - assert self._global_connection is None - assert self._global_transaction is None + assert self._global_connection is None + assert self._global_transaction is None - self._global_connection = Connection(self, self.backend) - self._global_transaction = self._global_connection.transaction(force_rollback=True) + self._global_connection = Connection(self, self.backend) + self._global_transaction = self._global_connection.transaction(force_rollback=True) + await self._global_transaction.__aenter__() - await self._global_transaction.__aenter__() - - async def disconnect(self) -> None: + async def disconnect(self, force: bool = False) -> None: """ Close all connections in the connection pool. """ - if not self.is_connected: - logger.debug("Already disconnected, skipping disconnection") - return None + if not await self.decr_refcount() or force: + if not self.is_connected: + logger.debug("Already disconnected, skipping disconnection") + return None + if force: + logger.warning("Force disconnect, despite refcount not 0") + else: + return None - if self._force_rollback: - assert self._global_connection is not None - assert self._global_transaction is not None + assert self._global_connection is not None + assert self._global_transaction is not None - await self._global_transaction.__aexit__() + await self._global_transaction.__aexit__() + assert ( + self._global_connection._connection_counter == 0 + ), f"global connection active: {self._global_connection._connection_counter}" - self._global_transaction = None - self._global_connection = None - else: - self._connection = None + self._global_transaction = None + self._global_connection = None + self._connection = None await self.backend.disconnect() logger.info( @@ -296,8 +377,8 @@ async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: await connection.drop_all(meta, **kwargs) def connection(self) -> Connection: - if self._global_connection is not None: - return self._global_connection + if self.force_rollback: + return typing.cast(Connection, self._global_connection) if not self._connection: self._connection = Connection(self, self.backend) @@ -307,15 +388,6 @@ def connection(self) -> Connection: def engine(self) -> typing.Optional[AsyncEngine]: return self.backend.engine - @contextlib.contextmanager - def force_rollback(self) -> typing.Iterator[None]: - initial = self._force_rollback - self._force_rollback = True - try: - yield - finally: - self._force_rollback = initial - @classmethod def get_backends( cls, diff --git a/databasez/core/transaction.py b/databasez/core/transaction.py index 36c6953..8874650 100644 --- a/databasez/core/transaction.py +++ b/databasez/core/transaction.py @@ -54,6 +54,9 @@ def _transaction( ) -> typing.Optional[interfaces.TransactionBackend]: transactions = ACTIVE_TRANSACTIONS.get() if transactions is None: + # shortcut, we don't need to initialize anything for None (remove transaction) + if transaction is None: + return None transactions = weakref.WeakKeyDictionary() else: transactions = transactions.copy() @@ -62,8 +65,10 @@ def _transaction( transactions.pop(self, None) else: transactions[self] = transaction - + # It is always a copy required to + # prevent sideeffects between contexts ACTIVE_TRANSACTIONS.set(transactions) + return transactions.get(self, None) async def __aenter__(self) -> Transaction: @@ -107,33 +112,38 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> Transaction: - async with self.connection._transaction_lock: - is_root = not self.connection._transaction_stack - _transaction = self.connection._connection.transaction(self._existing_transaction) + connection = self.connection + async with connection._transaction_lock: + is_root = not connection._transaction_stack + _transaction = connection._connection.transaction(self._existing_transaction) _transaction.owner = self - await self.connection.__aenter__() + await connection.__aenter__() if self._existing_transaction is None: await _transaction.start(is_root=is_root, **self._extra_options) self._transaction = _transaction - self.connection._transaction_stack.append(self) + connection._transaction_stack.append(self) return self async def commit(self) -> None: - async with self.connection._transaction_lock: + connection = self.connection + async with connection._transaction_lock: _transaction = self._transaction if _transaction is not None: + # delete transaction from ACTIVE_TRANSACTIONS self._transaction = None - assert self.connection._transaction_stack[-1] is self - self.connection._transaction_stack.pop() + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() await _transaction.commit() - await self.connection.__aexit__() + await connection.__aexit__() async def rollback(self) -> None: - async with self.connection._transaction_lock: + connection = self.connection + async with connection._transaction_lock: _transaction = self._transaction if _transaction is not None: + # delete transaction from ACTIVE_TRANSACTIONS self._transaction = None - assert self.connection._transaction_stack[-1] is self - self.connection._transaction_stack.pop() + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() await _transaction.rollback() - await self.connection.__aexit__() + await connection.__aexit__() diff --git a/databasez/interfaces.py b/databasez/interfaces.py index b14e8ec..16dac85 100644 --- a/databasez/interfaces.py +++ b/databasez/interfaces.py @@ -32,6 +32,7 @@ def __init__( connection: ConnectionBackend, existing_transaction: typing.Optional[Transaction] = None, ): + # cannot be a weak ref otherwise connections get lost when retrieving them via transactions self.connection = connection self.raw_transaction = existing_transaction @@ -235,14 +236,14 @@ def __copy__(self) -> DatabaseBackend: @property def owner(self) -> typing.Optional[RootDatabase]: - result = self.__dict__.get("root") + result = self.__dict__.get("owner") if result is None: return None return typing.cast("RootDatabase", result()) @owner.setter def owner(self, value: RootDatabase) -> None: - self.__dict__["root"] = weakref.ref(value) + self.__dict__["owner"] = weakref.ref(value) @abstractmethod async def connect(self, database_url: DatabaseURL, **options: typing.Any) -> None: diff --git a/databasez/testclient.py b/databasez/testclient.py index 1838544..0ade7dc 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -235,7 +235,7 @@ async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) -> await db_client.disconnect() - async def disconnect(self) -> None: + async def disconnect(self, force: bool = False) -> None: if self.drop: await self.drop_database(self.test_db_url) - await super().disconnect() + await super().disconnect(force) diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md index e28afa9..313398b 100644 --- a/docs/connections-and-transactions.md +++ b/docs/connections-and-transactions.md @@ -15,13 +15,14 @@ from databasez import Database **Parameters** -* **url** - The `url` of the connection string. +* **url** - The `url` of the connection string or a Database object to copy from. Default: `None` -* **force_rollback** - A boolean flag indicating if it should force the rollback. +* **force_rollback** - An optional boolean flag for force_rollback. Overwritable at runtime possible. + Note: when None it copies the value from the provided Database object or sets it to False. - Default: `False` + Default: `None` * **config** - A python like dictionary as alternative to the `url` that contains the information to connect to the database. @@ -35,6 +36,21 @@ to connect to the database. at the same time. +**Attributes*** + +* **force_rollback**: + It evaluates its trueness value to the active value of force_rollback for this context. + You can delete it to reset it (`del database.force_rollback`) (it uses the descriptor magic). + +**Functions** + +* **__copy__** - Either usable directly or via copy from the copy module. A fresh Database object with the same options as the existing is created. + Note: for creating a copy with overwritten initial force_rollback you can use: `Database(database_obj, force_rollback=False)`. + Note: you have to connect it. + +* **force_rollback(force_rollback=True)**: - The magic attribute is also function returning a context-manager for temporary overwrites of force_rollback. + + ## Connecting and disconnecting You can control the database connection, by using it as a async context manager. diff --git a/docs/release-notes.md b/docs/release-notes.md index d83be8f..3ea5c1d 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,25 @@ # Release Notes +## 0.9.0 + +### Added + +- `force_rollback` is now a descriptor returning an extensive ForceRollback object. + - Setting True, False, None is now possible for overwriting the value/resetting to the initial value (None). + - Deleting it resets it to the initial value. + - Its trueness value evaluates to the current value, context-sensitive. + - It still can be used as a contextmanager for temporary overwrites. + +### Fixed + +- Fixed refcount for global connections. + +### Changed + +- `connect`/`disconnect` calls are now refcounted. Nesting is now supported. +- ACTIVE_TRANSACTIONS dict is not replaced anymore when initialized. + + ## 0.8.5 ### Added diff --git a/tests/test_databases.py b/tests/test_databases.py index f6f8111..107d795 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -467,39 +467,92 @@ async def test_connect_and_disconnect(database_mixed_url): database = Database(**data) assert not database.is_connected - assert database._force_rollback is False + assert not database.force_rollback assert database.engine is None + assert database.ref_counter == 0 await database.connect() assert database.is_connected assert database.engine is not None + assert database.ref_counter == 1 # copy copied_db = database.__copy__() - assert database._force_rollback is False + assert copied_db.ref_counter == 0 + assert not database.force_rollback assert not copied_db.is_connected assert copied_db.engine is None # second method copied_db = Database(database, force_rollback=True) + assert copied_db.ref_counter == 0 assert not copied_db.is_connected assert copied_db.engine is None - assert copied_db._force_rollback is True + assert copied_db.force_rollback copied_db2 = copied_db.__copy__() - assert copied_db2._force_rollback is True + assert copied_db2.force_rollback old_engine = database.engine + assert database.ref_counter == 1 await database.disconnect() assert not database.is_connected + assert database.ref_counter == 0 - # connect and disconnect idempotence + # connect and disconnect refcounting await database.connect() assert database.engine is not old_engine - await database.connect() assert database.is_connected + old_engine = database.engine + # nest + async with database: + assert database.ref_counter == 2 + assert database.engine is old_engine + assert database.is_connected + assert database.ref_counter == 1 await database.disconnect() - await database.disconnect() + assert database.ref_counter == 0 assert not database.is_connected + assert not database._global_connection + assert not database._global_transaction + + +@pytest.mark.asyncio +async def test_force_rollback(database_url): + async with Database(database_url, force_rollback=False) as database: + assert not database.force_rollback + database.force_rollback = True + assert database.force_rollback + # execute() + data = {"text": "hello", "boolean": True, "int": 2} + values = {"data": data} + query = session.insert() + await database.execute(query, values) + async with database.connection() as connection_1: + assert connection_1 is database._global_connection + with database.force_rollback(False): + assert not database.force_rollback + async with database.connection() as connection_2: + assert connection_2 is not database._global_connection + if database.url.dialect != "sqlite": + # sqlite has locking problems + data = {"text": "hello", "boolean": True, "int": 1} + values = {"data": data} + query = session.insert() + await database.execute(query, values) + + # now reset + del database.force_rollback + assert not database.force_rollback + + async with Database(database_url, force_rollback=True) as database: + # fetch_all() + query = session.select() + results = await database.fetch_all(query=query) + if database.url.dialect == "sqlite": + assert len(results) == 0 + else: + assert len(results) == 1 + assert results[0].data == {"text": "hello", "boolean": True, "int": 1} @pytest.mark.asyncio diff --git a/tests/test_transactions.py b/tests/test_transactions.py index 14e4fc3..300cf36 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -159,6 +159,8 @@ async def check_child_connection(database: Database): # Should not have a connection for the task anymore assert len(database._connection_map) == 0 + # now cleanup + await database.disconnect() @pytest.mark.asyncio @@ -195,12 +197,17 @@ async def test_transaction_context_cleanup_garbagecollector(database_url): assert ACTIVE_TRANSACTIONS.get() is None async with Database(database_url) as database: - transaction = database.transaction() - await transaction.start() - # Should be tracking the transaction open_transactions = ACTIVE_TRANSACTIONS.get() assert isinstance(open_transactions, MutableMapping) + # the global one is always created + assert len(open_transactions) == 1 + transaction = database.transaction() + await transaction.start() + # is replaced after start() call + open_transactions = ACTIVE_TRANSACTIONS.get() + assert len(open_transactions) == 2 + assert open_transactions.get(transaction) is transaction._transaction # neither .commit, .rollback, nor .__aexit__ are called @@ -209,17 +216,18 @@ async def test_transaction_context_cleanup_garbagecollector(database_url): # A strong reference to the transaction is kept alive by the connection's # ._transaction_stack, so it is still be tracked at this point. - assert len(open_transactions) == 1 + assert len(open_transactions) == 2 # If that were magically cleared, the transaction would be cleaned up, # but as it stands this always causes a hang during teardown at # `Database(...).disconnect()` if the transaction is not closed. transaction = database.connection()._transaction_stack[-1] await transaction.rollback() + assert transaction.connection._connection_counter == 0 del transaction # Now with the transaction rolled-back, it should be cleaned up. - assert len(open_transactions) == 0 + assert len(open_transactions) == 1 @pytest.mark.asyncio