diff --git a/databasez/__init__.py b/databasez/__init__.py index 9ce5f6c..d69f2a5 100644 --- a/databasez/__init__.py +++ b/databasez/__init__.py @@ -1,5 +1,5 @@ from databasez.core import Database, DatabaseURL -__version__ = "0.9.6" +__version__ = "0.9.7" __all__ = ["Database", "DatabaseURL"] diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 17c6657..94b3962 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -38,7 +38,7 @@ def __init__( self._force_rollback = force_rollback self.connection_transaction: typing.Optional[Transaction] = None - @multiloop_protector(True) + @multiloop_protector(False) async def __aenter__(self) -> Connection: async with self._connection_lock: self._connection_counter += 1 @@ -65,6 +65,7 @@ async def __aenter__(self) -> Connection: raise e return self + @multiloop_protector(False) async def __aexit__( self, exc_type: typing.Optional[typing.Type[BaseException]] = None, @@ -89,7 +90,7 @@ async def __aexit__( def _loop(self) -> typing.Any: return self._database._loop - @multiloop_protector(True) + @multiloop_protector(False) async def fetch_all( self, query: typing.Union[ClauseElement, str], @@ -99,7 +100,7 @@ async def fetch_all( async with self._query_lock: return await self._connection.fetch_all(built_query) - @multiloop_protector(True) + @multiloop_protector(False) async def fetch_one( self, query: typing.Union[ClauseElement, str], @@ -110,7 +111,7 @@ async def fetch_one( async with self._query_lock: return await self._connection.fetch_one(built_query, pos=pos) - @multiloop_protector(True) + @multiloop_protector(False) async def fetch_val( self, query: typing.Union[ClauseElement, str], @@ -122,7 +123,7 @@ async def fetch_val( async with self._query_lock: return await self._connection.fetch_val(built_query, column, pos=pos) - @multiloop_protector(True) + @multiloop_protector(False) async def execute( self, query: typing.Union[ClauseElement, str], @@ -136,7 +137,7 @@ async def execute( async with self._query_lock: return await self._connection.execute(query, values) - @multiloop_protector(True) + @multiloop_protector(False) async def execute_many( self, query: typing.Union[ClauseElement, str], values: typing.Any = None ) -> typing.Union[typing.Sequence[interfaces.Record], int]: @@ -148,7 +149,7 @@ async def execute_many( async with self._query_lock: return await self._connection.execute_many(query, values) - @multiloop_protector(True) + @multiloop_protector(False) async def iterate( self, query: typing.Union[ClauseElement, str], @@ -160,7 +161,7 @@ async def iterate( async for record in self._connection.iterate(built_query, batch_size): yield record - @multiloop_protector(True) + @multiloop_protector(False) async def batched_iterate( self, query: typing.Union[ClauseElement, str], @@ -172,7 +173,7 @@ async def batched_iterate( async for records in self._connection.batched_iterate(built_query, batch_size): yield records - @multiloop_protector(True) + @multiloop_protector(False) async def run_sync( self, fn: typing.Callable[..., typing.Any], @@ -182,15 +183,15 @@ async def run_sync( async with self._query_lock: return await self._connection.run_sync(fn, *args, **kwargs) - @multiloop_protector(True) + @multiloop_protector(False) async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None: await self.run_sync(meta.create_all, **kwargs) - @multiloop_protector(True) + @multiloop_protector(False) async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: await self.run_sync(meta.drop_all, **kwargs) - @multiloop_protector(True) + @multiloop_protector(False) def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": return Transaction(weakref.ref(self), force_rollback, **kwargs) @@ -200,7 +201,7 @@ def async_connection(self) -> typing.Any: """The first layer (sqlalchemy).""" return self._connection.async_connection - @multiloop_protector(True) + @multiloop_protector(False) async def get_raw_connection(self) -> typing.Any: """The real raw connection (driver).""" return await self.async_connection.get_raw_connection() diff --git a/databasez/core/database.py b/databasez/core/database.py index c9a4cbc..1f564e7 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -63,12 +63,6 @@ def init() -> None: ) -# we need a dict to ensure the references are kept -ACTIVE_DATABASES: ContextVar[typing.Optional[typing.Dict[typing.Any, Database]]] = ContextVar( - "ACTIVE_DATABASES", default=None -) - - ACTIVE_FORCE_ROLLBACKS: ContextVar[ typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]] ] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None) @@ -149,11 +143,13 @@ class Database: """ _connection_map: weakref.WeakKeyDictionary[asyncio.Task, Connection] + _databases_map: typing.Dict[typing.Any, Database] _loop: typing.Any = None backend: interfaces.DatabaseBackend url: DatabaseURL options: typing.Any is_connected: bool = False + _call_hooks: bool = True _force_rollback: ForceRollback # descriptor force_rollback = ForceRollbackDescriptor() @@ -173,6 +169,7 @@ def __init__( self.backend = url.backend.__copy__() self.url = url.url self.options = url.options + self._call_hooks = url._call_hooks if force_rollback is None: force_rollback = bool(url.force_rollback) else: @@ -190,6 +187,7 @@ def __init__( self._force_rollback = ForceRollback(force_rollback) self.backend.owner = self self._connection_map = weakref.WeakKeyDictionary() + self._databases_map = {} # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. @@ -224,6 +222,14 @@ def _connection(self, connection: typing.Optional[Connection]) -> typing.Optiona return self._connection async def inc_refcount(self) -> bool: + """ + Internal method to bump the ref_count. + + Return True if ref_count is 0, False otherwise. + + Should not be used outside of tests. Use connect and hooks instead. + Not multithreading safe! + """ async with self.ref_lock: self.ref_counter += 1 # on the first call is count is 1 because of the former +1 @@ -232,6 +238,14 @@ async def inc_refcount(self) -> bool: return False async def decr_refcount(self) -> bool: + """ + Internal method to decrease the ref_count. + + Return True if ref_count drops to 0, False otherwise. + + Should not be used outside of tests. Use disconnect and hooks instead. + Not multithreading safe! + """ async with self.ref_lock: self.ref_counter -= 1 # on the last call, the count is 0 @@ -242,38 +256,53 @@ async def decr_refcount(self) -> bool: async def connect_hook(self) -> None: """Refcount protected connect hook. Executed begore engine and global connection setup.""" - @multiloop_protector(True) async def connect(self) -> bool: """ Establish the connection pool. """ + loop = asyncio.get_running_loop() + if self._loop is not None and loop != self._loop: + # copy when not in map + if loop not in self._databases_map: + # prevent side effects of connect_hook + database = self.__copy__() + database._call_hooks = False + assert self._global_connection + database._global_connection = await self._global_connection.__aenter__() + self._databases_map[loop] = database + # forward call + return await self._databases_map[loop].connect() + if not await self.inc_refcount(): assert self.is_connected, "ref_count < 0" return False - try: - await self.connect_hook() - except BaseException as exc: - await self.decr_refcount() - raise exc + if self._call_hooks: + try: + await self.connect_hook() + except BaseException as exc: + await self.decr_refcount() + raise exc self._loop = asyncio.get_event_loop() await self.backend.connect(self.url, **self.options) logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA) self.is_connected = True - assert self._global_connection is None - - self._global_connection = Connection(self, self.backend, force_rollback=True) + if self._global_connection is None: + self._global_connection = Connection(self, self.backend, force_rollback=True) return True async def disconnect_hook(self) -> None: """Refcount protected disconnect hook. Executed after connection, engine cleanup.""" - @multiloop_protector(True) - async def disconnect(self, force: bool = False) -> bool: + @multiloop_protector(True, inject_parent=True) + async def disconnect( + self, force: bool = False, *, parent_database: typing.Optional[Database] = None + ) -> bool: """ Close all connections in the connection pool. """ + # parent_database is injected and should not be specified manually if not await self.decr_refcount() or force: if not self.is_connected: logger.debug("Already disconnected, skipping disconnection") @@ -282,6 +311,14 @@ async def disconnect(self, force: bool = False) -> bool: logger.warning("Force disconnect, despite refcount not 0") else: return False + if parent_database is not None: + loop = asyncio.get_running_loop() + del parent_database._databases_map[loop] + if force: + for sub_database in self._databases_map.values(): + await sub_database.disconnect(True) + self._databases_map = {} + assert not self._databases_map, "sub databases still active" try: assert self._global_connection is not None @@ -297,23 +334,15 @@ async def disconnect(self, force: bool = False) -> bool: self.is_connected = False await self.backend.disconnect() self._loop = None - await self.disconnect_hook() + if self._call_hooks: + await self.disconnect_hook() return True async def __aenter__(self) -> "Database": + await self.connect() + # get right database loop = asyncio.get_running_loop() - database = self - if self._loop is not None and loop != self._loop: - dbs = ACTIVE_DATABASES.get() - if dbs is None: - dbs = {} - else: - dbs = dbs.copy() - database = self.__copy__() - dbs[loop] = database - # it is always a copy required to prevent sideeffects between the contexts - ACTIVE_DATABASES.set(dbs) - await database.connect() + database = self._databases_map.get(loop, self) return database async def __aexit__( @@ -322,13 +351,7 @@ async def __aexit__( exc_value: typing.Optional[BaseException] = None, traceback: typing.Optional[TracebackType] = None, ) -> None: - loop = asyncio.get_running_loop() - database = self - if self._loop is not None and loop != self._loop: - dbs = ACTIVE_DATABASES.get() - if dbs is not None: - database = dbs.pop(loop, database) - await database.disconnect() + await self.disconnect() @multiloop_protector(False) async def fetch_all( @@ -424,7 +447,7 @@ async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: async with self.connection() as connection: await connection.drop_all(meta, **kwargs) - @multiloop_protector(False, wrap_context_manager=True) + @multiloop_protector(False) def connection(self) -> Connection: if self.force_rollback: return typing.cast(Connection, self._global_connection) diff --git a/databasez/overwrites/jdbc.py b/databasez/overwrites/jdbc.py index 4f02859..394680c 100644 --- a/databasez/overwrites/jdbc.py +++ b/databasez/overwrites/jdbc.py @@ -2,6 +2,8 @@ import typing from pathlib import Path +# ensure jpype.dbapi2 is initialized. Prevent race condition. +import jpype.dbapi2 # noqa from jpype import addClassPath, isJVMStarted, startJVM from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction diff --git a/databasez/utils.py b/databasez/utils.py index 4ea3a97..07262f5 100644 --- a/databasez/utils.py +++ b/databasez/utils.py @@ -1,9 +1,9 @@ import asyncio import inspect import typing -from contextlib import asynccontextmanager from functools import partial, wraps from threading import Thread +from types import TracebackType async_wrapper_slots = ( "_async_wrapped", @@ -140,25 +140,98 @@ def join(self, timeout: typing.Union[float, int, None] = None) -> None: MultiloopProtectorCallable = typing.TypeVar("MultiloopProtectorCallable", bound=typing.Callable) -async def _async_helper( - database: typing.Any, fn: MultiloopProtectorCallable, *args: typing.Any, **kwargs: typing.Any -) -> typing.Any: - # copy - async with database.__class__(database) as new_database: - return await fn(new_database, *args, **kwargs) +class AsyncHelperDatabase: + def __init__( + self, + database: typing.Any, + fn: typing.Callable, + *args: typing.Any, + **kwargs: typing.Any, + ) -> None: + self.database = database.__copy__() + self.fn = partial(fn, self.database, *args, **kwargs) + self.ctm = None + + async def call(self) -> typing.Any: + async with self.database: + return await self.fn() + + def __await__(self) -> typing.Any: + return self.call().__await__() + + async def __aenter__(self) -> typing.Any: + await self.database.__aenter__() + self.ctm = self.fn() + return await self.ctm.__aenter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + assert self.ctm is not None + try: + await self.ctm.__aexit__(exc_type, exc_value, traceback) + finally: + await self.database.__aexit__() + + +class AsyncHelperConnection: + def __init__( + self, + connection: typing.Any, + fn: typing.Callable, + *args: typing.Any, + **kwargs: typing.Any, + ) -> None: + self.connection = connection + self.fn = partial(fn, self.connection, *args, **kwargs) + self.ctm = None + + async def call(self) -> typing.Any: + async with self.connection: + result = self.fn() + if inspect.isawaitable(result): + result = await result + return result + + async def acall(self) -> typing.Any: + return asyncio.run_coroutine_threadsafe(self.call(), self.connection._loop).result() + + def __await__(self) -> typing.Any: + return self.acall().__await__() + async def enter_intern(self) -> typing.Any: + await self.connection.__aenter__() + self.ctm = await self.call() + return await self.ctm.__aenter__() -@asynccontextmanager -async def _contextmanager_helper( - database: typing.Any, fn: MultiloopProtectorCallable, *args: typing.Any, **kwargs: typing.Any -) -> typing.Any: - async with database.__copy__() as new_database: - async with fn(new_database, *args, **kwargs) as result: - yield result + async def exit_intern(self) -> typing.Any: + assert self.ctm is not None + try: + await self.ctm.__aexit__() + finally: + self.ctm = None + await self.connection.__aexit__() + + async def __aenter__(self) -> typing.Any: + return asyncio.run_coroutine_threadsafe( + self.enter_intern(), self.connection._loop + ).result() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + assert self.ctm is not None + asyncio.run_coroutine_threadsafe(self.exit_intern(), self.connection._loop).result() def multiloop_protector( - fail_with_different_loop: bool, wrap_context_manager: bool = False + fail_with_different_loop: bool, inject_parent: bool = False ) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: """For multiple threads or other reasons why the loop changes""" @@ -167,16 +240,29 @@ def multiloop_protector( def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable: @wraps(fn) def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + if inject_parent: + assert "parent_database" not in kwargs, '"parent_database" is a reserved keyword' try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop is not None and self._loop is not None and loop != self._loop: - if fail_with_different_loop: - raise RuntimeError("Different loop used") - if wrap_context_manager: - return _contextmanager_helper(self, fn, *args, **kwargs) - return _async_helper(self, fn, *args, **kwargs) + # redirect call if self is Database and loop is in sub databases referenced + # afaik we can careless continue use the old database object from a subloop and all protected + # methods are forwarded + if hasattr(self, "_databases_map") and loop in self._databases_map: + if inject_parent: + kwargs["parent_database"] = self + self = self._databases_map[loop] + else: + if fail_with_different_loop: + raise RuntimeError("Different loop used") + helper = ( + AsyncHelperDatabase + if hasattr(self, "_databases_map") + else AsyncHelperConnection + ) + return helper(self, fn, *args, **kwargs) return fn(self, *args, **kwargs) return typing.cast(MultiloopProtectorCallable, wrapper) diff --git a/docs/release-notes.md b/docs/release-notes.md index bea57ec..c099b24 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,6 +1,17 @@ # Release Notes +## 0.9.7 + +### Added + +- It is now possible to use connect(), disconnect() instead of a async contextmanager in multi-loop calls (multithreading). + +### Fixed + +- Database calls are forwarded to subdatabase when possible. This unbreaks using not the returned database object. +- `force_rollback` works also in multi-loop call (multithreading). + ## 0.9.6 ### Fixed diff --git a/tests/test_databases.py b/tests/test_databases.py index c3cbe6c..c851e94 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -889,14 +889,27 @@ async def wrap_in_thread(): await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread()) +@pytest.mark.parametrize("force_rollback", [True, False]) @pytest.mark.asyncio -async def test_multi_thread_db_contextmanager(database_url): - async with Database(database_url, force_rollback=False) as database: +async def test_multi_thread_db_contextmanager(database_url, force_rollback): + async with Database(database_url, force_rollback=force_rollback) as database: + query = notes.insert().values(text="examplecontext", completed=True) + await database.execute(query) async def db_connect(depth=3): # many parallel and nested threads async with database as new_database: - await new_database.fetch_one("SELECT 1") + query = notes.select() + result = await database.fetch_one(query) + assert result.text == "examplecontext" + assert result.completed is True + # test delegate to sub database + assert database.engine is new_database.engine + # also this shouldn't fail because redirected + old_refcount = new_database.ref_counter + await database.connect() + assert new_database.ref_counter == old_refcount + 1 + await database.disconnect() ops = [] while depth >= 0: depth -= 1 @@ -906,19 +919,49 @@ async def db_connect(depth=3): await to_thread(asyncio.run, db_connect()) assert database.ref_counter == 0 + if force_rollback: + async with database: + query = notes.select() + result = await database.fetch_one(query) + assert result is None @pytest.mark.asyncio -async def test_multi_thread_db_connect_fails(database_url): +async def test_multi_thread_db_connect(database_url): async with Database(database_url, force_rollback=True) as database: async def db_connect(): await database.connect() + await database.fetch_one("SELECT 1") + await database.disconnect() + + await to_thread(asyncio.run, db_connect()) + + +@pytest.mark.asyncio +async def test_multi_thread_db_fails(database_url): + async with Database(database_url, force_rollback=True) as database: + + async def db_connect(): + # not in same loop + database.disconnect() with pytest.raises(RuntimeError): await to_thread(asyncio.run, db_connect()) +@pytest.mark.asyncio +async def test_error_on_passed_parent_database(database_url): + database = Database(database_url) + # don't allow specifying parent_database + with pytest.raises(AssertionError): + await database.disconnect(parent_database=None) + with pytest.raises(AssertionError): + await database.disconnect(parent_database="") + with pytest.raises(TypeError): + await database.disconnect(False, None) + + @pytest.mark.asyncio async def test_global_connection_is_initialized_lazily(database_url): """