From 03b62b0391f1885ce73ad32d062dfed1a5b0703b Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:28:02 +0530 Subject: [PATCH] Minor: Optimise Databricks Client (#14776) --- .../source/database/databricks/client.py | 138 ++++++------------ .../database/databricks/legacy/lineage.py | 2 +- .../source/database/databricks/usage.py | 2 +- .../pipeline/databrickspipeline/metadata.py | 4 +- 4 files changed, 47 insertions(+), 99 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/client.py b/ingestion/src/metadata/ingestion/source/database/databricks/client.py index 843454b25908..8f86634e10e1 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/client.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/client.py @@ -14,7 +14,7 @@ import json import traceback from datetime import timedelta -from typing import List +from typing import Iterable, List import requests @@ -37,6 +37,12 @@ COLUMN_LINEAGE_PATH = "/lineage-tracking/column-lineage/get" +class DatabricksClientException(Exception): + """ + Class to throw auth and other databricks api exceptions. + """ + + class DatabricksClient: """ DatabricksClient creates a Databricks connection based on DatabricksCredentials. @@ -66,14 +72,33 @@ def test_query_api_access(self) -> None: if res.status_code != 200: raise APIError(res.json) + def _run_query_paginator(self, data, result, end_time, response): + while True: + if response: + next_page_token = response.get("next_page_token", None) + has_next_page = response.get("has_next_page", None) + if next_page_token: + data["page_token"] = next_page_token + if not has_next_page: + data = {} + break + else: + break + + if result[-1]["execution_end_time_ms"] <= end_time: + response = self.client.get( + self.base_query_url, + data=json.dumps(data), + headers=self.headers, + timeout=API_TIMEOUT, + ).json() + yield from response.get("res") or [] + def list_query_history(self, start_date=None, end_date=None) -> List[dict]: """ Method returns List the history of queries through SQL warehouses """ - query_details = [] try: - next_page_token = None - has_next_page = None data = {} daydiff = end_date - start_date @@ -104,36 +129,15 @@ def list_query_history(self, start_date=None, end_date=None) -> List[dict]: result = response.get("res") or [] data = {} - while True: - if result: - query_details.extend(result) - - next_page_token = response.get("next_page_token", None) - has_next_page = response.get("has_next_page", None) - if next_page_token: - data["page_token"] = next_page_token - - if not has_next_page: - data = {} - break - else: - break - - if result[-1]["execution_end_time_ms"] <= end_time: - response = self.client.get( - self.base_query_url, - data=json.dumps(data), - headers=self.headers, - timeout=API_TIMEOUT, - ).json() - result = response.get("res") + yield from result + yield from self._run_query_paginator( + data=data, result=result, end_time=end_time, response=response + ) or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - return query_details - def is_query_valid(self, row) -> bool: query_text = row.get("query_text") return not ( @@ -143,18 +147,19 @@ def is_query_valid(self, row) -> bool: def list_jobs_test_connection(self) -> None: data = {"limit": 1, "expand_tasks": True, "offset": 0} - self.client.get( + response = self.client.get( self.jobs_list_url, data=json.dumps(data), headers=self.headers, timeout=API_TIMEOUT, - ).json() + ) + if response.status_code != 200: + raise DatabricksClientException(response.text) - def list_jobs(self) -> List[dict]: + def list_jobs(self) -> Iterable[dict]: """ Method returns List all the created jobs in a Databricks Workspace """ - job_list = [] try: data = {"limit": 25, "expand_tasks": True, "offset": 0} @@ -165,9 +170,9 @@ def list_jobs(self) -> List[dict]: timeout=API_TIMEOUT, ).json() - job_list.extend(response.get("jobs") or []) + yield from response.get("jobs") or [] - while response["has_more"]: + while response and response.get("has_more"): data["offset"] = len(response.get("jobs") or []) response = self.client.get( @@ -177,19 +182,16 @@ def list_jobs(self) -> List[dict]: timeout=API_TIMEOUT, ).json() - job_list.extend(response.get("jobs") or []) + yield from response.get("jobs") or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - return job_list - def get_job_runs(self, job_id) -> List[dict]: """ Method returns List of all runs for a job by the specified job_id """ - job_runs = [] try: params = { "job_id": job_id, @@ -206,7 +208,7 @@ def get_job_runs(self, job_id) -> List[dict]: timeout=API_TIMEOUT, ).json() - job_runs.extend(response.get("runs") or []) + yield from response.get("runs") or [] while response["has_more"]: params.update({"start_time_to": response["runs"][-1]["start_time"]}) @@ -218,62 +220,8 @@ def get_job_runs(self, job_id) -> List[dict]: timeout=API_TIMEOUT, ).json() - job_runs.extend(response.get("runs" or [])) + yield from response.get("runs") or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - - return job_runs - - def get_table_lineage(self, table_name: str) -> LineageTableStreams: - """ - Method returns table lineage details - """ - try: - data = { - "table_name": table_name, - } - - response = self.client.get( - f"{self.base_url}{TABLE_LINEAGE_PATH}", - headers=self.headers, - data=json.dumps(data), - timeout=API_TIMEOUT, - ).json() - if response: - return LineageTableStreams(**response) - - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.error(exc) - - return LineageTableStreams() - - def get_column_lineage( - self, table_name: str, column_name: str - ) -> LineageColumnStreams: - """ - Method returns table lineage details - """ - try: - data = { - "table_name": table_name, - "column_name": column_name, - } - - response = self.client.get( - f"{self.base_url}{COLUMN_LINEAGE_PATH}", - headers=self.headers, - data=json.dumps(data), - timeout=API_TIMEOUT, - ).json() - - if response: - return LineageColumnStreams(**response) - - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.error(exc) - - return LineageColumnStreams() diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/legacy/lineage.py b/ingestion/src/metadata/ingestion/source/database/databricks/legacy/lineage.py index 6ff88db4e5cf..0e9388d0e574 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/legacy/lineage.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/legacy/lineage.py @@ -35,7 +35,7 @@ def yield_table_query(self) -> Iterator[TableQuery]: start_date=self.start, end_date=self.end, ) - for row in data: + for row in data or []: try: if self.client.is_query_valid(row): yield TableQuery( diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/usage.py b/ingestion/src/metadata/ingestion/source/database/databricks/usage.py index 22cc194c950b..be71d21bf11f 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/usage.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/usage.py @@ -39,7 +39,7 @@ def yield_table_queries(self) -> Optional[Iterable[TableQuery]]: start_date=self.start, end_date=self.end, ) - for row in data: + for row in data or []: try: if self.client.is_query_valid(row): queries.append( diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py index 0562c196d85d..43ac611c07a6 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py @@ -78,7 +78,7 @@ def create(cls, config_dict, metadata: OpenMetadata): return cls(config, metadata) def get_pipelines_list(self) -> Iterable[dict]: - for workflow in self.client.list_jobs(): + for workflow in self.client.list_jobs() or []: yield workflow def get_pipeline_name(self, pipeline_details: dict) -> str: @@ -192,7 +192,7 @@ def yield_pipeline_status(self, pipeline_details) -> Iterable[OMetaPipelineStatu for job_id in self.context.job_id_list: try: runs = self.client.get_job_runs(job_id=job_id) - for attempt in runs: + for attempt in runs or []: for task_run in attempt["tasks"]: task_status = [] task_status.append(