Skip to content

Commit

Permalink
feat: removed last_insert_id_async typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Jul 19, 2024
1 parent 765583c commit 610545b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
15 changes: 8 additions & 7 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import peewee

from .result_wrappers import AsyncQueryWrapper
from .utils import CursorProtocol


async def aio_prefetch(sq, *subqueries, prefetch_type):
Expand Down Expand Up @@ -40,45 +41,45 @@ class AioQueryMixin:
async def aio_execute(self, database):
return await database.aio_execute(self)

async def make_async_query_wrapper(self, cursor):
async def make_async_query_wrapper(self, cursor: CursorProtocol):
return await AsyncQueryWrapper.make_for_all_rows(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return cursor.rowcount


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):

async def fetch_results(self, cursor):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return cursor.rowcount


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning is not None and len(self._returning) > 1:
return await self.make_async_query_wrapper(cursor)

if self._returning:
row = await cursor.fetchone()
return row[0] if row else None
else:
return await self._database.last_insert_id_async(cursor)
return cursor.lastrowid


class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
async def fetch_results(self, cursor):
async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)


class AioSelectMixin(AioQueryMixin):

async def fetch_results(self, cursor):
async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)

@peewee.database_required
Expand Down
16 changes: 0 additions & 16 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,6 @@ def connect_params_async(self):
})
return kwargs

async def last_insert_id_async(self, cursor):
"""Get ID of last inserted row.
NOTE: it's not clear, when this code is executed?
"""
# try:
# return cursor if query_type else cursor[0][0]
# except (IndexError, KeyError, TypeError):
# pass
return cursor.lastrowid


class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
Expand Down Expand Up @@ -327,11 +316,6 @@ def connect_params_async(self):
})
return kwargs

async def last_insert_id_async(self, cursor):
"""Get ID of last inserted row.
"""
return cursor.lastrowid


# DEPRECATED Databases

Expand Down
18 changes: 12 additions & 6 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import asyncio
from typing import Any, Optional
from typing import Any, Optional, cast

from .utils import aiopg, aiomysql, PoolProtocol, ConnectionProtocol

Expand Down Expand Up @@ -67,9 +67,12 @@ async def create(self) -> None:
"""
if "connect_timeout" in self.connect_params:
self.connect_params['timeout'] = self.connect_params.pop("connect_timeout")
self.pool = await aiopg.create_pool(
database=self.database,
**self.connect_params
self.pool = cast(
PoolProtocol,
await aiopg.create_pool(
database=self.database,
**self.connect_params
)
)


Expand All @@ -80,6 +83,9 @@ class MysqlPoolBackend(PoolBackend):
async def create(self) -> None:
"""Create connection pool asynchronously.
"""
self.pool = await aiomysql.create_pool(
db=self.database, **self.connect_params
self.pool = cast(
PoolProtocol,
await aiomysql.create_pool(
db=self.database, **self.connect_params
),
)
8 changes: 8 additions & 0 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,18 @@ class CursorProtocol(Protocol):
async def fetchone(self) -> Any:
...

@property
def lastrowid(self) -> int:
...

@property
def description(self) -> Optional[Sequence[Any]]:
...

@property
def rowcount(self) -> int:
...

async def execute(self, query: str, *args: Any, **kwargs: Any) -> None:
...

Expand Down

0 comments on commit 610545b

Please sign in to comment.