diff --git a/tests/test_connections.py b/tests/test_connections.py index 1dfa373..08fd071 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -200,6 +200,34 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None: maybe_await(cursor.execute_scheme("DROP TABLE pet")) + def _test_error_with_interactive_tx( + self, + connection: dbapi.Connection, + ) -> None: + + cur = connection.cursor() + cur.execute_scheme( + """ + DROP TABLE IF EXISTS test; + CREATE TABLE test ( + id Int64 NOT NULL, + val Int64, + PRIMARY KEY(id) + ) + """ + ) + + connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + cur = connection.cursor() + maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)")) + with pytest.raises(dbapi.Error): + maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)")) + + maybe_await(cur.close()) + maybe_await(connection.rollback()) + class TestConnection(BaseDBApiTestSuit): @pytest.fixture @@ -245,6 +273,11 @@ def test_errors(self, connection: dbapi.Connection) -> None: def test_bulk_upsert(self, connection: dbapi.Connection) -> None: self._test_bulk_upsert(connection) + def test_errors_with_interactive_tx( + self, connection: dbapi.Connection + ) -> None: + self._test_error_with_interactive_tx(connection) + class TestAsyncConnection(BaseDBApiTestSuit): @pytest_asyncio.fixture @@ -304,3 +337,9 @@ async def test_bulk_upsert( self, connection: dbapi.AsyncConnection ) -> None: await greenlet_spawn(self._test_bulk_upsert, connection) + + @pytest.mark.asyncio + async def test_errors_with_interactive_tx( + self, connection: dbapi.AsyncConnection + ) -> None: + await greenlet_spawn(self._test_error_with_interactive_tx, connection) diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 634f0d0..e3f7372 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -10,6 +10,7 @@ from sqlalchemy.util import greenlet_spawn from ydb_dbapi import AsyncCursor from ydb_dbapi import Cursor +from ydb_dbapi.utils import CursorStatus def maybe_await(obj: callable) -> any: @@ -22,6 +23,14 @@ def maybe_await(obj: callable) -> any: RESULT_SET_COUNT = 3 +class FakeSyncConnection: + def _invalidate_session(): ... + + +class FakeAsyncConnection: + async def _invalidate_session(): ... + + class BaseCursorTestSuit: def _test_cursor_fetch_one(self, cursor: Cursor | AsyncCursor) -> None: yql_text = """ @@ -136,13 +145,24 @@ def _test_cursor_fetch_all_multiple_result_sets( assert maybe_await(cursor.fetchall()) == [] assert not maybe_await(cursor.nextset()) + def _test_cursor_state_after_error( + self, cursor: Cursor | AsyncCursor + ) -> None: + query = "INSERT INTO table (id, val) VALUES (0,0)" + with pytest.raises(ydb.Error): + maybe_await(cursor.execute(query=query)) + + assert cursor._state == CursorStatus.finished + class TestCursor(BaseCursorTestSuit): @pytest.fixture def sync_cursor( self, session_pool_sync: ydb.QuerySessionPool ) -> Generator[Cursor]: + cursor = Cursor( + FakeSyncConnection(), session_pool_sync, ydb.QuerySerializableReadWrite(), request_settings=ydb.BaseRequestSettings(), @@ -174,6 +194,10 @@ def test_cursor_fetch_all_multiple_result_sets( ) -> None: self._test_cursor_fetch_all_multiple_result_sets(sync_cursor) + def test_cursor_state_after_error( + self, sync_cursor: Cursor + ) -> None: + self._test_cursor_state_after_error(sync_cursor) class TestAsyncCursor(BaseCursorTestSuit): @@ -182,6 +206,7 @@ async def async_cursor( self, session_pool: ydb.aio.QuerySessionPool ) -> AsyncGenerator[Cursor]: cursor = AsyncCursor( + FakeAsyncConnection(), session_pool, ydb.QuerySerializableReadWrite(), request_settings=ydb.BaseRequestSettings(), @@ -224,3 +249,11 @@ async def test_cursor_fetch_all_multiple_result_sets( await greenlet_spawn( self._test_cursor_fetch_all_multiple_result_sets, async_cursor ) + + @pytest.mark.asyncio + async def test_cursor_state_after_error( + self, async_cursor: AsyncCursor + ) -> None: + await greenlet_spawn( + self._test_cursor_state_after_error, async_cursor + ) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 8f45318..9eda5ca 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -192,6 +192,7 @@ def __init__( def cursor(self) -> Cursor: return self._cursor_cls( + connection=self, session_pool=self._session_pool, tx_mode=self._tx_mode, tx_context=self._tx_context, @@ -326,6 +327,13 @@ def bulk_upsert( settings=settings, ) + def _invalidate_session(self) -> None: + if self._tx_context: + self._tx_context = None + if self._session: + self._session_pool.release(self._session) + self._session = None + class AsyncConnection(BaseConnection): _driver_cls = ydb.aio.Driver @@ -357,6 +365,7 @@ def __init__( def cursor(self) -> AsyncCursor: return self._cursor_cls( + connection=self, session_pool=self._session_pool, tx_mode=self._tx_mode, tx_context=self._tx_context, @@ -492,6 +501,13 @@ async def bulk_upsert( settings=settings, ) + async def _invalidate_session(self) -> None: + if self._tx_context: + self._tx_context = None + if self._session: + await self._session_pool.release(self._session) + self._session = None + def connect(*args: tuple, **kwargs: dict) -> Connection: conn = Connection(*args, **kwargs) # type: ignore diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 7d9b911..6e81113 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -1,12 +1,15 @@ from __future__ import annotations +import functools import itertools from collections.abc import AsyncIterator from collections.abc import Generator from collections.abc import Iterator from collections.abc import Sequence +from inspect import iscoroutinefunction from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Union import ydb @@ -20,6 +23,9 @@ from .utils import maybe_get_current_trace_id if TYPE_CHECKING: + from .connections import AsyncConnection + from .connections import Connection + ParametersType = dict[ str, Union[ @@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str: return str(ydb.convert.type_to_native(type_obj)) +def invalidate_cursor_on_ydb_error(func: Callable) -> Callable: + if iscoroutinefunction(func): + + @functools.wraps(func) + async def awrapper( + self: AsyncCursor, *args: tuple, **kwargs: dict + ) -> Any: + try: + return await func(self, *args, **kwargs) + except ydb.Error: + self._state = CursorStatus.finished + await self._connection._invalidate_session() + raise + + return awrapper + + @functools.wraps(func) + def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any: + try: + return func(self, *args, **kwargs) + except ydb.Error: + self._state = CursorStatus.closed + self._connection._invalidate_session() + raise + + return wrapper + + class BufferedCursor: def __init__(self) -> None: self.arraysize: int = 1 @@ -154,6 +188,7 @@ def _append_table_path_prefix(self, query: str) -> str: class Cursor(BufferedCursor): def __init__( self, + connection: Connection, session_pool: ydb.QuerySessionPool, tx_mode: ydb.BaseQueryTxMode, request_settings: ydb.BaseRequestSettings, @@ -161,6 +196,7 @@ def __init__( table_path_prefix: str = "", ) -> None: super().__init__() + self._connection = connection self._session_pool = session_pool self._tx_mode = tx_mode self._request_settings = request_settings @@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings: return settings @handle_ydb_errors + @invalidate_cursor_on_ydb_error def _execute_generic_query( self, query: str, parameters: ParametersType | None = None ) -> Iterator[ydb.convert.ResultSet]: @@ -205,6 +242,7 @@ def callee( return self._session_pool.retry_operation_sync(callee) @handle_ydb_errors + @invalidate_cursor_on_ydb_error def _execute_session_query( self, query: str, @@ -225,6 +263,7 @@ def callee( return self._session_pool.retry_operation_sync(callee) @handle_ydb_errors + @invalidate_cursor_on_ydb_error def _execute_transactional_query( self, tx_context: ydb.QueryTxContext, @@ -283,6 +322,7 @@ def executemany( self.execute(query, parameters) @handle_ydb_errors + @invalidate_cursor_on_ydb_error def nextset(self, replace_current: bool = True) -> bool: if self._stream is None: return False @@ -328,6 +368,7 @@ def __exit__( class AsyncCursor(BufferedCursor): def __init__( self, + connection: AsyncConnection, session_pool: ydb.aio.QuerySessionPool, tx_mode: ydb.BaseQueryTxMode, request_settings: ydb.BaseRequestSettings, @@ -335,6 +376,7 @@ def __init__( table_path_prefix: str = "", ) -> None: super().__init__() + self._connection = connection self._session_pool = session_pool self._tx_mode = tx_mode self._request_settings = request_settings @@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings: return settings @handle_ydb_errors + @invalidate_cursor_on_ydb_error async def _execute_generic_query( self, query: str, parameters: ParametersType | None = None ) -> AsyncIterator[ydb.convert.ResultSet]: @@ -379,6 +422,7 @@ async def callee( return await self._session_pool.retry_operation_async(callee) @handle_ydb_errors + @invalidate_cursor_on_ydb_error async def _execute_session_query( self, query: str, @@ -399,6 +443,7 @@ async def callee( return await self._session_pool.retry_operation_async(callee) @handle_ydb_errors + @invalidate_cursor_on_ydb_error async def _execute_transactional_query( self, tx_context: ydb.aio.QueryTxContext, @@ -457,6 +502,7 @@ async def executemany( await self.execute(query, parameters) @handle_ydb_errors + @invalidate_cursor_on_ydb_error async def nextset(self, replace_current: bool = True) -> bool: if self._stream is None: return False