Skip to content

Commit

Permalink
Minor: Optimise Databricks Client (#14776)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulixius9 committed Jan 23, 2024
1 parent 52c40ab commit 03b62b0
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 99 deletions.
138 changes: 43 additions & 95 deletions ingestion/src/metadata/ingestion/source/database/databricks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import traceback
from datetime import timedelta
from typing import List
from typing import Iterable, List

import requests

Expand All @@ -37,6 +37,12 @@
COLUMN_LINEAGE_PATH = "/lineage-tracking/column-lineage/get"


class DatabricksClientException(Exception):
"""
Class to throw auth and other databricks api exceptions.
"""


class DatabricksClient:
"""
DatabricksClient creates a Databricks connection based on DatabricksCredentials.
Expand Down Expand Up @@ -66,14 +72,33 @@ def test_query_api_access(self) -> None:
if res.status_code != 200:
raise APIError(res.json)

def _run_query_paginator(self, data, result, end_time, response):
while True:
if response:
next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token
if not has_next_page:
data = {}
break
else:
break

if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
yield from response.get("res") or []

def list_query_history(self, start_date=None, end_date=None) -> List[dict]:
"""
Method returns List the history of queries through SQL warehouses
"""
query_details = []
try:
next_page_token = None
has_next_page = None

data = {}
daydiff = end_date - start_date
Expand Down Expand Up @@ -104,36 +129,15 @@ def list_query_history(self, start_date=None, end_date=None) -> List[dict]:
result = response.get("res") or []
data = {}

while True:
if result:
query_details.extend(result)

next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token

if not has_next_page:
data = {}
break
else:
break

if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
result = response.get("res")
yield from result
yield from self._run_query_paginator(
data=data, result=result, end_time=end_time, response=response
) or []

except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)

return query_details

def is_query_valid(self, row) -> bool:
query_text = row.get("query_text")
return not (
Expand All @@ -143,18 +147,19 @@ def is_query_valid(self, row) -> bool:

def list_jobs_test_connection(self) -> None:
data = {"limit": 1, "expand_tasks": True, "offset": 0}
self.client.get(
response = self.client.get(
self.jobs_list_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
)
if response.status_code != 200:
raise DatabricksClientException(response.text)

def list_jobs(self) -> List[dict]:
def list_jobs(self) -> Iterable[dict]:
"""
Method returns List all the created jobs in a Databricks Workspace
"""
job_list = []
try:
data = {"limit": 25, "expand_tasks": True, "offset": 0}

Expand All @@ -165,9 +170,9 @@ def list_jobs(self) -> List[dict]:
timeout=API_TIMEOUT,
).json()

job_list.extend(response.get("jobs") or [])
yield from response.get("jobs") or []

while response["has_more"]:
while response and response.get("has_more"):
data["offset"] = len(response.get("jobs") or [])

response = self.client.get(
Expand All @@ -177,19 +182,16 @@ def list_jobs(self) -> List[dict]:
timeout=API_TIMEOUT,
).json()

job_list.extend(response.get("jobs") or [])
yield from response.get("jobs") or []

except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)

return job_list

def get_job_runs(self, job_id) -> List[dict]:
"""
Method returns List of all runs for a job by the specified job_id
"""
job_runs = []
try:
params = {
"job_id": job_id,
Expand All @@ -206,7 +208,7 @@ def get_job_runs(self, job_id) -> List[dict]:
timeout=API_TIMEOUT,
).json()

job_runs.extend(response.get("runs") or [])
yield from response.get("runs") or []

while response["has_more"]:
params.update({"start_time_to": response["runs"][-1]["start_time"]})
Expand All @@ -218,62 +220,8 @@ def get_job_runs(self, job_id) -> List[dict]:
timeout=API_TIMEOUT,
).json()

job_runs.extend(response.get("runs" or []))
yield from response.get("runs") or []

except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)

return job_runs

def get_table_lineage(self, table_name: str) -> LineageTableStreams:
"""
Method returns table lineage details
"""
try:
data = {
"table_name": table_name,
}

response = self.client.get(
f"{self.base_url}{TABLE_LINEAGE_PATH}",
headers=self.headers,
data=json.dumps(data),
timeout=API_TIMEOUT,
).json()
if response:
return LineageTableStreams(**response)

except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)

return LineageTableStreams()

def get_column_lineage(
self, table_name: str, column_name: str
) -> LineageColumnStreams:
"""
Method returns table lineage details
"""
try:
data = {
"table_name": table_name,
"column_name": column_name,
}

response = self.client.get(
f"{self.base_url}{COLUMN_LINEAGE_PATH}",
headers=self.headers,
data=json.dumps(data),
timeout=API_TIMEOUT,
).json()

if response:
return LineageColumnStreams(**response)

except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)

return LineageColumnStreams()
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def yield_table_query(self) -> Iterator[TableQuery]:
start_date=self.start,
end_date=self.end,
)
for row in data:
for row in data or []:
try:
if self.client.is_query_valid(row):
yield TableQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def yield_table_queries(self) -> Optional[Iterable[TableQuery]]:
start_date=self.start,
end_date=self.end,
)
for row in data:
for row in data or []:
try:
if self.client.is_query_valid(row):
queries.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create(cls, config_dict, metadata: OpenMetadata):
return cls(config, metadata)

def get_pipelines_list(self) -> Iterable[dict]:
for workflow in self.client.list_jobs():
for workflow in self.client.list_jobs() or []:
yield workflow

def get_pipeline_name(self, pipeline_details: dict) -> str:
Expand Down Expand Up @@ -192,7 +192,7 @@ def yield_pipeline_status(self, pipeline_details) -> Iterable[OMetaPipelineStatu
for job_id in self.context.job_id_list:
try:
runs = self.client.get_job_runs(job_id=job_id)
for attempt in runs:
for attempt in runs or []:
for task_run in attempt["tasks"]:
task_status = []
task_status.append(
Expand Down

0 comments on commit 03b62b0

Please sign in to comment.