Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework record/replay to record at the database connection level. #244

Merged
merged 8 commits into from
Jul 16, 2024
241 changes: 193 additions & 48 deletions dbt/adapters/record.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,212 @@
import dataclasses
from io import StringIO
import json
import re
from typing import Any, Optional, Mapping
import datetime
from typing import Any, Dict, Optional, Mapping, List, Union, Iterable

from agate import Table
from dbt.adapters.contracts.connection import Connection

from dbt_common.events.contextvars import get_node_info
from dbt_common.record import Record, Recorder
from dbt_common.record import Record, Recorder, record_function

from dbt.adapters.contracts.connection import AdapterResponse

class RecordReplayHandle:
def __init__(self, native_handle: Any, connection: Connection) -> None:
self.native_handle = native_handle
self.connection = connection

def cursor(self):
# The native handle could be None if we are in replay mode, because no
# actual database access should be performed in that mode.
cursor = None if self.native_handle is None else self.native_handle.cursor()
return RecordReplayCursor(cursor, self.connection)


@dataclasses.dataclass
class CursorExecuteParams:
colin-rogers-dbt marked this conversation as resolved.
Show resolved Hide resolved
connection_name: str
operation: str
parameters: Union[Iterable[Any], Mapping[str, Any]]


class CursorExecuteRecord(Record):
colin-rogers-dbt marked this conversation as resolved.
Show resolved Hide resolved
params_cls = CursorExecuteParams
result_cls = None


Recorder.register_record_type(CursorExecuteRecord)


@dataclasses.dataclass
class CursorFetchOneParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchOneResult:
result: Any


class CursorFetchOneRecord(Record):
params_cls = CursorFetchOneParams
result_cls = CursorFetchOneResult


Recorder.register_record_type(CursorFetchOneRecord)


@dataclasses.dataclass
class CursorFetchManyParams:
connection_name: str


@dataclasses.dataclass
class QueryRecordParams:
sql: str
auto_begin: bool = False
fetch: bool = False
limit: Optional[int] = None
node_unique_id: Optional[str] = None

def __post_init__(self) -> None:
if self.node_unique_id is None:
node_info = get_node_info()
self.node_unique_id = node_info["unique_id"] if node_info else ""

@staticmethod
def _clean_up_sql(sql: str) -> str:
sql = re.sub(r"--.*?\n", "", sql) # Remove single-line comments (--)
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) # Remove multi-line comments (/* */)
return sql.replace(" ", "").replace("\n", "")

def _matches(self, other: "QueryRecordParams") -> bool:
return self.node_unique_id == other.node_unique_id and self._clean_up_sql(
self.sql
) == self._clean_up_sql(other.sql)
class CursorFetchManyResult:
results: List[Any]


class CursorFetchManyRecord(Record):
params_cls = CursorFetchManyParams
result_cls = CursorFetchManyResult


Recorder.register_record_type(CursorFetchManyRecord)


@dataclasses.dataclass
class QueryRecordResult:
adapter_response: Optional["AdapterResponse"]
table: Optional[Table]
class CursorFetchAllParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchAllResult:
results: List[Any]

def _to_dict(self) -> Dict[str, Any]:
processed_results = []
for result in self.results:
result = tuple(map(self._process_value, result))
processed_results.append(result)

return {"results": processed_results}

@classmethod
def _from_dict(cls, dct: Mapping) -> "CursorFetchAllResult":
unprocessed_results = []
for result in dct["results"]:
result = tuple(map(cls._unprocess_value, result))
unprocessed_results.append(result)

return CursorFetchAllResult(unprocessed_results)

@classmethod
def _process_value(self, value: Any) -> Any:
if type(value) is datetime.date:
return {"type": "date", "value": value.isoformat()}
elif type(value) is datetime.datetime:
return {"type": "datetime", "value": value.isoformat()}
else:
return value

