Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: MERGE/Upsert Support #1534

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@
from pyiceberg.utils.config import Config
from pyiceberg.utils.properties import property_as_bool

from pyiceberg.table import merge_rows_util

if TYPE_CHECKING:
import daft
import pandas as pd
Expand Down Expand Up @@ -1064,6 +1066,134 @@ def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
return self.metadata.name_mapping()

def merge_rows(self, df: pa.Table, join_cols: list
,merge_options: dict = {'when_matched_update_all': True, 'when_not_matched_insert_all': True}
) -> Dict:
"""
Shorthand API for performing an upsert/merge to an iceberg table.

Args:
df: The input dataframe to merge with the table's data.
join_cols: The columns to join on.
merge_options: A dictionary of merge actions to perform. Currently supports these predicates >
when_matched_update_all: default is True
when_not_matched_insert_all: default is True

Returns:
A dictionary containing the number of rows updated and inserted.
"""

#merge_rows_util is a file
try:
from datafusion import SessionContext
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For merge_rows, DataFusion needs to be installed") from e

try:
from pyarrow import dataset as ds
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For merge_rows, PyArrow needs to be installed") from e

source_table_name = "source"
target_table_name = "target"


if merge_options is None or merge_options == {}:
merge_options = {'when_matched_update_all': True, 'when_not_matched_insert_all': True}

when_matched_update_all = merge_options.get('when_matched_update_all', False)
when_not_matched_insert_all = merge_options.get('when_not_matched_insert_all', False)

if when_matched_update_all == False and when_not_matched_insert_all == False:
return {'rows_updated': 0, 'rows_inserted': 0, 'msg': 'no merge options selected...exiting'}

ctx = SessionContext()

#register both source and target tables so we can find the deltas to update/append
ctx.register_dataset(source_table_name, ds.dataset(df))
ctx.register_dataset(target_table_name, ds.dataset(self.scan().to_arrow()))

source_col_list = merge_rows_util.get_table_column_list(ctx, source_table_name)
target_col_list = merge_rows_util.get_table_column_list(ctx, target_table_name)

source_col_names = set([col[0] for col in source_col_list])
target_col_names = set([col[0] for col in target_col_list])

source_col_types = {col[0]: col[1] for col in source_col_list}
#target_col_types = {col[0]: col[1] for col in target_col_list}

missing_columns = merge_rows_util.do_join_columns_exist(source_col_names, target_col_names, join_cols)

if missing_columns['source'] or missing_columns['target']:

return {'error_msgs': f"Join columns missing in tables: Source table columns missing: {missing_columns['source']}, Target table columns missing: {missing_columns['target']}"}

#raise Exception(f"Join columns missing in tables: Source table columns missing: {missing_columns['source']}, Target table columns missing: {missing_columns['target']}")

#check for dups on source
if merge_rows_util.dups_check_in_source(source_table_name, join_cols, ctx):

return {'error_msgs': 'Duplicate rows found in source dataset based on the key columns. No Merge executed'}

#raise Exception(f"Duplicate rows found in source table based on the key columns [{', '.join(join_cols)}]")

update_row_cnt = 0
insert_row_cnt = 0

txn = self.transaction()

try:

if when_matched_update_all:

# Get the rows to update
update_recs_sql = merge_rows_util.get_rows_to_update_sql(source_table_name, target_table_name, join_cols, source_col_names, target_col_names)
#print(update_recs_sql)
update_recs = ctx.sql(update_recs_sql).to_arrow_table()

update_row_cnt = len(update_recs)

if len(join_cols) == 1:
join_col = join_cols[0]
col_type = source_col_types[join_col]
values = [row[join_col] for row in update_recs.to_pylist()]
# if strings are in the filter, we encapsulate with tick marks
formatted_values = [f"'{value}'" if col_type == 'string' else str(value) for value in values]
overwrite_filter = f"{join_col} IN ({', '.join(formatted_values)})"
else:
overwrite_filter = " OR ".join(
f"({' AND '.join([f'{col} = {repr(row[col])}' if source_col_types[col] != 'string' else f'{col} = {repr(row[col])}' for col in join_cols])})"
for row in update_recs.to_pylist()
)

#print(f"overwrite_filter: {overwrite_filter}")

txn.overwrite(update_recs, overwrite_filter)

# Insert the new records

if when_not_matched_insert_all:
insert_recs_sql = merge_rows_util.get_rows_to_insert_sql(source_table_name, target_table_name, join_cols, source_col_names, target_col_names)

insert_recs = ctx.sql(insert_recs_sql).to_arrow_table()

insert_row_cnt = len(insert_recs)

txn.append(insert_recs)

if when_matched_update_all or when_not_matched_insert_all:
txn.commit_transaction()
#print("commited changes")

return {
"rows_updated": update_row_cnt,
"rows_inserted": insert_row_cnt
}

except Exception as e:
print(f"Error: {e}")
raise e

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to the table.
Expand Down
117 changes: 117 additions & 0 deletions pyiceberg/table/merge_rows_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

