Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return deferred db init #278

Merged
merged 3 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 74 additions & 44 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import logging
from typing import Type, Optional, Any, AsyncIterator, Iterator
from typing import Type, Optional, Any, AsyncIterator, Iterator, Dict

import peewee
from playhouse import postgres_ext as ext
Expand All @@ -11,15 +11,18 @@
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__


class AioDatabase:
class AioDatabase(peewee.Database):
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]
pool_backend: PoolBackend

def __init__(self, database: Optional[str], **kwargs: Any) -> None:
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
@property
def connect_params_async(self) -> Dict[str, Any]:
...
kalombos marked this conversation as resolved.
Show resolved Hide resolved

def init(self, database: Optional[str], **kwargs: Any) -> None:
super().init(database, **kwargs)
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
Expand All @@ -28,6 +31,8 @@ def __init__(self, database: Optional[str], **kwargs: Any) -> None:
async def aio_connect(self) -> None:
"""Creates a connection pool
"""
if self.deferred:
raise Exception('Cannot connect before db inited')
kalombos marked this conversation as resolved.
Show resolved Hide resolved
await self.pool_backend.connect()

@property
kalombos marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -39,6 +44,9 @@ def is_connected(self) -> bool:
async def aio_close(self) -> None:
"""Terminate pool backend. The pool is closed until you run aio_connect manually
"""
if self.deferred:
raise Exception('Cannot connect before db inited')
kalombos marked this conversation as resolved.
Show resolved Hide resolved

await self.pool_backend.terminate()

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -91,12 +99,17 @@ def execute_sql(self, *args, **kwargs):
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs)))
logging.log(
self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs))
)
return super().execute_sql(*args, **kwargs)

def aio_connection(self) -> ConnectionContextManager:
if self.deferred:
raise Exception('Cannot create connection before db inited')
kalombos marked this conversation as resolved.
Show resolved Hide resolved

return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params=None, fetch_results=None):
Expand All @@ -123,38 +136,28 @@ async def aio_execute(self, query, fetch_results=None):
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)


class AioPostgresqlMixin(AioDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
class AioPostgresqlMixin(AioDatabase, peewee.PostgresqlDatabase):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection.
"""

_enable_json: bool
_enable_hstore: bool

pool_backend_cls = PostgresqlPoolBackend

if psycopg2:
Error = psycopg2.Error

def init_async(self, enable_json: bool = False, enable_hstore: bool =False) -> None:
def init_async(self, enable_json: bool = False, enable_hstore: bool = False) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
self._enable_json = enable_json
self._enable_hstore = enable_hstore

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
})
return kwargs


class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
"""PostgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.

:param max_connections: connections pool size
Expand All @@ -166,15 +169,37 @@ class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
min_connections: int = 1
max_connections: int = 20

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
super().init(database, **kwargs)
if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

self.init_async()
super().init(database, **kwargs)

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
}
)
return kwargs


class PooledPostgresqlExtDatabase(
AioPostgresqlMixin,
PooledPostgresqlDatabase,
ext.PostgresqlExtDatabase
):
"""PosgreSQL database extended driver providing **single drop-in sync**
Expand All @@ -183,8 +208,6 @@ class PooledPostgresqlExtDatabase(
JSON fields support is always enabled, HStore supports is enabled by
default, but can be disabled with ``register_hstore=False`` argument.

:param max_connections: connections pool size

Example::

database = PooledPostgresqlExtDatabase('test', register_hstore=False,
Expand All @@ -195,14 +218,11 @@ class PooledPostgresqlExtDatabase(
"""

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
super().init(database, **kwargs)


class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
Expand All @@ -218,6 +238,9 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
min_connections: int = 1
max_connections: int = 20

pool_backend_cls = MysqlPoolBackend

if pymysql:
Expand All @@ -226,18 +249,25 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
def init(self, database: Optional[str], **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)

if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

super().init(database, **kwargs)

@property
def connect_params_async(self):
def connect_params_async(self) -> Dict[str, Any]:
"""Connection parameters for `aiomysql.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
})
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
}
)
return kwargs
3 changes: 2 additions & 1 deletion peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database: str, **kwargs: Any) -> None:

def __init__(self, *, database: Optional[str] = None, **kwargs: Any) -> None:
kalombos marked this conversation as resolved.
Show resolved Hide resolved
self.pool: Optional[PoolProtocol] = None
self.database = database
self.connect_params = kwargs
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def db(request):

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)

database._allow_sync = False
with database.allow_sync():
for model in ALL_MODELS:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest

from peewee_async import connection_context
from tests.conftest import dbs_all
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, MYSQL_DBS, PG_DBS
from tests.db_config import DB_DEFAULTS, DB_CLASSES
from tests.models import TestModel


Expand Down Expand Up @@ -53,3 +55,20 @@ async def test_aio_close_idempotent(db):

await db.aio_close()
assert db.is_connected is False


@pytest.mark.parametrize('db_name', PG_DBS + MYSQL_DBS)
async def test_deferred_init(db_name):
init_params = DB_DEFAULTS[db_name]
database_host = init_params.pop('database')
kalombos marked this conversation as resolved.
Show resolved Hide resolved
init_params['database'] = None

database: AioDatabase = DB_CLASSES[db_name](**init_params)

with pytest.raises(Exception, match='Cannot create connection before db inited'):
await database.aio_execute_sql(sql='SELECT 1;')

init_params['database'] = database_host
database.init(**init_params)

await database.aio_connect()
kalombos marked this conversation as resolved.
Show resolved Hide resolved
Loading