Skip to content

Commit

Permalink
Return deferred db init (#278) (#280)
Browse files Browse the repository at this point in the history
* Return deferred db init (#278)

* fix: return deferred db init
---------

Co-authored-by: Nikita <[email protected]>
  • Loading branch information
kalombos and F1int0m authored Aug 5, 2024
1 parent a41ebef commit d20b6c4
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 73 deletions.
164 changes: 94 additions & 70 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import logging
import warnings
from typing import Type, Optional, Any, AsyncIterator, Iterator
from typing import Dict, Type, Optional, Any, AsyncIterator, Iterator

import peewee
from playhouse import postgres_ext as ext
Expand All @@ -13,15 +13,25 @@
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__


class AioDatabase:
class AioDatabase(peewee.Database):
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]
pool_backend: PoolBackend

def __init__(self, database: Optional[str], **kwargs: Any) -> None:
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
@property
def connect_params_async(self) -> Dict[str, Any]:
...

def init(self, database: Optional[str], **kwargs: Any) -> None:
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
Expand All @@ -30,6 +40,8 @@ def __init__(self, database: Optional[str], **kwargs: Any) -> None:
async def aio_connect(self) -> None:
"""Creates a connection pool
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')
await self.pool_backend.connect()

@property
Expand All @@ -41,6 +53,9 @@ def is_connected(self) -> bool:
async def aio_close(self) -> None:
"""Terminate pool backend. The pool is closed until you run aio_connect manually
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

await self.pool_backend.terminate()

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -93,12 +108,17 @@ def execute_sql(self, *args, **kwargs):
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs)))
logging.log(
self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs))
)
return super().execute_sql(*args, **kwargs)

def aio_connection(self) -> ConnectionContextManager:
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params=None, fetch_results=None):
Expand Down Expand Up @@ -183,38 +203,28 @@ def transaction_async(self) -> Any:
return self.aio_atomic()


class AioPostgresqlMixin(AioDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
class AioPostgresqlMixin(AioDatabase, peewee.PostgresqlDatabase):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection.
"""

_enable_json: bool
_enable_hstore: bool

pool_backend_cls = PostgresqlPoolBackend

if psycopg2:
Error = psycopg2.Error

def init_async(self, enable_json: bool = False, enable_hstore: bool =False) -> None:
def init_async(self, enable_json: bool = False, enable_hstore: bool = False) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
self._enable_json = enable_json
self._enable_hstore = enable_hstore

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
})
return kwargs


class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
"""PostgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.
:param max_connections: connections pool size
Expand All @@ -226,22 +236,37 @@ class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
min_connections: int = 1
max_connections: int = 20

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

self.init_async()
super().init(database, **kwargs)

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
}
)
return kwargs


class PooledPostgresqlExtDatabase(
AioPostgresqlMixin,
PooledPostgresqlDatabase,
ext.PostgresqlExtDatabase
):
"""PosgreSQL database extended driver providing **single drop-in sync**
Expand All @@ -250,8 +275,6 @@ class PooledPostgresqlExtDatabase(
JSON fields support is always enabled, HStore supports is enabled by
default, but can be disabled with ``register_hstore=False`` argument.
:param max_connections: connections pool size
Example::
database = PooledPostgresqlExtDatabase('test', register_hstore=False,
Expand All @@ -262,20 +285,11 @@ class PooledPostgresqlExtDatabase(
"""

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
super().init(database, **kwargs)


class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
Expand All @@ -291,6 +305,9 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
min_connections: int = 1
max_connections: int = 20

pool_backend_cls = MysqlPoolBackend

if pymysql:
Expand All @@ -299,27 +316,34 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
def init(self, database: Optional[str], **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)

if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

super().init(database, **kwargs)

@property
def connect_params_async(self):
def connect_params_async(self) -> Dict[str, Any]:
"""Connection parameters for `aiomysql.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
})
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
}
)
return kwargs


# DEPRECATED Databases


class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
class PostgresqlDatabase(PooledPostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync** connection
and **single async connection** interface.
Expand All @@ -330,15 +354,16 @@ class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlDatabase` is deprecated, use `PooledPostgresqlDatabase` instead.",
DeprecationWarning
)
self.min_connections = 1
self.max_connections = 1
super().init(database, **kwargs)
self.init_async()


class MySQLDatabase(PooledMySQLDatabase):
Expand All @@ -352,17 +377,19 @@ class MySQLDatabase(PooledMySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`MySQLDatabase` is deprecated, use `PooledMySQLDatabase` instead.",
DeprecationWarning
)
super().init(database, **kwargs)
self.min_connections = 1
self.max_connections = 1


class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
class PostgresqlExtDatabase(PooledPostgresqlExtDatabase):
"""PosgreSQL database extended driver providing **single drop-in sync**
connection and **single async connection** interface.
Expand All @@ -377,15 +404,12 @@ class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlExtDatabase` is deprecated, use `PooledPostgresqlExtDatabase` instead.",
DeprecationWarning
)
self.min_connections = 1
self.max_connections = 1
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
1 change: 1 addition & 0 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""

def __init__(self, *, database: str, **kwargs: Any) -> None:
self.pool: Optional[PoolProtocol] = None
self.database = database
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def db(request):

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)

database._allow_sync = False
with database.allow_sync():
for model in ALL_MODELS:
Expand Down
4 changes: 2 additions & 2 deletions tests/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'port': int(os.environ.get('POSTGRES_PORT', 5432)),
'password': 'postgres',
'user': 'postgres',
'connect_timeout': 30
'connection_timeout': 30
}

MYSQL_DEFAULTS = {
Expand All @@ -18,7 +18,7 @@
'port': int(os.environ.get('MYSQL_PORT', 3306)),
'user': 'root',
'password': 'mysql',
'connect_timeout': 30
'connection_timeout': 30
}

DB_DEFAULTS = {
Expand Down
16 changes: 15 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest

from peewee_async import connection_context
from tests.conftest import dbs_all
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, MYSQL_DBS, PG_DBS
from tests.db_config import DB_DEFAULTS, DB_CLASSES
from tests.models import TestModel


Expand Down Expand Up @@ -53,3 +55,15 @@ async def test_aio_close_idempotent(db):

await db.aio_close()
assert db.is_connected is False


@pytest.mark.parametrize('db_name', PG_DBS + MYSQL_DBS)
async def test_deferred_init(db_name):
database: AioDatabase = DB_CLASSES[db_name](None)

with pytest.raises(Exception, match='Error, database must be initialized before creating a connection pool'):
await database.aio_execute_sql(sql='SELECT 1;')

database.init(**DB_DEFAULTS[db_name])

await database.aio_execute_sql(sql='SELECT 1;')

0 comments on commit d20b6c4

Please sign in to comment.