Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1263] Implement a .returned_as_direct_result property for AsyncExecution status #325

Open
wants to merge 6 commits into
base: peco-1263-staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/databricks/sql/ae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Optional, Union, TYPE_CHECKING
from databricks.sql.exc import RequestError
from databricks.sql.results import ResultSet
from databricks.sql.results import ResultSet, execute_response_contains_direct_results

from datetime import datetime

Expand Down Expand Up @@ -81,6 +81,7 @@ class AsyncExecution:
]
_last_sync_timestamp: Optional[datetime] = None
_result_set: Optional["ResultSet"] = None
_returned_as_direct_result: bool = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need change get_result method to also check this flag?


def __init__(
self,
Expand All @@ -101,6 +102,8 @@ def __init__(

if execute_statement_response:
self._execute_statement_response = execute_statement_response
if execute_response_contains_direct_results(execute_statement_response):
self._returned_as_direct_result = True
else:
self._execute_statement_response = FakeExecuteStatementResponse(
directResults=False, operationHandle=self.t_operation_handle
Expand Down Expand Up @@ -225,6 +228,17 @@ def last_sync_timestamp(self) -> Optional[datetime]:
"""The timestamp of the last time self.status was synced with the server"""
return self._last_sync_timestamp

@property
def returned_as_direct_result(self) -> bool:
"""When direct results were returned, this query_id cannot be picked up
with `Connection.get_async_execution()`

Only returns True if the query returned its results directly when `execute_async`
was called.
"""

return self._returned_as_direct_result

@classmethod
def from_thrift_response(
cls,
Expand Down
30 changes: 30 additions & 0 deletions src/databricks/sql/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from databricks.sql.exc import (
CursorAlreadyClosedError,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import Row
from databricks.sql.utils import ExecuteResponse

Expand All @@ -17,6 +18,10 @@
from databricks.sql.client import Connection
from databricks.sql.thrift_backend import ThriftBackend

import logging

logger = logging.getLogger(__name__)

# TODO: this is duplicated from client.py to avoid ImportError. Fix this.
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600

Expand Down Expand Up @@ -223,3 +228,28 @@ def map_col_type(type_):
(column.name, map_col_type(column.datatype), None, None, None, None, None)
for column in table_schema_message.columns
]


def execute_response_contains_direct_results(
execute_response: ttypes.TExecuteStatementResp,
) -> bool:
"""
Returns True if the thrift TExecuteStatementResp returned a direct result.

When directResults is used the server just batches these rpcs together,
if the entire result can be returned in a single round-trip:

struct TSparkDirectResults {
1: optional TGetOperationStatusResp operationStatus
2: optional TGetResultSetMetadataResp resultSetMetadata
3: optional TFetchResultsResp resultSet
4: optional TCloseOperationResp closeOperation
}
"""

has_op_status = execute_response.directResults.operationStatus
has_result_set = execute_response.directResults.resultSet
has_metadata = execute_response.directResults.resultSetMetadata
has_close_op = execute_response.directResults.closeOperation

return has_op_status and has_result_set and has_metadata and has_close_op
25 changes: 12 additions & 13 deletions tests/e2e/test_execute_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_direct_results_query_canary(self):

with self.connection() as conn:
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
assert not ae.is_running
assert ae.returned_as_direct_result

def test_cancel_running_query(self, long_running_ae: AsyncExecution):
long_running_ae.cancel()
Expand Down Expand Up @@ -112,21 +112,14 @@ def cancel_query_in_separate_thread(query_id, query_secret):
assert long_running_ae.status == AsyncExecutionStatus.CANCELED

def test_long_ish_query_canary(self, long_ish_ae: AsyncExecution):
"""This test verifies that on the current endpoint, the LONG_ISH_QUERY requires
at least one sync_status call before it is finished. If this test fails, it means
the SQL warehouse got faster at executing this query and we should increment the value
of GT_FIVE_SECONDS_VALUE
"""This test verifies that on the current endpoint, the LONG_ISH_QUERY does not return direct results.

It would be easier to do this if Databricks SQL had a SLEEP() function :/
"""

poll_count = 0
while long_ish_ae.is_running:
time.sleep(1)
long_ish_ae.sync_status()
poll_count += 1
We could acheive something similar by overriding the directResults setting in our ExecuteStatementReq
"""

assert poll_count > 0
assert not long_ish_ae.returned_as_direct_result

def test_get_async_execution_and_get_results_without_direct_results(
self, long_ish_ae: AsyncExecution
Expand Down Expand Up @@ -162,10 +155,13 @@ def test_serialize(self, long_running_ae: AsyncExecution):
assert ae.is_running

def test_get_async_execution_no_results_when_direct_results_were_sent(self):
"""It remains to be seen whether results can be fetched repeatedly from a "picked up" execution."""
"""Queries that return direct results cannot be picked up with `get_async_execution()`."""

with self.connection() as conn:
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
assert (
ae.returned_as_direct_result
), "Queries that return direct results should not be available"
query_id, query_secret = ae.serialize().split(":")
ae.get_result()

Expand Down Expand Up @@ -193,9 +189,12 @@ def test_get_async_execution_twice(self):
"""
with self.connection() as conn_1, self.connection() as conn_2:
ae_1 = conn_1.execute_async(LONG_ISH_QUERY)
assert not ae_1.returned_as_direct_result


query_id, query_secret = ae_1.serialize().split(":")
ae_2 = conn_2.get_async_execution(query_id, query_secret)
assert not ae_2.returned_as_direct_result

while ae_1.is_running:
time.sleep(1)
Expand Down
Loading