diff --git a/databasez/__init__.py b/databasez/__init__.py index 7f8f867..c1c26b1 100644 --- a/databasez/__init__.py +++ b/databasez/__init__.py @@ -1,5 +1,5 @@ from databasez.core import Database, DatabaseURL -__version__ = "0.9.1" +__version__ = "0.9.2" __all__ = ["Database", "DatabaseURL"] diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 89c27d1..d319e25 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -19,7 +19,9 @@ class Connection: - def __init__(self, database: Database, backend: interfaces.DatabaseBackend) -> None: + def __init__( + self, database: Database, backend: interfaces.DatabaseBackend, force_rollback: bool = False + ) -> None: self._database = database self._backend = backend @@ -32,6 +34,7 @@ def __init__(self, database: Database, backend: interfaces.DatabaseBackend) -> N self._transaction_stack: typing.List[Transaction] = [] self._query_lock = asyncio.Lock() + self._force_rollback = force_rollback self.connection_transaction: typing.Optional[Transaction] = None async def __aenter__(self) -> Connection: @@ -39,12 +42,22 @@ async def __aenter__(self) -> Connection: self._connection_counter += 1 try: if self._connection_counter == 1: + if self._database._global_connection is self: + # on first init double increase, so it isn't terminated too early + self._connection_counter += 1 raw_transaction = await self._connection.acquire() if raw_transaction is not None: self.connection_transaction = self.transaction( - existing_transaction=raw_transaction + existing_transaction=raw_transaction, + force_rollback=self._force_rollback, ) # we don't need to call __aenter__ of connection_transaction, it is not on the stack + elif self._force_rollback: + self.connection_transaction = self.transaction( + force_rollback=self._force_rollback + ) + # make re-entrant, we have already the connection lock + await self.connection_transaction.start(True) except BaseException as e: self._connection_counter -= 1 raise e @@ -62,7 +75,10 @@ async def __aexit__( if self._connection_counter == 0: try: if self.connection_transaction: + # __aexit__ needs the connection_transaction parameter await self.connection_transaction.__aexit__(exc_type, exc_value, traceback) + # untie, for allowing gc + self.connection_transaction = None finally: await self._connection.release() self._database._connection = None diff --git a/databasez/core/database.py b/databasez/core/database.py index c3d756f..8fa9374 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -186,7 +186,6 @@ def __init__( # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None - self._global_transaction: typing.Optional[Transaction] = None self.ref_counter: int = 0 self.ref_lock: asyncio.Lock = asyncio.Lock() @@ -232,6 +231,9 @@ async def decr_refcount(self) -> bool: return True return False + async def connect_hook(self) -> None: + """Refcount protected connect hook""" + async def connect(self) -> None: """ Establish the connection pool. @@ -245,11 +247,12 @@ async def connect(self) -> None: self.is_connected = True assert self._global_connection is None - assert self._global_transaction is None - self._global_connection = Connection(self, self.backend) - self._global_transaction = self._global_connection.transaction(force_rollback=True) - await self._global_transaction.__aenter__() + self._global_connection = Connection(self, self.backend, force_rollback=True) + await self.connect_hook() + + async def disconnect_hook(self) -> None: + """Refcount protected disconnect hook""" async def disconnect(self, force: bool = False) -> None: """ @@ -265,24 +268,21 @@ async def disconnect(self, force: bool = False) -> None: return None assert self._global_connection is not None - assert self._global_transaction is not None - - await self._global_transaction.__aexit__() - assert ( - self._global_connection._connection_counter == 0 - ), f"global connection active: {self._global_connection._connection_counter}" - - self._global_transaction = None - self._global_connection = None - self._connection = None - - await self.backend.disconnect() - logger.info( - "Disconnected from database %s", - self.url.obscure_password, - extra=DISCONNECT_EXTRA, - ) - self.is_connected = False + try: + await self.disconnect_hook() + finally: + await self._global_connection.__aexit__() + self._global_connection = None + self._connection = None + try: + await self.backend.disconnect() + logger.info( + "Disconnected from database %s", + self.url.obscure_password, + extra=DISCONNECT_EXTRA, + ) + finally: + self.is_connected = False async def __aenter__(self) -> "Database": await self.connect() diff --git a/databasez/core/transaction.py b/databasez/core/transaction.py index 8874650..2c07a70 100644 --- a/databasez/core/transaction.py +++ b/databasez/core/transaction.py @@ -3,6 +3,7 @@ import functools import typing import weakref +from contextlib import AsyncExitStack from contextvars import ContextVar from types import TracebackType @@ -111,13 +112,19 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore - async def start(self) -> Transaction: + async def start(self, without_transaction_lock: bool = False) -> Transaction: connection = self.connection - async with connection._transaction_lock: + + async with AsyncExitStack() as cm: + if not without_transaction_lock: + await cm.enter_async_context(connection._transaction_lock) is_root = not connection._transaction_stack _transaction = connection._connection.transaction(self._existing_transaction) _transaction.owner = self - await connection.__aenter__() + # will be terminated with connection, don't bump + # fixes also a locking issue + if connection.connection_transaction is not self: + await connection.__aenter__() if self._existing_transaction is None: await _transaction.start(is_root=is_root, **self._extra_options) self._transaction = _transaction @@ -128,22 +135,30 @@ async def commit(self) -> None: connection = self.connection async with connection._transaction_lock: _transaction = self._transaction + # some transactions are tied to connections and are not on the transaction stack if _transaction is not None: # delete transaction from ACTIVE_TRANSACTIONS self._transaction = None assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await _transaction.commit() + # if a connection_transaction, the connetion cleans it up in __aexit__ + # prevent loop + if connection.connection_transaction is not self: await connection.__aexit__() async def rollback(self) -> None: connection = self.connection async with connection._transaction_lock: _transaction = self._transaction + # some transactions are tied to connections and are not on the transaction stack if _transaction is not None: # delete transaction from ACTIVE_TRANSACTIONS self._transaction = None assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await _transaction.rollback() + # if a connection_transaction, the connetion cleans it up in __aexit__ + # prevent loop + if connection.connection_transaction is not self: await connection.__aexit__() diff --git a/databasez/testclient.py b/databasez/testclient.py index 0ade7dc..d126109 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -1,4 +1,3 @@ -import asyncio import os import typing from typing import Any @@ -50,7 +49,6 @@ def __init__( self.test_db_url = str(getattr(url, "test_db_url", test_database_url)) self.use_existing = getattr(url, "use_existing", use_existing) self.drop = getattr(url, "drop", drop_database) - asyncio.get_event_loop().run_until_complete(self.setup()) super().__init__(url, force_rollback=force_rollback) # fix url if str(self.url) != self.test_db_url: @@ -64,22 +62,28 @@ def __init__( self.use_existing = use_existing self.drop = drop_database - asyncio.get_event_loop().run_until_complete(self.setup()) - super().__init__(test_database_url, force_rollback=force_rollback, **options) - async def setup(self) -> None: + async def connect_hook(self) -> None: """ Makes sure the database is created if does not exist or use existing if needed. """ if not self.use_existing: - if await self.database_exists(self.test_db_url): - await self.drop_database(self.test_db_url) - await self.create_database(self.test_db_url) + try: + if await self.database_exists(self.test_db_url): + await self.drop_database(self.test_db_url) + else: + await self.create_database(self.test_db_url) + except (ProgrammingError, OperationalError, TypeError): + self.drop = False else: if not await self.database_exists(self.test_db_url): - await self.create_database(self.test_db_url) + try: + await self.create_database(self.test_db_url) + except (ProgrammingError, OperationalError): + self.drop = False + await super().connect_hook() async def is_database_exist(self) -> Any: """ @@ -153,43 +157,41 @@ async def create_database( db_client = Database(url, isolation_level="AUTOCOMMIT") else: db_client = Database(url) - await db_client.connect() - - if dialect_name == "postgresql": - if not template: - template = "template1" + async with db_client: + if dialect_name == "postgresql": + if not template: + template = "template1" - async with db_client.engine.begin() as conn: # type: ignore - text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( - quote(conn, database), encoding, quote(conn, template) - ) - await conn.execute(sa.text(text)) - - elif dialect_name == "mysql": - async with db_client.engine.begin() as conn: # type: ignore - text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( - quote(conn, database), encoding - ) - await conn.execute(sa.text(text)) + async with db_client.engine.begin() as conn: # type: ignore + text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( + quote(conn, database), encoding, quote(conn, template) + ) + await conn.execute(sa.text(text)) - elif dialect_name == "sqlite" and database != ":memory:": - if database: + elif dialect_name == "mysql": async with db_client.engine.begin() as conn: # type: ignore - await conn.execute(sa.text("CREATE TABLE DB(id int)")) - await conn.execute(sa.text("DROP TABLE DB")) + text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( + quote(conn, database), encoding + ) + await conn.execute(sa.text(text)) - else: - async with db_client.engine.begin() as conn: # type: ignore - text = f"CREATE DATABASE {quote(conn, database)}" - await conn.execute(sa.text(text)) + elif dialect_name == "sqlite" and database != ":memory:": + if database: + async with db_client.engine.begin() as conn: # type: ignore + await conn.execute(sa.text("CREATE TABLE DB(id int)")) + await conn.execute(sa.text("DROP TABLE DB")) - await db_client.disconnect() + else: + async with db_client.engine.begin() as conn: # type: ignore + text = f"CREATE DATABASE {quote(conn, database)}" + await conn.execute(sa.text(text)) async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) -> Any: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database - dialect_name = url.sqla_url.get_dialect(True).name - dialect_driver = url.sqla_url.get_dialect(True).driver + dialect = url.sqla_url.get_dialect(True) + dialect_name = dialect.name + dialect_driver = dialect.driver if dialect_name == "postgresql": url = url.replace(database="postgres") @@ -207,35 +209,37 @@ async def drop_database(self, url: typing.Union[str, "sa.URL", DatabaseURL]) -> db_client = Database(url, isolation_level="AUTOCOMMIT") else: db_client = Database(url) - await db_client.connect() - - if dialect_name == "sqlite" and database != ":memory:": - if database: - os.remove(database) - elif dialect_name == "postgresql": - async with db_client.engine.begin() as conn: # type: ignore - # Disconnect all users from the database we are dropping. - version = conn.dialect.server_version_info - pid_column = "pid" if (version >= (9, 2)) else "procpid" # type: ignore - text = """ - SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{database}' - AND {pid_column} <> pg_backend_pid(); - """.format(pid_column=pid_column, database=database) - await conn.execute(sa.text(text)) - - # Drop the database. - text = f"DROP DATABASE {quote(conn, database)}" - await conn.execute(sa.text(text)) - else: - async with db_client.engine.begin() as conn: # type: ignore - text = f"DROP DATABASE {quote(conn, database)}" - await conn.execute(sa.text(text)) - - await db_client.disconnect() + async with db_client: + if dialect_name == "sqlite" and database != ":memory:": + if database: + os.remove(database) + elif dialect_name == "postgresql": + async with db_client.connection() as conn1: + async with conn1.async_connection.begin() as conn: + # no connection + if dialect.server_version_info is None: + return + # Disconnect all users from the database we are dropping. + version = dialect.server_version_info + pid_column = "pid" if (version >= (9, 2)) else "procpid" + text = """ + SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND {pid_column} <> pg_backend_pid(); + """.format(pid_column=pid_column, database=database) + await conn.execute(sa.text(text)) + + # Drop the database. + text = f"DROP DATABASE {quote(conn, database)}" + await conn.execute(sa.text(text)) + else: + async with db_client.connection() as conn1: + async with conn1.async_connection.begin() as conn: + text = f"DROP DATABASE {quote(conn, database)}" + await conn.execute(sa.text(text)) - async def disconnect(self, force: bool = False) -> None: + async def disconnect_hook(self) -> None: if self.drop: await self.drop_database(self.test_db_url) - await super().disconnect(force) + await super().disconnect_hook() diff --git a/docs/release-notes.md b/docs/release-notes.md index 35c03e8..89f23d9 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,16 @@ # Release Notes +## 0.9.2 + +### Added + +- Expose customization hooks for disconnects, connects. + +### Fixed + +- Testclient has issues with missing permissions. +- Lazy global connection. + ## 0.9.1 ### Added diff --git a/pyproject.toml b/pyproject.toml index 3351aa1..fa6b3c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,10 +141,10 @@ ignore_errors = true addopts = ["--strict-config", "--strict-markers", "--pdbcls=IPython.terminal.debugger:Pdb"] xfail_strict = true junit_family = "xunit2" -setup_timeout = 2 -execution_timeout = 20 +setup_timeout = 5 +execution_timeout = 10 teardown_timeout = 5 -reruns = 3 +reruns = 1 reruns_delay = 1 [tool.hatch.build.targets.sdist] diff --git a/tests/shared_db.py b/tests/shared_db.py index 2ca81a6..619b40c 100644 --- a/tests/shared_db.py +++ b/tests/shared_db.py @@ -70,7 +70,9 @@ def process_result_value(self, value, dialect): ) -async def database_client(url: typing.Union[dict, str]) -> DatabaseTestClient: +async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTestClient: + if meta is None: + meta = metadata if isinstance(url, str): is_sqlite = url.startswith("sqlite") database = DatabaseTestClient( @@ -79,10 +81,12 @@ async def database_client(url: typing.Union[dict, str]) -> DatabaseTestClient: else: database = Database(config=url) await database.connect() - await database.create_all(metadata) + await database.create_all(meta) return database -async def stop_database_client(database: Database): - await database.drop_all(metadata) +async def stop_database_client(database: Database, meta=None): + if meta is None: + meta = metadata + await database.drop_all(meta) await database.disconnect() diff --git a/tests/test_database_testclient.py b/tests/test_database_testclient.py new file mode 100644 index 0000000..3540012 --- /dev/null +++ b/tests/test_database_testclient.py @@ -0,0 +1,34 @@ +import os + +import pyodbc +import pytest + +from databasez import Database +from databasez.testclient import DatabaseTestClient + +assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." + +DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] + +if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): + DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@pytest.mark.asyncio +async def test_non_existing_normal(database_url): + test_db = DatabaseTestClient(database_url) + async with Database(test_db) as database: + assert database.is_connected + async with Database(test_db) as database: + assert database.is_connected + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@pytest.mark.asyncio +async def test_non_existing_client(database_url): + test_db = DatabaseTestClient(database_url) + async with DatabaseTestClient(test_db) as database: + assert database.is_connected + async with DatabaseTestClient(test_db, drop_database=True) as database: + assert database.is_connected diff --git a/tests/test_databases.py b/tests/test_databases.py index 8ffd523..86ac5ed 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -515,7 +515,6 @@ async def test_connect_and_disconnect(database_mixed_url): assert database.ref_counter == 0 assert not database.is_connected assert not database._global_connection - assert not database._global_transaction @pytest.mark.asyncio diff --git a/tests/test_integration.py b/tests/test_integration.py index 04108ab..4fd10c0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import pytest @@ -11,7 +12,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from databasez import Database, DatabaseURL +from databasez import Database +from tests.shared_db import database_client, stop_database_client from tests.test_databases import DATABASE_URLS metadata = sqlalchemy.MetaData() @@ -25,37 +27,13 @@ ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): +@pytest.fixture(params=DATABASE_URLS) +def database_url(request): # Create test databases - for url in DATABASE_URLS: - database_url = str(DatabaseURL(url)) - database_url = ( - database_url.replace("sqlite+aiosqlite:", "sqlite:") - .replace("mssql+aioodbc:", "mssql+pyodbc:") - .replace("postgresql+asyncpg:", "postgresql+psycopg:") - .replace("mysql+asyncmy:", "mysql+pymysql:") - .replace("mysql+aiomysql:", "mysql+pymysql:") - ) - - engine = sqlalchemy.create_engine(database_url) - metadata.create_all(engine) - - # Run the test suite - yield - - for url in DATABASE_URLS: - database_url = str(DatabaseURL(url)) - database_url = ( - database_url.replace("sqlite+aiosqlite:", "sqlite:") - .replace("mssql+aioodbc:", "mssql+pyodbc:") - .replace("postgresql+asyncpg:", "postgresql+psycopg:") - .replace("mysql+asyncmy:", "mysql+pymysql:") - .replace("mysql+aiomysql:", "mysql+pymysql:") - ) - - engine = sqlalchemy.create_engine(database_url) - metadata.drop_all(engine) + loop = asyncio.new_event_loop() + database = loop.run_until_complete(database_client(request.param, metadata)) + yield str(database.url) + loop.run_until_complete(stop_database_client(database, metadata)) def get_app(database_url): @@ -119,7 +97,6 @@ async def shutdown(): return app -@pytest.mark.parametrize("database_url", DATABASE_URLS) def test_integration(database_url): app = get_app(database_url) @@ -139,7 +116,6 @@ def test_integration(database_url): assert response.json() == [] -@pytest.mark.parametrize("database_url", DATABASE_URLS) def test_integration_esmerald(database_url): app = get_esmerald_app(database_url) diff --git a/tests/test_transactions.py b/tests/test_transactions.py index 300cf36..5455716 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -199,14 +199,12 @@ async def test_transaction_context_cleanup_garbagecollector(database_url): async with Database(database_url) as database: # Should be tracking the transaction open_transactions = ACTIVE_TRANSACTIONS.get() - assert isinstance(open_transactions, MutableMapping) - # the global one is always created - assert len(open_transactions) == 1 + assert open_transactions is None transaction = database.transaction() await transaction.start() # is replaced after start() call open_transactions = ACTIVE_TRANSACTIONS.get() - assert len(open_transactions) == 2 + assert len(open_transactions) == 1 assert open_transactions.get(transaction) is transaction._transaction @@ -216,7 +214,7 @@ async def test_transaction_context_cleanup_garbagecollector(database_url): # A strong reference to the transaction is kept alive by the connection's # ._transaction_stack, so it is still be tracked at this point. - assert len(open_transactions) == 2 + assert len(open_transactions) == 1 # If that were magically cleared, the transaction would be cleaned up, # but as it stands this always causes a hang during teardown at @@ -227,7 +225,7 @@ async def test_transaction_context_cleanup_garbagecollector(database_url): del transaction # Now with the transaction rolled-back, it should be cleaned up. - assert len(open_transactions) == 1 + assert len(open_transactions) == 0 @pytest.mark.asyncio