Skip to content

Commit

Permalink
Fix db drops during tests (#36)
Browse files Browse the repository at this point in the history
Changes:

- add connect, disconnect hook
- fix testclient
- lazy init global connections
- replace global transaction by a connection_transaction of the global_connection
  • Loading branch information
devkral authored Aug 9, 2024
1 parent c1e69f0 commit 016fa29
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 142 deletions.
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.9.1"
__version__ = "0.9.2"

__all__ = ["Database", "DatabaseURL"]
20 changes: 18 additions & 2 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class Connection:
def __init__(self, database: Database, backend: interfaces.DatabaseBackend) -> None:
def __init__(
self, database: Database, backend: interfaces.DatabaseBackend, force_rollback: bool = False
) -> None:
self._database = database
self._backend = backend

Expand All @@ -32,19 +34,30 @@ def __init__(self, database: Database, backend: interfaces.DatabaseBackend) -> N
self._transaction_stack: typing.List[Transaction] = []

self._query_lock = asyncio.Lock()
self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

async def __aenter__(self) -> Connection:
async with self._connection_lock:
self._connection_counter += 1
try:
if self._connection_counter == 1:
if self._database._global_connection is self:
# on first init double increase, so it isn't terminated too early
self._connection_counter += 1
raw_transaction = await self._connection.acquire()
if raw_transaction is not None:
self.connection_transaction = self.transaction(
existing_transaction=raw_transaction
existing_transaction=raw_transaction,
force_rollback=self._force_rollback,
)
# we don't need to call __aenter__ of connection_transaction, it is not on the stack
elif self._force_rollback:
self.connection_transaction = self.transaction(
force_rollback=self._force_rollback
)
# make re-entrant, we have already the connection lock
await self.connection_transaction.start(True)
except BaseException as e:
self._connection_counter -= 1
raise e
Expand All @@ -62,7 +75,10 @@ async def __aexit__(
if self._connection_counter == 0:
try:
if self.connection_transaction:
# __aexit__ needs the connection_transaction parameter
await self.connection_transaction.__aexit__(exc_type, exc_value, traceback)
# untie, for allowing gc
self.connection_transaction = None
finally:
await self._connection.release()
self._database._connection = None
Expand Down
46 changes: 23 additions & 23 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def __init__(
# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
self._global_connection: typing.Optional[Connection] = None
self._global_transaction: typing.Optional[Transaction] = None

self.ref_counter: int = 0
self.ref_lock: asyncio.Lock = asyncio.Lock()
Expand Down Expand Up @@ -232,6 +231,9 @@ async def decr_refcount(self) -> bool:
return True
return False

async def connect_hook(self) -> None:
"""Refcount protected connect hook"""

async def connect(self) -> None:
"""
Establish the connection pool.
Expand All @@ -245,11 +247,12 @@ async def connect(self) -> None:
self.is_connected = True

assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self, self.backend)
self._global_transaction = self._global_connection.transaction(force_rollback=True)
await self._global_transaction.__aenter__()
self._global_connection = Connection(self, self.backend, force_rollback=True)
await self.connect_hook()

async def disconnect_hook(self) -> None:
"""Refcount protected disconnect hook"""

async def disconnect(self, force: bool = False) -> None:
"""
Expand All @@ -265,24 +268,21 @@ async def disconnect(self, force: bool = False) -> None:
return None

assert self._global_connection is not None
assert self._global_transaction is not None

await self._global_transaction.__aexit__()
assert (
self._global_connection._connection_counter == 0
), f"global connection active: {self._global_connection._connection_counter}"

self._global_transaction = None
self._global_connection = None
self._connection = None

await self.backend.disconnect()
logger.info(
"Disconnected from database %s",
self.url.obscure_password,
extra=DISCONNECT_EXTRA,
)
self.is_connected = False
try:
await self.disconnect_hook()
finally:
await self._global_connection.__aexit__()
self._global_connection = None
self._connection = None
try:
await self.backend.disconnect()
logger.info(
"Disconnected from database %s",
self.url.obscure_password,
extra=DISCONNECT_EXTRA,
)
finally:
self.is_connected = False

async def __aenter__(self) -> "Database":
await self.connect()
Expand Down
21 changes: 18 additions & 3 deletions databasez/core/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import typing
import weakref
from contextlib import AsyncExitStack
from contextvars import ContextVar
from types import TracebackType

Expand Down Expand Up @@ -111,13 +112,19 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:

return wrapper # type: ignore

async def start(self) -> Transaction:
async def start(self, without_transaction_lock: bool = False) -> Transaction:
connection = self.connection
async with connection._transaction_lock:

async with AsyncExitStack() as cm:
if not without_transaction_lock:
await cm.enter_async_context(connection._transaction_lock)
is_root = not connection._transaction_stack
_transaction = connection._connection.transaction(self._existing_transaction)
_transaction.owner = self
await connection.__aenter__()
# will be terminated with connection, don't bump
# fixes also a locking issue
if connection.connection_transaction is not self:
await connection.__aenter__()
if self._existing_transaction is None:
await _transaction.start(is_root=is_root, **self._extra_options)
self._transaction = _transaction
Expand All @@ -128,22 +135,30 @@ async def commit(self) -> None:
connection = self.connection
async with connection._transaction_lock:
_transaction = self._transaction
# some transactions are tied to connections and are not on the transaction stack
if _transaction is not None:
# delete transaction from ACTIVE_TRANSACTIONS
self._transaction = None
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await _transaction.commit()
# if a connection_transaction, the connetion cleans it up in __aexit__
# prevent loop
if connection.connection_transaction is not self:
await connection.__aexit__()

async def rollback(self) -> None:
connection = self.connection
async with connection._transaction_lock:
_transaction = self._transaction
# some transactions are tied to connections and are not on the transaction stack
if _transaction is not None:
# delete transaction from ACTIVE_TRANSACTIONS
self._transaction = None
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await _transaction.rollback()
# if a connection_transaction, the connetion cleans it up in __aexit__
# prevent loop
if connection.connection_transaction is not self:
await connection.__aexit__()
136 changes: 70 additions & 66 deletions databasez/testclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
import typing
from typing import Any
Expand Down Expand Up @@ -50,7 +49,6 @@ def __init__(
self.test_db_url = str(getattr(url, "test_db_url", test_database_url))
self.use_existing = getattr(url, "use_existing", use_existing)
self.drop = getattr(url, "drop", drop_database)
asyncio.get_event_loop().run_until_complete(self.setup())
super().__init__(url, force_rollback=force_rollback)
# fix url
if str(self.url) != self.test_db_url:
Expand All @@ -64,22 +62,28 @@ def __init__(
self.use_existing = use_existing
self.drop = drop_database

asyncio.get_event_loop().run_until_complete(self.setup())

super().__init__(test_database_url, force_rollback=force_rollback, **options)

async def setup(self) -> None:
async def connect_hook(self) -> None:
"""
Makes sure the database is created if does not exist or use existing
if needed.
"""
if not self.use_existing:
if await self.database_exists(self.test_db_url):
await self.drop_database(self.test_db_url)
await self.create_database(self.test_db_url)
try:
if await self.database_exists(self.test_db_url):
await self.drop_database(self.test_db_url)
else:
await self.create_database(self.test_db_url)
except (ProgrammingError, OperationalError, TypeError):
self.drop = False
else:
if not await self.database_exists(self.test_db_url):
await self.create_database(self.test_db_url)
try:
await self.create_database(self.test_db_url)
except (ProgrammingError, OperationalError):
self.drop = False
await super().connect_hook()

async def is_database_exist(self) -> Any:
"""
Expand Down Expand Up @@ -153,43 +157,41 @@ async def create_database(
db_client = Database(url, isolation_level="AUTOCOMMIT")
else:
db_client = Database(url)
await db_client.connect()

if dialect_name == "postgresql":
if not template:
template = "template1"
async with db_client:
if dialect_name == "postgresql":
if not template:
template = "template1"

async with db_client.engine.begin() as conn: # type: ignore
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
await conn.execute(sa.text(text))

elif dialect_name == "mysql":
async with db_client.engine.begin() as conn: # type: ignore
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
await conn.execute(sa.text(text))
async with db_client.engine.begin() as conn: # type: ignore
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
await conn.execute(sa.text(text))

elif dialect_name == "sqlite" and database != ":memory:":
if database:
elif dialect_name == "mysql":
async with db_client.engine.begin() as conn: # type: ignore
await conn.execute(sa.text("CREATE TABLE DB(id int)"))
await conn.execute(sa.text("DROP TABLE DB"))
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
await conn.execute(sa.text(text))

else:
async with db_client.engine.begin() as conn: # type: ignore
text = f"CREATE DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
elif dialect_name == "sqlite" and database != ":memory:":
if database:
async with db_client.engine.begin() as conn: # type: ignore
await conn.execute(sa.text("CREATE TABLE DB(id int)"))
await conn.execute(sa.text("DROP TABLE DB"))

await db_client.disconnect()
else:
async with db_client.engine.begin() as conn: # type: ignore
text = f"CREATE DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))

async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) -> Any:
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
database = url.database
dialect_name = url.sqla_url.get_dialect(True).name
dialect_driver = url.sqla_url.get_dialect(True).driver
dialect = url.sqla_url.get_dialect(True)
dialect_name = dialect.name
dialect_driver = dialect.driver

if dialect_name == "postgresql":
url = url.replace(database="postgres")
Expand All @@ -207,35 +209,37 @@ async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
db_client = Database(url, isolation_level="AUTOCOMMIT")
else:
db_client = Database(url)
await db_client.connect()

if dialect_name == "sqlite" and database != ":memory:":
if database:
os.remove(database)
elif dialect_name == "postgresql":
async with db_client.engine.begin() as conn: # type: ignore
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = "pid" if (version >= (9, 2)) else "procpid" # type: ignore
text = """
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""".format(pid_column=pid_column, database=database)
await conn.execute(sa.text(text))

# Drop the database.
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
else:
async with db_client.engine.begin() as conn: # type: ignore
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))

await db_client.disconnect()
async with db_client:
if dialect_name == "sqlite" and database != ":memory:":
if database:
os.remove(database)
elif dialect_name == "postgresql":
async with db_client.connection() as conn1:
async with conn1.async_connection.begin() as conn:
# no connection
if dialect.server_version_info is None:
return
# Disconnect all users from the database we are dropping.
version = dialect.server_version_info
pid_column = "pid" if (version >= (9, 2)) else "procpid"
text = """
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""".format(pid_column=pid_column, database=database)
await conn.execute(sa.text(text))

# Drop the database.
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
else:
async with db_client.connection() as conn1:
async with conn1.async_connection.begin() as conn:
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))

async def disconnect(self, force: bool = False) -> None:
async def disconnect_hook(self) -> None:
if self.drop:
await self.drop_database(self.test_db_url)
await super().disconnect(force)
await super().disconnect_hook()
Loading

0 comments on commit 016fa29

Please sign in to comment.