Skip to content

Commit

Permalink
Invalidate session&tx on YDB errors
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Dec 19, 2024
1 parent eff6be5 commit 3c83113
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,34 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None:

maybe_await(cursor.execute_scheme("DROP TABLE pet"))

def _test_error_with_interactive_tx(
self,
connection: dbapi.Connection,
) -> None:

cur = connection.cursor()
cur.execute_scheme(
"""
DROP TABLE IF EXISTS test;
CREATE TABLE test (
id Int64 NOT NULL,
val Int64,
PRIMARY KEY(id)
)
"""
)

connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE)
maybe_await(connection.begin())

cur = connection.cursor()
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))
with pytest.raises(dbapi.Error):
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))

maybe_await(cur.close())
maybe_await(connection.rollback())


class TestConnection(BaseDBApiTestSuit):
@pytest.fixture
Expand Down Expand Up @@ -245,6 +273,11 @@ def test_errors(self, connection: dbapi.Connection) -> None:
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
self._test_bulk_upsert(connection)

def test_errors_with_interactive_tx(
self, connection: dbapi.Connection
) -> None:
self._test_error_with_interactive_tx(connection)


class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
Expand Down Expand Up @@ -304,3 +337,9 @@ async def test_bulk_upsert(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_bulk_upsert, connection)

@pytest.mark.asyncio
async def test_errors_with_interactive_tx(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_error_with_interactive_tx, connection)
33 changes: 33 additions & 0 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.util import greenlet_spawn
from ydb_dbapi import AsyncCursor
from ydb_dbapi import Cursor
from ydb_dbapi.utils import CursorStatus


def maybe_await(obj: callable) -> any:
Expand All @@ -22,6 +23,14 @@ def maybe_await(obj: callable) -> any:
RESULT_SET_COUNT = 3


class FakeSyncConnection:
def _invalidate_session(): ...


class FakeAsyncConnection:
async def _invalidate_session(): ...


class BaseCursorTestSuit:
def _test_cursor_fetch_one(self, cursor: Cursor | AsyncCursor) -> None:
yql_text = """
Expand Down Expand Up @@ -136,13 +145,24 @@ def _test_cursor_fetch_all_multiple_result_sets(
assert maybe_await(cursor.fetchall()) == []
assert not maybe_await(cursor.nextset())

def _test_cursor_state_after_error(
self, cursor: Cursor | AsyncCursor
) -> None:
query = "INSERT INTO table (id, val) VALUES (0,0)"
with pytest.raises(ydb.Error):
maybe_await(cursor.execute(query=query))

assert cursor._state == CursorStatus.finished


class TestCursor(BaseCursorTestSuit):
@pytest.fixture
def sync_cursor(
self, session_pool_sync: ydb.QuerySessionPool
) -> Generator[Cursor]:

cursor = Cursor(
FakeSyncConnection(),
session_pool_sync,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -174,6 +194,10 @@ def test_cursor_fetch_all_multiple_result_sets(
) -> None:
self._test_cursor_fetch_all_multiple_result_sets(sync_cursor)

def test_cursor_state_after_error(
self, sync_cursor: Cursor
) -> None:
self._test_cursor_state_after_error(sync_cursor)


class TestAsyncCursor(BaseCursorTestSuit):
Expand All @@ -182,6 +206,7 @@ async def async_cursor(
self, session_pool: ydb.aio.QuerySessionPool
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(
FakeAsyncConnection(),
session_pool,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -224,3 +249,11 @@ async def test_cursor_fetch_all_multiple_result_sets(
await greenlet_spawn(
self._test_cursor_fetch_all_multiple_result_sets, async_cursor
)

@pytest.mark.asyncio
async def test_cursor_state_after_error(
self, async_cursor: AsyncCursor
) -> None:
await greenlet_spawn(
self._test_cursor_state_after_error, async_cursor
)
16 changes: 16 additions & 0 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(

def cursor(self) -> Cursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -326,6 +327,13 @@ def bulk_upsert(
settings=settings,
)

def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
self._session_pool.release(self._session)
self._session = None


class AsyncConnection(BaseConnection):
_driver_cls = ydb.aio.Driver
Expand Down Expand Up @@ -357,6 +365,7 @@ def __init__(

def cursor(self) -> AsyncCursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -492,6 +501,13 @@ async def bulk_upsert(
settings=settings,
)

async def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
await self._session_pool.release(self._session)
self._session = None


def connect(*args: tuple, **kwargs: dict) -> Connection:
conn = Connection(*args, **kwargs) # type: ignore
Expand Down
46 changes: 46 additions & 0 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import functools
import itertools
from collections.abc import AsyncIterator
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Union

import ydb
Expand All @@ -20,6 +23,9 @@
from .utils import maybe_get_current_trace_id

if TYPE_CHECKING:
from .connections import AsyncConnection
from .connections import Connection

ParametersType = dict[
str,
Union[
Expand All @@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
return str(ydb.convert.type_to_native(type_obj))


def invalidate_cursor_on_ydb_error(func: Callable) -> Callable:
if iscoroutinefunction(func):

@functools.wraps(func)
async def awrapper(
self: AsyncCursor, *args: tuple, **kwargs: dict
) -> Any:
try:
return await func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.finished
await self._connection._invalidate_session()
raise

return awrapper

@functools.wraps(func)
def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any:
try:
return func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.closed
self._connection._invalidate_session()
raise

return wrapper


class BufferedCursor:
def __init__(self) -> None:
self.arraysize: int = 1
Expand Down Expand Up @@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
class Cursor(BufferedCursor):
def __init__(
self,
connection: Connection,
session_pool: ydb.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> Iterator[ydb.convert.ResultSet]:
Expand All @@ -205,6 +242,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_session_query(
self,
query: str,
Expand All @@ -225,6 +263,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_transactional_query(
self,
tx_context: ydb.QueryTxContext,
Expand Down Expand Up @@ -283,6 +322,7 @@ def executemany(
self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down Expand Up @@ -328,13 +368,15 @@ def __exit__(
class AsyncCursor(BufferedCursor):
def __init__(
self,
connection: AsyncConnection,
session_pool: ydb.aio.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> AsyncIterator[ydb.convert.ResultSet]:
Expand All @@ -379,6 +422,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_session_query(
self,
query: str,
Expand All @@ -399,6 +443,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_transactional_query(
self,
tx_context: ydb.aio.QueryTxContext,
Expand Down Expand Up @@ -457,6 +502,7 @@ async def executemany(
await self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down

0 comments on commit 3c83113

Please sign in to comment.