From 754e5085156563866ece93a30c6297aa801985c7 Mon Sep 17 00:00:00 2001 From: Sultan Iman <354868+sultaniman@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:08:21 +0100 Subject: [PATCH] Pydantic improvements (#901) * add example tests * Add sub-model hints for pydantic * Use snake_case naming convention to generate field names * Detect pydantic model in try..catch block --------- Co-authored-by: Dave --- .github/workflows/lint.yml | 6 +- .../test_destination_athena_iceberg.yml | 2 +- dlt/common/libs/pydantic.py | 56 ++++-- dlt/common/schema/typing.py | 1 + dlt/extract/hints.py | 6 +- tests/libs/test_pydantic.py | 93 +++++++++- tests/pipeline/test_pipeline_extra.py | 160 ++++++++++++++++++ 7 files changed, 302 insertions(+), 22 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 96dae8044c..35ccb71ab5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -25,7 +25,7 @@ jobs: defaults: run: shell: bash - runs-on: ${{ matrix.os }} + runs-on: ${{ matrix.os }} steps: @@ -42,7 +42,7 @@ jobs: with: virtualenvs-create: true virtualenvs-in-project: true - installer-parallel: true + installer-parallel: true - name: Load cached venv id: cached-poetry-dependencies @@ -57,7 +57,7 @@ jobs: - name: Run make lint run: | - export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows + export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows make lint # - name: print envs diff --git a/.github/workflows/test_destination_athena_iceberg.yml b/.github/workflows/test_destination_athena_iceberg.yml index 92b73d5a9b..fa45b1b49b 100644 --- a/.github/workflows/test_destination_athena_iceberg.yml +++ b/.github/workflows/test_destination_athena_iceberg.yml @@ -65,7 +65,7 @@ jobs: - name: Install dependencies # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction -E --with sentry-sdk --with pipeline + run: poetry install --no-interaction -E --with sentry-sdk --with pipeline - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index 58829f0592..872f352178 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -14,10 +14,11 @@ ) from typing_extensions import Annotated, get_args, get_origin +from dlt.common.data_types import py_type_to_sc_type from dlt.common.exceptions import MissingDependencyException from dlt.common.schema import DataValidationError from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns -from dlt.common.data_types import py_type_to_sc_type +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention from dlt.common.typing import ( TDataItem, TDataItems, @@ -52,6 +53,9 @@ _TPydanticModel = TypeVar("_TPydanticModel", bound=BaseModel) +snake_case_naming_convention = SnakeCaseNamingConvention() + + class ListModel(BaseModel, Generic[_TPydanticModel]): items: List[_TPydanticModel] @@ -71,7 +75,7 @@ class DltConfig(TypedDict, total=False): def pydantic_to_table_schema_columns( - model: Union[BaseModel, Type[BaseModel]] + model: Union[BaseModel, Type[BaseModel]], ) -> TTableSchemaColumns: """Convert a pydantic model to a table schema columns dict @@ -111,24 +115,47 @@ def pydantic_to_table_schema_columns( if is_list_generic_type(inner_type): inner_type = list - elif is_dict_generic_type(inner_type) or issubclass(inner_type, BaseModel): + elif is_dict_generic_type(inner_type): inner_type = dict + is_inner_type_pydantic_model = False name = field.alias or field_name try: data_type = py_type_to_sc_type(inner_type) except TypeError: - # try to coerce unknown type to text - data_type = "text" - - if data_type == "complex" and skip_complex_types: + if issubclass(inner_type, BaseModel): + data_type = "complex" + is_inner_type_pydantic_model = True + else: + # try to coerce unknown type to text + data_type = "text" + + if is_inner_type_pydantic_model and not skip_complex_types: + result[name] = { + "name": name, + "data_type": "complex", + "nullable": nullable, + } + elif is_inner_type_pydantic_model: + # This case is for a single field schema/model + # we need to generate snake_case field names + # and return flattened field schemas + schema_hints = pydantic_to_table_schema_columns(field.annotation) + + for field_name, hints in schema_hints.items(): + schema_key = snake_case_naming_convention.make_path(name, field_name) + result[schema_key] = { + **hints, + "name": snake_case_naming_convention.make_path(name, hints["name"]), + } + elif data_type == "complex" and skip_complex_types: continue - - result[name] = { - "name": name, - "data_type": data_type, - "nullable": nullable, - } + else: + result[name] = { + "name": name, + "data_type": data_type, + "nullable": nullable, + } return result @@ -261,7 +288,8 @@ def create_list_model( # TODO: use LenientList to create list model that automatically discards invalid items # https://github.com/pydantic/pydantic/issues/2274 and https://gist.github.com/dmontagu/7f0cef76e5e0e04198dd608ad7219573 return create_model( - "List" + __name__, items=(List[model], ...) # type: ignore[return-value,valid-type] + "List" + __name__, + items=(List[model], ...), # type: ignore[return-value,valid-type] ) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a27cbe4bb..e1ff17115d 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -7,6 +7,7 @@ Optional, Sequence, Set, + Tuple, Type, TypedDict, NewType, diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index c9f6327d3c..ec4bd56021 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -159,7 +159,11 @@ def compute_table_schema(self, item: TDataItem = None) -> TTableSchema: if self._table_name_hint_fun and item is None: raise DataItemRequiredForDynamicTableHints(self.name) # resolve - resolved_template: TResourceHints = {k: self._resolve_hint(item, v) for k, v in table_template.items() if k not in ["incremental", "validator", "original_columns"]} # type: ignore + resolved_template: TResourceHints = { + k: self._resolve_hint(item, v) + for k, v in table_template.items() + if k not in ["incremental", "validator", "original_columns"] + } # type: ignore table_schema = self._merge_keys(resolved_template) table_schema["resource"] = self.name validate_dict_ignoring_xkeys( diff --git a/tests/libs/test_pydantic.py b/tests/libs/test_pydantic.py index b7ca44c595..3420c08382 100644 --- a/tests/libs/test_pydantic.py +++ b/tests/libs/test_pydantic.py @@ -10,7 +10,6 @@ Union, Optional, List, - Dict, Any, ) from typing_extensions import Annotated, get_args, get_origin @@ -261,9 +260,11 @@ class User(BaseModel): # extra is modified assert model_freeze.__fields__["address"].annotation.__name__ == "UserAddressExtraAllow" # type: ignore[index] # annotated is preserved - assert issubclass(get_origin(model_freeze.__fields__["address"].rebuild_annotation()), Annotated) # type: ignore[arg-type, index] + type_origin = get_origin(model_freeze.__fields__["address"].rebuild_annotation()) # type: ignore[index] + assert issubclass(type_origin, Annotated) # type: ignore[arg-type] # UserAddress is converted to UserAddressAllow only once - assert model_freeze.__fields__["address"].annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index] + type_annotation = model_freeze.__fields__["address"].annotation # type: ignore[index] + assert type_annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index] # print(User.__fields__) # print(User.__fields__["name"].annotation) @@ -480,3 +481,89 @@ class ItemModel(BaseModel): validate_item("items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve") is None ) + + +class ChildModel(BaseModel): + child_attribute: str + optional_child_attribute: Optional[str] = None + + +class Parent(BaseModel): + child: ChildModel + optional_parent_attribute: Optional[str] = None + + +def test_pydantic_model_flattened_when_skip_complex_types_is_true(): + class MyParent(Parent): + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + schema = pydantic_to_table_schema_columns(MyParent) + + assert schema == { + "child__child_attribute": { + "data_type": "text", + "name": "child__child_attribute", + "nullable": False, + }, + "child__optional_child_attribute": { + "data_type": "text", + "name": "child__optional_child_attribute", + "nullable": True, + }, + "optional_parent_attribute": { + "data_type": "text", + "name": "optional_parent_attribute", + "nullable": True, + }, + } + + +def test_considers_model_as_complex_when_skip_complex_types_is_false(): + class MyParent(Parent): + data_dictionary: Dict[str, Any] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False} + + schema = pydantic_to_table_schema_columns(MyParent) + + assert schema == { + "child": {"data_type": "complex", "name": "child", "nullable": False}, + "data_dictionary": {"data_type": "complex", "name": "data_dictionary", "nullable": False}, + "optional_parent_attribute": { + "data_type": "text", + "name": "optional_parent_attribute", + "nullable": True, + }, + } + + +def test_considers_dictionary_as_complex_when_skip_complex_types_is_false(): + class MyParent(Parent): + data_list: List[str] = [] + data_dictionary: Dict[str, Any] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False} + + schema = pydantic_to_table_schema_columns(MyParent) + + assert schema["data_dictionary"] == { + "data_type": "complex", + "name": "data_dictionary", + "nullable": False, + } + + assert schema["data_list"] == { + "data_type": "complex", + "name": "data_list", + "nullable": False, + } + + +def test_skip_complex_types_when_skip_complex_types_is_true_and_field_is_not_pydantic_model(): + class MyParent(Parent): + data_list: List[str] = [] + data_dictionary: Dict[str, Any] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + schema = pydantic_to_table_schema_columns(MyParent) + + assert "data_dictionary" not in schema + assert "data_list" not in schema diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 856e716134..81c883c273 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -226,3 +226,163 @@ def generic(start=8): pipeline = dlt.pipeline(destination="duckdb") pipeline.run(generic(), loader_file_format=file_format) + + +class Child(BaseModel): + child_attribute: str + optional_child_attribute: Optional[str] = None + + +def test_flattens_model_when_skip_complex_types_is_set() -> None: + class Parent(BaseModel): + child: Child + optional_parent_attribute: Optional[str] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + example_data = { + "optional_parent_attribute": None, + "child": { + "child_attribute": "any string", + "optional_child_attribute": None, + }, + } + + p = dlt.pipeline("example", destination="duckdb") + p.run([example_data], table_name="items", columns=Parent) + + with p.sql_client() as client: + with client.execute_query("SELECT * FROM items") as cursor: + loaded_values = { + col[0]: val + for val, col in zip(cursor.fetchall()[0], cursor.description) + if col[0] not in ("_dlt_id", "_dlt_load_id") + } + + # Check if child dictionary is flattened and added to schema + assert loaded_values == { + "child__child_attribute": "any string", + "child__optional_child_attribute": None, + "optional_parent_attribute": None, + } + + keys = p.default_schema.tables["items"]["columns"].keys() + columns = p.default_schema.tables["items"]["columns"] + + assert keys == { + "child__child_attribute", + "child__optional_child_attribute", + "optional_parent_attribute", + "_dlt_load_id", + "_dlt_id", + } + + assert columns["child__child_attribute"] == { + "name": "child__child_attribute", + "data_type": "text", + "nullable": False, + } + + assert columns["child__optional_child_attribute"] == { + "name": "child__optional_child_attribute", + "data_type": "text", + "nullable": True, + } + + assert columns["optional_parent_attribute"] == { + "name": "optional_parent_attribute", + "data_type": "text", + "nullable": True, + } + + +def test_considers_model_as_complex_when_skip_complex_types_is_not_set(): + class Parent(BaseModel): + child: Child + optional_parent_attribute: Optional[str] = None + data_dictionary: Dict[str, Any] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False} + + example_data = { + "optional_parent_attribute": None, + "data_dictionary": { + "child_attribute": "any string", + }, + "child": { + "child_attribute": "any string", + "optional_child_attribute": None, + }, + } + + p = dlt.pipeline("example", destination="duckdb") + p.run([example_data], table_name="items", columns=Parent) + + with p.sql_client() as client: + with client.execute_query("SELECT * FROM items") as cursor: + loaded_values = { + col[0]: val + for val, col in zip(cursor.fetchall()[0], cursor.description) + if col[0] not in ("_dlt_id", "_dlt_load_id") + } + + # Check if complex fields preserved + # their contents and were not flattened + assert loaded_values == { + "child": '{"child_attribute":"any string","optional_child_attribute":null}', + "optional_parent_attribute": None, + "data_dictionary": '{"child_attribute":"any string"}', + } + + keys = p.default_schema.tables["items"]["columns"].keys() + assert keys == { + "child", + "optional_parent_attribute", + "data_dictionary", + "_dlt_load_id", + "_dlt_id", + } + + columns = p.default_schema.tables["items"]["columns"] + + assert columns["optional_parent_attribute"] == { + "name": "optional_parent_attribute", + "data_type": "text", + "nullable": True, + } + + assert columns["data_dictionary"] == { + "name": "data_dictionary", + "data_type": "complex", + "nullable": False, + } + + +def test_skips_complex_fields_when_skip_complex_types_is_true_and_field_is_not_a_pydantic_model(): + class Parent(BaseModel): + data_list: List[int] = [] + data_dictionary: Dict[str, Any] = None + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + example_data = { + "optional_parent_attribute": None, + "data_list": [12, 12, 23, 23, 45], + "data_dictionary": { + "child_attribute": "any string", + }, + } + + p = dlt.pipeline("example", destination="duckdb") + p.run([example_data], table_name="items", columns=Parent) + + table_names = [item["name"] for item in p.default_schema.data_tables()] + assert "items__data_list" in table_names + + # But `data_list` and `data_dictionary` will be loaded + with p.sql_client() as client: + with client.execute_query("SELECT * FROM items") as cursor: + loaded_values = { + col[0]: val + for val, col in zip(cursor.fetchall()[0], cursor.description) + if col[0] not in ("_dlt_id", "_dlt_load_id") + } + + assert loaded_values == {"data_dictionary__child_attribute": "any string"}