Skip to content

Commit

Permalink
Merge branch 'main' into fix_tx_modes
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Oct 31, 2024
2 parents 5870e24 + 5f50b2c commit 98dc11c
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
## 0.0.1b1 ##
* YDB DBAPI based on QueryService
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ydb-dbapi"
version = "0.0.0" # AUTOVERSION
version = "0.0.1b1" # AUTOVERSION
description = "YDB Python DBAPI which complies with PEP 249"
authors = ["Yandex LLC <[email protected]>"]
readme = "README.md"
Expand Down
28 changes: 12 additions & 16 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ def _test_isolation_level_read_only(
isolation_level: str,
read_only: bool,
) -> None:
connection.set_isolation_level("AUTOCOMMIT")
cursor = connection.cursor()
with suppress(dbapi.DatabaseError):
maybe_await(cursor.execute("DROP TABLE foo"))
cursor = connection.cursor()
maybe_await(cursor.execute(
maybe_await(cursor.execute_scheme("DROP TABLE foo"))
maybe_await(cursor.execute_scheme(
"CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))"
))

Expand All @@ -46,23 +44,21 @@ def _test_isolation_level_read_only(

maybe_await(connection.rollback())

connection.set_isolation_level("AUTOCOMMIT")
cursor = connection.cursor()
maybe_await(cursor.execute("DROP TABLE foo"))
maybe_await(cursor.execute_scheme("DROP TABLE foo"))

def _test_connection(self, connection: dbapi.Connection) -> None:
maybe_await(connection.commit())
maybe_await(connection.rollback())

cur = connection.cursor()
with suppress(dbapi.DatabaseError):
maybe_await(cur.execute("DROP TABLE foo"))
maybe_await(cur.execute_scheme("DROP TABLE foo"))

assert not maybe_await(connection.check_exists("/local/foo"))
with pytest.raises(dbapi.ProgrammingError):
maybe_await(connection.describe("/local/foo"))

maybe_await(cur.execute(
maybe_await(cur.execute_scheme(
"CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))"
))

Expand All @@ -72,17 +68,17 @@ def _test_connection(self, connection: dbapi.Connection) -> None:
assert col.name == "id"
assert col.type == ydb.PrimitiveType.Int64

maybe_await(cur.execute("DROP TABLE foo"))
maybe_await(cur.execute_scheme("DROP TABLE foo"))
maybe_await(cur.close())

def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
cur = connection.cursor()
assert cur

with suppress(dbapi.DatabaseError):
maybe_await(cur.execute("DROP TABLE test"))
maybe_await(cur.execute_scheme("DROP TABLE test"))

maybe_await(cur.execute(
maybe_await(cur.execute_scheme(
"CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))"
))

Expand All @@ -107,7 +103,7 @@ def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
},
))

maybe_await(cur.execute("DROP TABLE test"))
maybe_await(cur.execute_scheme("DROP TABLE test"))

maybe_await(cur.close())

Expand All @@ -125,7 +121,7 @@ def _test_errors(
cur = connection.cursor()

with suppress(dbapi.DatabaseError):
maybe_await(cur.execute("DROP TABLE test"))
maybe_await(cur.execute_scheme("DROP TABLE test"))

with pytest.raises(dbapi.DataError):
maybe_await(cur.execute("SELECT 18446744073709551616"))
Expand All @@ -139,7 +135,7 @@ def _test_errors(
with pytest.raises(dbapi.ProgrammingError):
maybe_await(cur.execute("SELECT * FROM test"))

maybe_await(cur.execute(
maybe_await(cur.execute_scheme(
"CREATE TABLE test(id Int64, PRIMARY KEY (id))"
))

Expand All @@ -148,7 +144,7 @@ def _test_errors(
with pytest.raises(dbapi.IntegrityError):
maybe_await(cur.execute("INSERT INTO test(id) VALUES(1)"))

maybe_await(cur.execute("DROP TABLE test"))
maybe_await(cur.execute_scheme("DROP TABLE test"))
maybe_await(cur.close())


Expand Down
4 changes: 2 additions & 2 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _test_cursor_fetch_all_multiple_result_sets(
class TestCursor(BaseCursorTestSuit):
@pytest.fixture
def sync_cursor(self, session_sync: ydb.QuerySession) -> Generator[Cursor]:
cursor = Cursor(session_sync)
cursor = Cursor(session_sync, ydb.QuerySerializableReadWrite())
yield cursor
cursor.close()

Expand Down Expand Up @@ -175,7 +175,7 @@ class TestAsyncCursor(BaseCursorTestSuit):
async def async_cursor(
self, session: ydb.aio.QuerySession
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(session)
cursor = AsyncCursor(session, ydb.QuerySerializableReadWrite())
yield cursor
await greenlet_spawn(cursor.close)

Expand Down
19 changes: 10 additions & 9 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class _IsolationSettings(NamedTuple):


_ydb_isolation_settings_map = {
IsolationLevel.AUTOCOMMIT: _IsolationSettings(None, interactive=False),
IsolationLevel.AUTOCOMMIT: _IsolationSettings(
ydb.QuerySerializableReadWrite(), interactive=False
),
IsolationLevel.SERIALIZABLE: _IsolationSettings(
ydb.QuerySerializableReadWrite(), interactive=True
),
Expand Down Expand Up @@ -79,7 +81,7 @@ def __init__(
self._shared_session_pool: bool = False

self._tx_context: TxContext | AsyncTxContext | None = None
self._tx_mode: ydb.BaseQueryTxMode | None = None
self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
self.interactive_transaction: bool = False

if ydb_session_pool is not None:
Expand All @@ -105,15 +107,14 @@ def set_isolation_level(self, isolation_level: IsolationLevel) -> None:

ydb_isolation_settings = _ydb_isolation_settings_map[isolation_level]

self._tx_context = None
self._tx_mode = ydb_isolation_settings.ydb_mode
self.interactive_transaction = ydb_isolation_settings.interactive

def get_isolation_level(self) -> str:
if self._tx_mode is None:
return IsolationLevel.AUTOCOMMIT
if self._tx_mode.name == ydb.QuerySerializableReadWrite().name:
return IsolationLevel.SERIALIZABLE
if self.interactive_transaction:
return IsolationLevel.SERIALIZABLE
return IsolationLevel.AUTOCOMMIT
if self._tx_mode.name == ydb.QueryOnlineReadOnly().name:
if self._tx_mode.allow_inconsistent_reads:
return IsolationLevel.ONLINE_READONLY_INCONSISTENT
Expand All @@ -128,7 +129,7 @@ def get_isolation_level(self) -> str:
def _maybe_init_tx(
self, session: ydb.QuerySession | ydb.aio.QuerySession
) -> None:
if self._tx_context is None and self._tx_mode is not None:
if self._tx_context is None and self.interactive_transaction:
self._tx_context = session.transaction(self._tx_mode)


Expand Down Expand Up @@ -166,8 +167,8 @@ def cursor(self) -> Cursor:

self._current_cursor = self._cursor_cls(
session=self._session,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
autocommit=(not self.interactive_transaction),
)
return self._current_cursor

Expand Down Expand Up @@ -290,8 +291,8 @@ def cursor(self) -> AsyncCursor:

self._current_cursor = self._cursor_cls(
session=self._session,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
autocommit=(not self.interactive_transaction),
)
return self._current_cursor

Expand Down
83 changes: 69 additions & 14 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import AsyncIterator
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
from typing import Union

Expand Down Expand Up @@ -138,15 +139,15 @@ class Cursor(BufferedCursor):
def __init__(
self,
session: ydb.QuerySession,
tx_mode: ydb.BaseQueryTxMode,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
autocommit: bool = True,
) -> None:
super().__init__()
self._session = session
self._tx_mode = tx_mode
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix
self._autocommit = autocommit

self._stream: Iterator | None = None

Expand All @@ -166,6 +167,18 @@ def _execute_generic_query(
) -> Iterator[ydb.convert.ResultSet]:
return self._session.execute(query=query, parameters=parameters)

@handle_ydb_errors
def _execute_session_query(
self,
query: str,
parameters: ParametersType | None = None,
) -> Iterator[ydb.convert.ResultSet]:
return self._session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)

@handle_ydb_errors
def _execute_transactional_query(
self,
Expand All @@ -176,8 +189,21 @@ def _execute_transactional_query(
return tx_context.execute(
query=query,
parameters=parameters,
commit_tx=self._autocommit,
commit_tx=False,
)

def execute_scheme(
self,
query: str,
parameters: ParametersType | None = None,
) -> None:
self._raise_if_closed()

self._stream = self._execute_generic_query(
query=query, parameters=parameters
)
self._begin_query()
self._scroll_stream(replace_current=False)

def execute(
self,
Expand All @@ -191,16 +217,18 @@ def execute(
tx_context=self._tx_context, query=query, parameters=parameters
)
else:
self._stream = self._execute_generic_query(
self._stream = self._execute_session_query(
query=query, parameters=parameters
)

self._begin_query()

self._scroll_stream(replace_current=False)

async def executemany(self) -> None:
pass
def executemany(
self, query: str, seq_of_parameters: Sequence[ParametersType]
) -> None:
for parameters in seq_of_parameters:
self.execute(query, parameters)

@handle_ydb_errors
def nextset(self, replace_current: bool = True) -> bool:
Expand Down Expand Up @@ -249,15 +277,15 @@ class AsyncCursor(BufferedCursor):
def __init__(
self,
session: ydb.aio.QuerySession,
tx_mode: ydb.BaseQueryTxMode,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
autocommit: bool = True,
) -> None:
super().__init__()
self._session = session
self._tx_mode = tx_mode
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix
self._autocommit = autocommit

self._stream: AsyncIterator | None = None

Expand All @@ -277,6 +305,18 @@ async def _execute_generic_query(
) -> AsyncIterator[ydb.convert.ResultSet]:
return await self._session.execute(query=query, parameters=parameters)

@handle_ydb_errors
async def _execute_session_query(
self,
query: str,
parameters: ParametersType | None = None,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await self._session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)

@handle_ydb_errors
async def _execute_transactional_query(
self,
Expand All @@ -287,8 +327,21 @@ async def _execute_transactional_query(
return await tx_context.execute(
query=query,
parameters=parameters,
commit_tx=self._autocommit,
commit_tx=False,
)

async def execute_scheme(
self,
query: str,
parameters: ParametersType | None = None,
) -> None:
self._raise_if_closed()

self._stream = await self._execute_generic_query(
query=query, parameters=parameters
)
self._begin_query()
await self._scroll_stream(replace_current=False)

async def execute(
self,
Expand All @@ -302,16 +355,18 @@ async def execute(
tx_context=self._tx_context, query=query, parameters=parameters
)
else:
self._stream = await self._execute_generic_query(
self._stream = await self._execute_session_query(
query=query, parameters=parameters
)

self._begin_query()

await self._scroll_stream(replace_current=False)

async def executemany(self) -> None:
pass
async def executemany(
self, query: str, seq_of_parameters: Sequence[ParametersType]
) -> None:
for parameters in seq_of_parameters:
await self.execute(query, parameters)

@handle_ydb_errors
async def nextset(self, replace_current: bool = True) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion ydb_dbapi/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.0.1"
VERSION = "0.0.1b1"

0 comments on commit 98dc11c

Please sign in to comment.