@classmethod
def _unprocess_value(self, value: Any) -> Any:
if type(value) is dict:
value_type = value.get("type")
if value_type == "date":
return datetime.date.fromisoformat(value.get("value"))
elif value_type == "datetime":
return datetime.datetime.fromisoformat(value.get("value"))
return value
else:
return value


class CursorFetchAllRecord(Record):
params_cls = CursorFetchAllParams
result_cls = CursorFetchAllResult


Recorder.register_record_type(CursorFetchAllRecord)


@dataclasses.dataclass
class CursorGetRowCountParams:
connection_name: str


@dataclasses.dataclass
class CursorGetRowCountResult:
rowcount: Optional[int]


class CursorGetRowCountRecord(Record):
params_cls = CursorGetRowCountParams
result_cls = CursorGetRowCountResult


Recorder.register_record_type(CursorGetRowCountRecord)


@dataclasses.dataclass
class CursorGetDescriptionParams:
connection_name: str


@dataclasses.dataclass
class CursorGetDescriptionResult:
columns: Iterable[Any]

def _to_dict(self) -> Any:
buf = StringIO()
self.table.to_json(buf) # type: ignore
column_dicts = []
for c in self.columns:
# This captures the mandatory column information, but we might need
# more for some adapters.
# See https://peps.python.org/pep-0249/#description
column_dicts.append((c[0], c[1]))

return {
"adapter_response": self.adapter_response.to_dict(), # type: ignore
"table": buf.getvalue(),
}
return {"columns": column_dicts}

@classmethod
def _from_dict(cls, dct: Mapping) -> "QueryRecordResult":
return QueryRecordResult(
adapter_response=AdapterResponse.from_dict(dct["adapter_response"]),
table=Table.from_object(json.loads(dct["table"])),
)
def _from_dict(cls, dct: Mapping) -> "CursorGetDescriptionResult":
return CursorGetDescriptionResult(columns=dct["columns"])


class CursorGetDescriptionRecord(Record):
params_cls = CursorGetDescriptionParams
result_cls = CursorGetDescriptionResult


Recorder.register_record_type(CursorGetDescriptionRecord)


class RecordReplayCursor:
def __init__(self, native_cursor: Any, connection: Connection) -> None:
self.native_cursor = native_cursor
self.connection = connection

@record_function(CursorExecuteRecord, method=True, id_field_name="connection_name")
def execute(self, operation, parameters=None) -> None:
self.native_cursor.execute(operation, parameters)

@record_function(CursorFetchOneRecord, method=True, id_field_name="connection_name")
def fetchone(self) -> Any:
return self.native_cursor.fetchone()

@record_function(CursorFetchManyRecord, method=True, id_field_name="connection_name")
def fetchmany(self, size: int) -> Any:
return self.native_cursor.fetchmany(size)

@record_function(CursorFetchAllRecord, method=True, id_field_name="connection_name")
def fetchall(self) -> Any:
return self.native_cursor.fetchall()

class QueryRecord(Record):
params_cls = QueryRecordParams
result_cls = QueryRecordResult
@property
def connection_name(self) -> Optional[str]:
return self.connection.name

@property
@record_function(CursorGetRowCountRecord, method=True, id_field_name="connection_name")
def rowcount(self) -> int:
return self.native_cursor.rowcount

Recorder.register_record_type(QueryRecord)
@property
@record_function(CursorGetDescriptionRecord, method=True, id_field_name="connection_name")
def description(self) -> str:
return self.native_cursor.description
3 changes: 0 additions & 3 deletions dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtInternalError, NotImplementedError
from dbt_common.record import record_function
from dbt_common.utils import cast_to_str

from dbt.adapters.base import BaseConnectionManager
Expand All @@ -20,7 +19,6 @@
SQLQuery,
SQLQueryStatus,
)
from dbt.adapters.record import QueryRecord

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -143,7 +141,6 @@ def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Tab

return table_from_data_flat(data, column_names)

@record_function(QueryRecord, method=True, tuple_result=True)
def execute(
self,
sql: str,
Expand Down
Loading