from datafusion import SessionContext

def get_table_column_list(connection: SessionContext, table_name: str) -> list:
"""
This function retrieves the column names and their data types for the specified table.
It returns a list of tuples where each tuple contains the column name and its data type.

Args:
connection: DataFusion SessionContext.
table_name: The name of the table for which to retrieve column information.

Returns:
A list of tuples containing column names and their corresponding data types.
"""
# DataFusion logic
res = connection.sql(f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = '{table_name}'
""").collect()

column_names = res[0][0].to_pylist() # Extract the first column (column names)
data_types = res[0][1].to_pylist() # Extract the second column (data types)

return list(zip(column_names, data_types)) # Combine into list of tuples

def dups_check_in_source(source_table: str, join_cols: list, connection: SessionContext) -> bool:
"""
This function checks if there are duplicate rows in the source and target tables based on the join columns.
It returns True if there are duplicate rows in either table, otherwise it returns False.
"""
# Check for duplicates in the source table
source_dup_sql = f"""
SELECT {', '.join(join_cols)}, COUNT(*)
FROM {source_table}
GROUP BY {', '.join(join_cols)}
HAVING COUNT(*) > 1
LIMIT 1
"""
source_dup_df = connection.sql(source_dup_sql).collect()
source_dup_count = len(source_dup_df)

return source_dup_count > 0

def do_join_columns_exist(source_col_list: set, target_col_list: set, join_cols: list) -> bool:

"""
This function checks if the join columns exist in both the source and target tables.
It returns a dictionary indicating which join columns are missing from each table.
"""
missing_columns = {
'source': [],
'target': []
}

for col in join_cols:
if col not in source_col_list:
missing_columns['source'].append(col)
if col not in target_col_list:
missing_columns['target'].append(col)

return missing_columns



def get_rows_to_update_sql(source_table_name: str, target_table_name: str
, join_cols: list
, source_cols_list: set
, target_cols_list: set) -> str:
"""
This function returns the rows that need to be updated in the target table based on the source table.
It compares the source and target tables based on the join columns and returns the rows that have different values in the non-join columns.
"""

# Determine non-join columns that exist in both tables
non_join_cols = source_cols_list.intersection(target_cols_list) - set(join_cols)


sql = f"""
SELECT {', '.join([f"src.{col}" for col in join_cols])},
{', '.join([f"src.{col}" for col in non_join_cols])}
FROM {source_table_name} as src
INNER JOIN {target_table_name} as tgt
ON {' AND '.join([f"src.{col} = tgt.{col}" for col in join_cols])}
EXCEPT DISTINCT
SELECT {', '.join([f"tgt.{col}" for col in join_cols])},
{', '.join([f"tgt.{col}" for col in non_join_cols])}
FROM {target_table_name} as tgt
"""
return sql


def get_rows_to_insert_sql(source_table_name: str, target_table_name: str
, join_cols: list
, source_cols_list: set
, target_cols_list: set) -> str:


# Determine non-join columns that exist in both tables
insert_cols = source_cols_list.intersection(target_cols_list) - set(join_cols)

# Build the SQL query
sql = f"""
SELECT
{', '.join([f"src.{col}" for col in join_cols])},
{', '.join([f"src.{col}" for col in insert_cols])}
FROM
{source_table_name} as src
LEFT JOIN
{target_table_name} as tgt
ON
{' AND '.join([f"src.{col} = tgt.{col}" for col in join_cols])}
WHERE
tgt.{join_cols[0]} IS NULL
"""
return sql
1 change: 1 addition & 0 deletions pyiceberg/table/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print('test')
34 changes: 34 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ tenacity = ">=8.2.3,<10.0.0"
pyarrow = { version = ">=14.0.0,<19.0.0", optional = true }
pandas = { version = ">=1.0.0,<3.0.0", optional = true }
duckdb = { version = ">=0.5.0,<2.0.0", optional = true }
datafusion = { version = "43.1.0", optional = true }
ray = [
{ version = "==2.10.0", python = "<3.9", optional = true },
{ version = ">=2.10.0,<3.0.0", python = ">=3.9", optional = true },
Expand Down Expand Up @@ -226,6 +227,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -378,6 +383,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -530,6 +539,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -682,6 +695,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -834,6 +851,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -986,6 +1007,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -1138,6 +1163,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -1193,6 +1222,7 @@ script = "build-module.py"
pyarrow = ["pyarrow"]
pandas = ["pandas", "pyarrow"]
duckdb = ["duckdb", "pyarrow"]
datafusion = ["datafusion", "pyarrow"]
ray = ["ray", "pyarrow", "pandas"]
daft = ["getdaft"]
snappy = ["python-snappy"]
Expand Down Expand Up @@ -1361,6 +1391,10 @@ ignore_missing_imports = true
module = "duckdb.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "ray.*"
ignore_missing_imports = true
Expand Down
Loading
Loading