Skip to content

Commit

Permalink
Merge pull request #16 from ydb-platform/invalidate_session_and_tx
Browse files Browse the repository at this point in the history
Invalidate session&tx on YDB errors
  • Loading branch information
vgvoleg authored Dec 19, 2024
2 parents eff6be5 + 47fe1f4 commit 6c4908d
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 3 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ version = "0.1.5" # AUTOVERSION
description = "YDB Python DBAPI which complies with PEP 249"
authors = ["Yandex LLC <[email protected]>"]
readme = "README.md"

[project.urls]
Homepage = "https://github.com/ydb-platform/ydb-python-dbapi/"
repository = "https://github.com/ydb-platform/ydb-python-dbapi/"

[tool.poetry.dependencies]
python = "^3.8"
Expand Down
39 changes: 39 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,34 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None:

maybe_await(cursor.execute_scheme("DROP TABLE pet"))

def _test_error_with_interactive_tx(
self,
connection: dbapi.Connection,
) -> None:

cur = connection.cursor()
maybe_await(cur.execute_scheme(
"""
DROP TABLE IF EXISTS test;
CREATE TABLE test (
id Int64 NOT NULL,
val Int64,
PRIMARY KEY(id)
)
"""
))

connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE)
maybe_await(connection.begin())

cur = connection.cursor()
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))
with pytest.raises(dbapi.Error):
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))

maybe_await(cur.close())
maybe_await(connection.rollback())


class TestConnection(BaseDBApiTestSuit):
@pytest.fixture
Expand Down Expand Up @@ -245,6 +273,11 @@ def test_errors(self, connection: dbapi.Connection) -> None:
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
self._test_bulk_upsert(connection)

def test_errors_with_interactive_tx(
self, connection: dbapi.Connection
) -> None:
self._test_error_with_interactive_tx(connection)


class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
Expand Down Expand Up @@ -304,3 +337,9 @@ async def test_bulk_upsert(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_bulk_upsert, connection)

@pytest.mark.asyncio
async def test_errors_with_interactive_tx(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_error_with_interactive_tx, connection)
34 changes: 34 additions & 0 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import pytest
import ydb
import ydb_dbapi
from sqlalchemy.util import await_only
from sqlalchemy.util import greenlet_spawn
from ydb_dbapi import AsyncCursor
from ydb_dbapi import Cursor
from ydb_dbapi.utils import CursorStatus


def maybe_await(obj: callable) -> any:
Expand All @@ -22,6 +24,14 @@ def maybe_await(obj: callable) -> any:
RESULT_SET_COUNT = 3


class FakeSyncConnection:
def _invalidate_session(self) -> None: ...


class FakeAsyncConnection:
async def _invalidate_session(self) -> None: ...


class BaseCursorTestSuit:
def _test_cursor_fetch_one(self, cursor: Cursor | AsyncCursor) -> None:
yql_text = """
Expand Down Expand Up @@ -136,13 +146,24 @@ def _test_cursor_fetch_all_multiple_result_sets(
assert maybe_await(cursor.fetchall()) == []
assert not maybe_await(cursor.nextset())

def _test_cursor_state_after_error(
self, cursor: Cursor | AsyncCursor
) -> None:
query = "INSERT INTO table (id, val) VALUES (0,0)"
with pytest.raises(ydb_dbapi.Error):
maybe_await(cursor.execute(query=query))

assert cursor._state == CursorStatus.finished


class TestCursor(BaseCursorTestSuit):
@pytest.fixture
def sync_cursor(
self, session_pool_sync: ydb.QuerySessionPool
) -> Generator[Cursor]:

cursor = Cursor(
FakeSyncConnection(),
session_pool_sync,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -174,6 +195,10 @@ def test_cursor_fetch_all_multiple_result_sets(
) -> None:
self._test_cursor_fetch_all_multiple_result_sets(sync_cursor)

def test_cursor_state_after_error(
self, sync_cursor: Cursor
) -> None:
self._test_cursor_state_after_error(sync_cursor)


class TestAsyncCursor(BaseCursorTestSuit):
Expand All @@ -182,6 +207,7 @@ async def async_cursor(
self, session_pool: ydb.aio.QuerySessionPool
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(
FakeAsyncConnection(),
session_pool,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -224,3 +250,11 @@ async def test_cursor_fetch_all_multiple_result_sets(
await greenlet_spawn(
self._test_cursor_fetch_all_multiple_result_sets, async_cursor
)

@pytest.mark.asyncio
async def test_cursor_state_after_error(
self, async_cursor: AsyncCursor
) -> None:
await greenlet_spawn(
self._test_cursor_state_after_error, async_cursor
)
16 changes: 16 additions & 0 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(

def cursor(self) -> Cursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -326,6 +327,13 @@ def bulk_upsert(
settings=settings,
)

def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
self._session_pool.release(self._session)
self._session = None


class AsyncConnection(BaseConnection):
_driver_cls = ydb.aio.Driver
Expand Down Expand Up @@ -357,6 +365,7 @@ def __init__(

def cursor(self) -> AsyncCursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -492,6 +501,13 @@ async def bulk_upsert(
settings=settings,
)

async def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
await self._session_pool.release(self._session)
self._session = None


def connect(*args: tuple, **kwargs: dict) -> Connection:
conn = Connection(*args, **kwargs) # type: ignore
Expand Down
46 changes: 46 additions & 0 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import functools
import itertools
from collections.abc import AsyncIterator
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Union

import ydb
Expand All @@ -20,6 +23,9 @@
from .utils import maybe_get_current_trace_id

if TYPE_CHECKING:
from .connections import AsyncConnection
from .connections import Connection

ParametersType = dict[
str,
Union[
Expand All @@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
return str(ydb.convert.type_to_native(type_obj))


def invalidate_cursor_on_ydb_error(func: Callable) -> Callable:
if iscoroutinefunction(func):

@functools.wraps(func)
async def awrapper(
self: AsyncCursor, *args: tuple, **kwargs: dict
) -> Any:
try:
return await func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.finished
await self._connection._invalidate_session()
raise

return awrapper

@functools.wraps(func)
def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any:
try:
return func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.finished
self._connection._invalidate_session()
raise

return wrapper


class BufferedCursor:
def __init__(self) -> None:
self.arraysize: int = 1
Expand Down Expand Up @@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
class Cursor(BufferedCursor):
def __init__(
self,
connection: Connection,
session_pool: ydb.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> Iterator[ydb.convert.ResultSet]:
Expand All @@ -205,6 +242,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_session_query(
self,
query: str,
Expand All @@ -225,6 +263,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_transactional_query(
self,
tx_context: ydb.QueryTxContext,
Expand Down Expand Up @@ -283,6 +322,7 @@ def executemany(
self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down Expand Up @@ -328,13 +368,15 @@ def __exit__(
class AsyncCursor(BufferedCursor):
def __init__(
self,
connection: AsyncConnection,
session_pool: ydb.aio.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> AsyncIterator[ydb.convert.ResultSet]:
Expand All @@ -379,6 +422,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_session_query(
self,
query: str,
Expand All @@ -399,6 +443,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_transactional_query(
self,
tx_context: ydb.aio.QueryTxContext,
Expand Down Expand Up @@ -457,6 +502,7 @@ async def executemany(
await self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down

0 comments on commit 6c4908d

Please sign in to comment.