Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rewrite transaction context manager #243

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,28 +236,6 @@ async def __aexit__(self, *args):
connection_context.set(None)


class TransactionContextManager(ConnectionContextManager):
async def __aenter__(self):
connection = await super().__aenter__()
begin_transaction = self.connection_context.transaction_is_opened is False

self.transaction = Transaction(connection, is_savepoint=begin_transaction is False)
await self.transaction.__aenter__()

if begin_transaction is True:
self.connection_context.transaction_is_opened = True
return connection

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.transaction.__aexit__(exc_type, exc_value, exc_tb)

end_transaction = self.transaction.is_savepoint is False
if end_transaction is True:
self.connection_context.transaction_is_opened = False

await super().__aexit__()


############
# Database #
############
Expand Down Expand Up @@ -286,11 +264,21 @@ async def aio_close(self):
"""
await self.aio_pool.terminate()

def aio_atomic(self):
@contextlib.asynccontextmanager
async def aio_atomic(self):
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
"""
return TransactionContextManager(self.aio_pool)
async with self.aio_connection() as connection:
_connection_context = connection_context.get()
begin_transaction = _connection_context.transaction_is_opened is False
try:
async with Transaction(connection, is_savepoint=begin_transaction is False):
_connection_context.transaction_is_opened = True
yield
finally:
if begin_transaction is True:
_connection_context.transaction_is_opened = False

def set_allow_sync(self, value):
"""Allow or forbid sync queries for the database. See also
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "peewee-async"
version = "0.9.1"
version = "0.10.1-beta"
description = "Asynchronous interface for peewee ORM powered by asyncio."
authors = ["Alexey Kinev <[email protected]>", "Gorshkov Nikolay(contributor) <[email protected]>"]
readme = "README.md"
Expand All @@ -19,13 +19,14 @@ aiomysql = { version = "^0.2.0", optional = true }
cryptography = { version = "^41.0.3", optional = true }
pytest = { version = "^7.4.1", optional = true }
pytest-asyncio = { version = "^0.21.1", optional = true }
pytest-mock = { version = "^3.14.0", optional = true }
sphinx = { version = "^7.1.2", optional = true }
sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true }

[tool.poetry.extras]
postgresql = ["aiopg"]
mysql = ["aiomysql", "cryptography"]
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio"]
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock"]
docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"]

[build-system]
Expand Down
85 changes: 58 additions & 27 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,65 @@
import asyncio

import pytest
from peewee import IntegrityError

from peewee_async import Transaction
from tests.conftest import dbs_all
from tests.models import TestModel


@dbs_all
async def test_savepoint_success(db):
class FakeConnectionError(Exception):
pass

async with db.aio_atomic():
await TestModel.aio_create(text='FOO')

@dbs_all
async def test_transaction_error_on_begin(db, mocker):
mocker.patch.object(Transaction, "begin", side_effect=FakeConnectionError)
with pytest.raises(FakeConnectionError):
async with db.aio_atomic():
await TestModel.update(text="BAR").aio_execute()
await TestModel.aio_create(text='FOO')
assert db.aio_pool.has_acquired_connections() is False

assert await TestModel.aio_get_or_none(text="BAR") is not None
@dbs_all
async def test_transaction_error_on_commit(db, mocker):
mocker.patch.object(Transaction, "commit", side_effect=FakeConnectionError)
with pytest.raises(FakeConnectionError):
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_transaction_success(db):
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')
async def test_transaction_error_on_rollback(db, mocker):
await TestModel.aio_create(text='FOO', data="")
mocker.patch.object(Transaction, "rollback", side_effect=FakeConnectionError)
with pytest.raises(FakeConnectionError):
async with db.aio_atomic():
await TestModel.update(data="BAR").aio_execute()
assert await TestModel.aio_get_or_none(data="BAR") is not None
await TestModel.aio_create(text='FOO')

assert await TestModel.aio_get_or_none(text="FOO") is not None
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_savepoint_rollback(db):
await TestModel.aio_create(text='FOO', data="")

async def test_transaction_success(db):
async with db.aio_atomic():
await TestModel.update(data="BAR").aio_execute()

try:
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')
except:
pass
await TestModel.aio_create(text='FOO')

assert await TestModel.aio_get_or_none(data="BAR") is not None
assert await TestModel.aio_get_or_none(text="FOO") is not None
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_transaction_rollback(db):
await TestModel.aio_create(text='FOO', data="")

try:
with pytest.raises(IntegrityError):
async with db.aio_atomic():
await TestModel.update(data="BAR").aio_execute()
assert await TestModel.aio_get_or_none(data="BAR") is not None
await TestModel.aio_create(text='FOO')
except:
pass

assert await TestModel.aio_get_or_none(data="BAR") is None
assert db.aio_pool.has_acquired_connections() is False
Expand All @@ -72,11 +77,9 @@ async def t1():
async def t2():
async with db.aio_atomic():
await TestModel.aio_create(text='FOO2', data="")
try:
with pytest.raises(IntegrityError):
async with db.aio_atomic():
await TestModel.aio_create(text='FOO2', data="not_created")
except:
pass

async def t3():
async with db.aio_atomic():
Expand Down Expand Up @@ -110,6 +113,33 @@ async def test_transaction_manual_work(db):
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_savepoint_success(db):
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')

async with db.aio_atomic():
await TestModel.update(text="BAR").aio_execute()

assert await TestModel.aio_get_or_none(text="BAR") is not None
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_savepoint_rollback(db):
await TestModel.aio_create(text='FOO', data="")

async with db.aio_atomic():
await TestModel.update(data="BAR").aio_execute()

with pytest.raises(IntegrityError):
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')

assert await TestModel.aio_get_or_none(data="BAR") is not None
assert db.aio_pool.has_acquired_connections() is False


@dbs_all
async def test_savepoint_manual_work(db):
async with db.aio_connection() as connection:
Expand Down Expand Up @@ -178,3 +208,4 @@ async def insert_records(event_for_wait: asyncio.Event):
# The transaction has not been committed
assert len(list(await TestModel.select().aio_execute())) == 0
assert db.aio_pool.has_acquired_connections() is False

Loading