Skip to content

Commit

Permalink
AlationSink conn improvements (#18091)
Browse files Browse the repository at this point in the history
  • Loading branch information
OnkarVO7 committed Oct 3, 2024
1 parent 84441c4 commit b1dcb11
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
"""

import traceback
from typing import Iterable, Optional
from typing import Iterable, List, Optional

from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import Column, Constraint, Table
from metadata.generated.schema.entity.data.table import (
Column,
Constraint,
ConstraintType,
Table,
TableConstraint,
)
from metadata.generated.schema.entity.services.connections.metadata.alationSinkConnection import (
AlationSinkConnection,
)
Expand Down Expand Up @@ -108,7 +114,9 @@ def create_datasource_request(
),
),
db_username="Test",
title=model_str(om_database.name),
title=om_database.displayName
if om_database.displayName
else model_str(om_database.name),
description=model_str(om_database.description),
)
except Exception as exc:
Expand All @@ -129,7 +137,9 @@ def create_schema_request(
key=fqn._build( # pylint: disable=protected-access
str(alation_datasource_id), model_str(om_schema.name)
),
title=model_str(om_schema.name),
title=om_schema.displayName
if om_schema.displayName
else model_str(om_schema.name),
description=model_str(om_schema.description),
)
except Exception as exc:
Expand All @@ -150,7 +160,9 @@ def create_table_request(
key=fqn._build( # pylint: disable=protected-access
str(alation_datasource_id), schema_name, model_str(om_table.name)
),
title=model_str(om_table.name),
title=om_table.displayName
if om_table.displayName
else model_str(om_table.name),
description=model_str(om_table.description),
table_type=TABLE_TYPE_MAPPER.get(om_table.tableType, "TABLE"),
sql=om_table.schemaDefinition,
Expand All @@ -162,14 +174,60 @@ def create_table_request(
)
return None

def _get_column_index(self, om_column: Column) -> Optional[ColumnIndex]:
def _update_foreign_key(
self,
alation_datasource_id: int,
om_column: Column,
table_constraints: Optional[List[TableConstraint]],
column_index: ColumnIndex,
):
"""
Method to update the foreign key metadata in columns index
"""
try:
for table_constraint in table_constraints or []:
if table_constraint.constraintType == ConstraintType.FOREIGN_KEY:
for i, constraint_column in enumerate(
table_constraint.columns or []
):
if constraint_column == model_str(om_column.name):
column_index.isForeignKey = True
# update the service name of OM with the alation datasource id in the column FQN
splitted_col_fqn = fqn.split(
model_str(table_constraint.referredColumns[i])
)
splitted_col_fqn[0] = str(alation_datasource_id)
column_index.referencedColumnId = (
fqn._build( # pylint: disable=protected-access
*splitted_col_fqn
)
)
break
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Failed to update foreign key for {model_str(om_column.name)}: {exc}"
)

def _get_column_index(
self,
alation_datasource_id: int,
om_column: Column,
table_constraints: Optional[List[TableConstraint]],
) -> Optional[ColumnIndex]:
"""
Method to get the alation column index
"""
column_index = ColumnIndex()
try:
# Attach the primary key
if om_column.constraint == Constraint.PRIMARY_KEY:
return ColumnIndex(isPrimaryKey=True)
column_index.isPrimaryKey = True

# Attach the foreign key
self._update_foreign_key(
alation_datasource_id, om_column, table_constraints, column_index
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
Expand Down Expand Up @@ -199,6 +257,7 @@ def create_column_request(
schema_name: str,
table_name: str,
om_column: Column,
table_constraints: Optional[List[TableConstraint]],
) -> Optional[CreateColumnRequest]:
"""
Method to form the CreateColumnRequest object
Expand All @@ -211,13 +270,19 @@ def create_column_request(
table_name,
model_str(om_column.name),
),
column_type=om_column.dataType.value.lower(),
title=model_str(om_column.name),
column_type=om_column.dataTypeDisplay.lower()
if om_column.dataTypeDisplay
else om_column.dataType.value.lower(),
title=om_column.displayName
if om_column.displayName
else model_str(om_column.name),
description=model_str(om_column.description),
position=str(om_column.ordinalPosition)
if om_column.ordinalPosition
else None,
index=self._get_column_index(om_column),
index=self._get_column_index(
alation_datasource_id, om_column, table_constraints
),
nullable=self._check_nullable_column(om_column),
)
except Exception as exc:
Expand All @@ -241,6 +306,7 @@ def ingest_columns(
schema_name=schema_name,
table_name=model_str(om_table.name),
om_column=om_column,
table_constraints=om_table.tableConstraints,
)
if create_column_request:
create_requests.root.append(create_column_request)
Expand All @@ -266,6 +332,7 @@ def ingest_tables(self, alation_datasource_id: int, om_schema: DatabaseSchema):
entity=Table,
skip_on_failure=True,
params={"database": model_str(om_schema.fullyQualifiedName)},
fields=["tableConstraints, columns"],
)
)
create_requests = CreateTableRequestList(root=[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def test_columns(self):
schema_name="shopify",
table_name=model_str(om_table.name),
om_column=om_column,
table_constraints=om_table.tableConstraints,
)
)
for _, (expected, original) in enumerate(
Expand Down

0 comments on commit b1dcb11

Please sign in to comment.