Skip to content

Commit

Permalink
Pydantic improvements (#901)
Browse files Browse the repository at this point in the history
* add example tests

* Add sub-model hints for pydantic

* Use snake_case naming convention to generate field names

* Detect pydantic model in try..catch block

---------

Co-authored-by: Dave <[email protected]>
  • Loading branch information
sultaniman and sh-rp authored Feb 7, 2024
1 parent f04032d commit 754e508
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 22 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
defaults:
run:
shell: bash
runs-on: ${{ matrix.os }}
runs-on: ${{ matrix.os }}

steps:

Expand All @@ -42,7 +42,7 @@ jobs:
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
Expand All @@ -57,7 +57,7 @@ jobs:

- name: Run make lint
run: |
export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows
export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows
make lint
# - name: print envs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_destination_athena_iceberg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:

- name: Install dependencies
# if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction -E --with sentry-sdk --with pipeline
run: poetry install --no-interaction -E --with sentry-sdk --with pipeline

- name: create secrets.toml
run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml
Expand Down
56 changes: 42 additions & 14 deletions dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
)
from typing_extensions import Annotated, get_args, get_origin

from dlt.common.data_types import py_type_to_sc_type
from dlt.common.exceptions import MissingDependencyException
from dlt.common.schema import DataValidationError
from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns
from dlt.common.data_types import py_type_to_sc_type
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention
from dlt.common.typing import (
TDataItem,
TDataItems,
Expand Down Expand Up @@ -52,6 +53,9 @@
_TPydanticModel = TypeVar("_TPydanticModel", bound=BaseModel)


snake_case_naming_convention = SnakeCaseNamingConvention()


class ListModel(BaseModel, Generic[_TPydanticModel]):
items: List[_TPydanticModel]

Expand All @@ -71,7 +75,7 @@ class DltConfig(TypedDict, total=False):


def pydantic_to_table_schema_columns(
model: Union[BaseModel, Type[BaseModel]]
model: Union[BaseModel, Type[BaseModel]],
) -> TTableSchemaColumns:
"""Convert a pydantic model to a table schema columns dict
Expand Down Expand Up @@ -111,24 +115,47 @@ def pydantic_to_table_schema_columns(

if is_list_generic_type(inner_type):
inner_type = list
elif is_dict_generic_type(inner_type) or issubclass(inner_type, BaseModel):
elif is_dict_generic_type(inner_type):
inner_type = dict

is_inner_type_pydantic_model = False
name = field.alias or field_name
try:
data_type = py_type_to_sc_type(inner_type)
except TypeError:
# try to coerce unknown type to text
data_type = "text"

if data_type == "complex" and skip_complex_types:
if issubclass(inner_type, BaseModel):
data_type = "complex"
is_inner_type_pydantic_model = True
else:
# try to coerce unknown type to text
data_type = "text"

if is_inner_type_pydantic_model and not skip_complex_types:
result[name] = {
"name": name,
"data_type": "complex",
"nullable": nullable,
}
elif is_inner_type_pydantic_model:
# This case is for a single field schema/model
# we need to generate snake_case field names
# and return flattened field schemas
schema_hints = pydantic_to_table_schema_columns(field.annotation)

for field_name, hints in schema_hints.items():
schema_key = snake_case_naming_convention.make_path(name, field_name)
result[schema_key] = {
**hints,
"name": snake_case_naming_convention.make_path(name, hints["name"]),
}
elif data_type == "complex" and skip_complex_types:
continue

result[name] = {
"name": name,
"data_type": data_type,
"nullable": nullable,
}
else:
result[name] = {
"name": name,
"data_type": data_type,
"nullable": nullable,
}

return result

Expand Down Expand Up @@ -261,7 +288,8 @@ def create_list_model(
# TODO: use LenientList to create list model that automatically discards invalid items
# https://github.com/pydantic/pydantic/issues/2274 and https://gist.github.com/dmontagu/7f0cef76e5e0e04198dd608ad7219573
return create_model(
"List" + __name__, items=(List[model], ...) # type: ignore[return-value,valid-type]
"List" + __name__,
items=(List[model], ...), # type: ignore[return-value,valid-type]
)


Expand Down
1 change: 1 addition & 0 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Optional,
Sequence,
Set,
Tuple,
Type,
TypedDict,
NewType,
Expand Down
6 changes: 5 additions & 1 deletion dlt/extract/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def compute_table_schema(self, item: TDataItem = None) -> TTableSchema:
if self._table_name_hint_fun and item is None:
raise DataItemRequiredForDynamicTableHints(self.name)
# resolve
resolved_template: TResourceHints = {k: self._resolve_hint(item, v) for k, v in table_template.items() if k not in ["incremental", "validator", "original_columns"]} # type: ignore
resolved_template: TResourceHints = {
k: self._resolve_hint(item, v)
for k, v in table_template.items()
if k not in ["incremental", "validator", "original_columns"]
} # type: ignore
table_schema = self._merge_keys(resolved_template)
table_schema["resource"] = self.name
validate_dict_ignoring_xkeys(
Expand Down
93 changes: 90 additions & 3 deletions tests/libs/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Union,
Optional,
List,
Dict,
Any,
)
from typing_extensions import Annotated, get_args, get_origin
Expand Down Expand Up @@ -261,9 +260,11 @@ class User(BaseModel):
# extra is modified
assert model_freeze.__fields__["address"].annotation.__name__ == "UserAddressExtraAllow" # type: ignore[index]
# annotated is preserved
assert issubclass(get_origin(model_freeze.__fields__["address"].rebuild_annotation()), Annotated) # type: ignore[arg-type, index]
type_origin = get_origin(model_freeze.__fields__["address"].rebuild_annotation()) # type: ignore[index]
assert issubclass(type_origin, Annotated) # type: ignore[arg-type]
# UserAddress is converted to UserAddressAllow only once
assert model_freeze.__fields__["address"].annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index]
type_annotation = model_freeze.__fields__["address"].annotation # type: ignore[index]
assert type_annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index]

# print(User.__fields__)
# print(User.__fields__["name"].annotation)
Expand Down Expand Up @@ -480,3 +481,89 @@ class ItemModel(BaseModel):
validate_item("items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve")
is None
)


class ChildModel(BaseModel):
child_attribute: str
optional_child_attribute: Optional[str] = None


class Parent(BaseModel):
child: ChildModel
optional_parent_attribute: Optional[str] = None


def test_pydantic_model_flattened_when_skip_complex_types_is_true():
class MyParent(Parent):
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True}

schema = pydantic_to_table_schema_columns(MyParent)

assert schema == {
"child__child_attribute": {
"data_type": "text",
"name": "child__child_attribute",
"nullable": False,
},
"child__optional_child_attribute": {
"data_type": "text",
"name": "child__optional_child_attribute",
"nullable": True,
},
"optional_parent_attribute": {
"data_type": "text",
"name": "optional_parent_attribute",
"nullable": True,
},
}


def test_considers_model_as_complex_when_skip_complex_types_is_false():
class MyParent(Parent):
data_dictionary: Dict[str, Any] = None
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False}

schema = pydantic_to_table_schema_columns(MyParent)

assert schema == {
"child": {"data_type": "complex", "name": "child", "nullable": False},
"data_dictionary": {"data_type": "complex", "name": "data_dictionary", "nullable": False},
"optional_parent_attribute": {
"data_type": "text",
"name": "optional_parent_attribute",
"nullable": True,
},
}


def test_considers_dictionary_as_complex_when_skip_complex_types_is_false():
class MyParent(Parent):
data_list: List[str] = []
data_dictionary: Dict[str, Any] = None
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False}

schema = pydantic_to_table_schema_columns(MyParent)

assert schema["data_dictionary"] == {
"data_type": "complex",
"name": "data_dictionary",
"nullable": False,
}

assert schema["data_list"] == {
"data_type": "complex",
"name": "data_list",
"nullable": False,
}


def test_skip_complex_types_when_skip_complex_types_is_true_and_field_is_not_pydantic_model():
class MyParent(Parent):
data_list: List[str] = []
data_dictionary: Dict[str, Any] = None
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True}

schema = pydantic_to_table_schema_columns(MyParent)

assert "data_dictionary" not in schema
assert "data_list" not in schema
Loading

0 comments on commit 754e508

Please sign in to comment.