Skip to content

Commit

Permalink
improve multiloop callforwarding (#46)
Browse files Browse the repository at this point in the history
Changes:

- improve call forwarding
- allow now using connect instead of the async contextmanager
- unbreak idioms which does not use the returned Database of async
  contextmanager but the parent database in multi loop contexts
- unbreak force_rollback in multi loop contexts
  • Loading branch information
devkral authored Aug 21, 2024
1 parent 46d5d2a commit 0893af7
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 76 deletions.
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.9.6"
__version__ = "0.9.7"

__all__ = ["Database", "DatabaseURL"]
27 changes: 14 additions & 13 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

@multiloop_protector(True)
@multiloop_protector(False)
async def __aenter__(self) -> Connection:
async with self._connection_lock:
self._connection_counter += 1
Expand All @@ -65,6 +65,7 @@ async def __aenter__(self) -> Connection:
raise e
return self

@multiloop_protector(False)
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
Expand All @@ -89,7 +90,7 @@ async def __aexit__(
def _loop(self) -> typing.Any:
return self._database._loop

@multiloop_protector(True)
@multiloop_protector(False)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -99,7 +100,7 @@ async def fetch_all(
async with self._query_lock:
return await self._connection.fetch_all(built_query)

@multiloop_protector(True)
@multiloop_protector(False)
async def fetch_one(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -110,7 +111,7 @@ async def fetch_one(
async with self._query_lock:
return await self._connection.fetch_one(built_query, pos=pos)

@multiloop_protector(True)
@multiloop_protector(False)
async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -122,7 +123,7 @@ async def fetch_val(
async with self._query_lock:
return await self._connection.fetch_val(built_query, column, pos=pos)

@multiloop_protector(True)
@multiloop_protector(False)
async def execute(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -136,7 +137,7 @@ async def execute(
async with self._query_lock:
return await self._connection.execute(query, values)

@multiloop_protector(True)
@multiloop_protector(False)
async def execute_many(
self, query: typing.Union[ClauseElement, str], values: typing.Any = None
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
Expand All @@ -148,7 +149,7 @@ async def execute_many(
async with self._query_lock:
return await self._connection.execute_many(query, values)

@multiloop_protector(True)
@multiloop_protector(False)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -160,7 +161,7 @@ async def iterate(
async for record in self._connection.iterate(built_query, batch_size):
yield record

@multiloop_protector(True)
@multiloop_protector(False)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -172,7 +173,7 @@ async def batched_iterate(
async for records in self._connection.batched_iterate(built_query, batch_size):
yield records

@multiloop_protector(True)
@multiloop_protector(False)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
Expand All @@ -182,15 +183,15 @@ async def run_sync(
async with self._query_lock:
return await self._connection.run_sync(fn, *args, **kwargs)

@multiloop_protector(True)
@multiloop_protector(False)
async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
await self.run_sync(meta.create_all, **kwargs)

@multiloop_protector(True)
@multiloop_protector(False)
async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
await self.run_sync(meta.drop_all, **kwargs)

@multiloop_protector(True)
@multiloop_protector(False)
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(weakref.ref(self), force_rollback, **kwargs)

Expand All @@ -200,7 +201,7 @@ def async_connection(self) -> typing.Any:
"""The first layer (sqlalchemy)."""
return self._connection.async_connection

@multiloop_protector(True)
@multiloop_protector(False)
async def get_raw_connection(self) -> typing.Any:
"""The real raw connection (driver)."""
return await self.async_connection.get_raw_connection()
Expand Down
99 changes: 61 additions & 38 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ def init() -> None:
)


# we need a dict to ensure the references are kept
ACTIVE_DATABASES: ContextVar[typing.Optional[typing.Dict[typing.Any, Database]]] = ContextVar(
"ACTIVE_DATABASES", default=None
)


ACTIVE_FORCE_ROLLBACKS: ContextVar[
typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]]
] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None)
Expand Down Expand Up @@ -149,11 +143,13 @@ class Database:
"""

_connection_map: weakref.WeakKeyDictionary[asyncio.Task, Connection]
_databases_map: typing.Dict[typing.Any, Database]
_loop: typing.Any = None
backend: interfaces.DatabaseBackend
url: DatabaseURL
options: typing.Any
is_connected: bool = False
_call_hooks: bool = True
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
Expand All @@ -173,6 +169,7 @@ def __init__(
self.backend = url.backend.__copy__()
self.url = url.url
self.options = url.options
self._call_hooks = url._call_hooks
if force_rollback is None:
force_rollback = bool(url.force_rollback)
else:
Expand All @@ -190,6 +187,7 @@ def __init__(
self._force_rollback = ForceRollback(force_rollback)
self.backend.owner = self
self._connection_map = weakref.WeakKeyDictionary()
self._databases_map = {}

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
Expand Down Expand Up @@ -224,6 +222,14 @@ def _connection(self, connection: typing.Optional[Connection]) -> typing.Optiona
return self._connection

async def inc_refcount(self) -> bool:
"""
Internal method to bump the ref_count.
Return True if ref_count is 0, False otherwise.
Should not be used outside of tests. Use connect and hooks instead.
Not multithreading safe!
"""
async with self.ref_lock:
self.ref_counter += 1
# on the first call is count is 1 because of the former +1
Expand All @@ -232,6 +238,14 @@ async def inc_refcount(self) -> bool:
return False

async def decr_refcount(self) -> bool:
"""
Internal method to decrease the ref_count.
Return True if ref_count drops to 0, False otherwise.
Should not be used outside of tests. Use disconnect and hooks instead.
Not multithreading safe!
"""
async with self.ref_lock:
self.ref_counter -= 1
# on the last call, the count is 0
Expand All @@ -242,38 +256,53 @@ async def decr_refcount(self) -> bool:
async def connect_hook(self) -> None:
"""Refcount protected connect hook. Executed begore engine and global connection setup."""

@multiloop_protector(True)
async def connect(self) -> bool:
"""
Establish the connection pool.
"""
loop = asyncio.get_running_loop()
if self._loop is not None and loop != self._loop:
# copy when not in map
if loop not in self._databases_map:
# prevent side effects of connect_hook
database = self.__copy__()
database._call_hooks = False
assert self._global_connection
database._global_connection = await self._global_connection.__aenter__()
self._databases_map[loop] = database
# forward call
return await self._databases_map[loop].connect()

if not await self.inc_refcount():
assert self.is_connected, "ref_count < 0"
return False
try:
await self.connect_hook()
except BaseException as exc:
await self.decr_refcount()
raise exc
if self._call_hooks:
try:
await self.connect_hook()
except BaseException as exc:
await self.decr_refcount()
raise exc
self._loop = asyncio.get_event_loop()

await self.backend.connect(self.url, **self.options)
logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA)
self.is_connected = True

assert self._global_connection is None

self._global_connection = Connection(self, self.backend, force_rollback=True)
if self._global_connection is None:
self._global_connection = Connection(self, self.backend, force_rollback=True)
return True

async def disconnect_hook(self) -> None:
"""Refcount protected disconnect hook. Executed after connection, engine cleanup."""

@multiloop_protector(True)
async def disconnect(self, force: bool = False) -> bool:
@multiloop_protector(True, inject_parent=True)
async def disconnect(
self, force: bool = False, *, parent_database: typing.Optional[Database] = None
) -> bool:
"""
Close all connections in the connection pool.
"""
# parent_database is injected and should not be specified manually
if not await self.decr_refcount() or force:
if not self.is_connected:
logger.debug("Already disconnected, skipping disconnection")
Expand All @@ -282,6 +311,14 @@ async def disconnect(self, force: bool = False) -> bool:
logger.warning("Force disconnect, despite refcount not 0")
else:
return False
if parent_database is not None:
loop = asyncio.get_running_loop()
del parent_database._databases_map[loop]
if force:
for sub_database in self._databases_map.values():
await sub_database.disconnect(True)
self._databases_map = {}
assert not self._databases_map, "sub databases still active"

try:
assert self._global_connection is not None
Expand All @@ -297,23 +334,15 @@ async def disconnect(self, force: bool = False) -> bool:
self.is_connected = False
await self.backend.disconnect()
self._loop = None
await self.disconnect_hook()
if self._call_hooks:
await self.disconnect_hook()
return True

async def __aenter__(self) -> "Database":
await self.connect()
# get right database
loop = asyncio.get_running_loop()
database = self
if self._loop is not None and loop != self._loop:
dbs = ACTIVE_DATABASES.get()
if dbs is None:
dbs = {}
else:
dbs = dbs.copy()
database = self.__copy__()
dbs[loop] = database
# it is always a copy required to prevent sideeffects between the contexts
ACTIVE_DATABASES.set(dbs)
await database.connect()
database = self._databases_map.get(loop, self)
return database

async def __aexit__(
Expand All @@ -322,13 +351,7 @@ async def __aexit__(
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
loop = asyncio.get_running_loop()
database = self
if self._loop is not None and loop != self._loop:
dbs = ACTIVE_DATABASES.get()
if dbs is not None:
database = dbs.pop(loop, database)
await database.disconnect()
await self.disconnect()

@multiloop_protector(False)
async def fetch_all(
Expand Down Expand Up @@ -424,7 +447,7 @@ async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
async with self.connection() as connection:
await connection.drop_all(meta, **kwargs)

@multiloop_protector(False, wrap_context_manager=True)
@multiloop_protector(False)
def connection(self) -> Connection:
if self.force_rollback:
return typing.cast(Connection, self._global_connection)
Expand Down
2 changes: 2 additions & 0 deletions databasez/overwrites/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import typing
from pathlib import Path

# ensure jpype.dbapi2 is initialized. Prevent race condition.
import jpype.dbapi2 # noqa
from jpype import addClassPath, isJVMStarted, startJVM

from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction
Expand Down
Loading

0 comments on commit 0893af7

Please sign in to comment.