Skip to content

Commit

Permalink
Change structtype_column.dtype from class to object, update depende…
Browse files Browse the repository at this point in the history
…ncy pyright to v1.1.348 (#281)

* Update dependency pyright to v1.1.348

* suppress the Annotated error, since it's only affecting internal function calls

* fix line length issues

* fix ambiguous error

* style

---------

Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: nanne-aben <[email protected]>
  • Loading branch information
renovate[bot] and nanne-aben authored Jan 24, 2024
1 parent 3790c61 commit d68286b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ black[jupyter]==23.12.1
isort==5.13.2
docformatter==1.7.5
mypy==1.8.0
pyright==1.1.347
pyright==1.1.348
autoflake==2.2.1
# stubs
pandas-stubs==2.1.4.231227
Expand Down
6 changes: 4 additions & 2 deletions tests/_schema/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ def test_get_schema(schema: Type[Schema], expected_schema_definition: str):


def test_dtype_attributes(spark: SparkSession):
assert ComplexDatatypes.value.dtype == typedspark.StructType[Values]
assert isinstance(A.a.dtype, LongType)
assert isinstance(ComplexDatatypes.items.dtype, typedspark.ArrayType)
assert isinstance(ComplexDatatypes.value.dtype, typedspark.StructType)
assert ComplexDatatypes.value.dtype.schema == Values
assert ComplexDatatypes.value.dtype.schema.b.dtype == StringType
assert isinstance(ComplexDatatypes.value.dtype.schema.b.dtype, StringType)

df = create_partially_filled_dataset(
spark,
Expand Down
12 changes: 8 additions & 4 deletions tests/_schema/test_structfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def test_get_structfield_dtype(type_hints):
assert _get_structfield_dtype(Column[LongType], "a") == LongType()
assert _get_structfield_dtype(type_hints["b"], "b") == StringType()
assert (
_get_structfield_dtype(Annotated[Column[StringType], ColumnMeta(comment="comment")], "c")
_get_structfield_dtype(
Annotated[Column[StringType], ColumnMeta(comment="comment")], # type: ignore
"c",
)
== StringType()
)
assert _get_structfield_dtype(type_hints["d"], "d") == BooleanType()
Expand All @@ -37,16 +40,17 @@ def test_get_structfield_metadata(type_hints):
assert get_structfield_meta(Column[LongType]) == ColumnMeta()
assert get_structfield_meta(type_hints["b"]) == ColumnMeta()
assert get_structfield_meta(
Annotated[Column[StringType], ColumnMeta(comment="comment")]
Annotated[Column[StringType], ColumnMeta(comment="comment")] # type: ignore
) == ColumnMeta(comment="comment")
assert get_structfield_meta(type_hints["d"]) == ColumnMeta(comment="comment2")


def test_get_structfield(type_hints):
assert get_structfield("a", Column[LongType]) == StructField(name="a", dataType=LongType())
assert get_structfield("b", type_hints["b"]) == StructField(name="b", dataType=StringType())
assert get_structfield(
"c", Annotated[Column[StringType], ColumnMeta(comment="comment")]
assert get_structfield( # type: ignore
"c",
Annotated[Column[StringType], ColumnMeta(comment="comment")], # type: ignore
) == StructField(name="c", dataType=StringType(), metadata={"comment": "comment"})
assert get_structfield("d", type_hints["d"]) == StructField(
name="d", dataType=BooleanType(), metadata={"comment": "comment2"}
Expand Down
13 changes: 8 additions & 5 deletions typedspark/_core/column.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module containing classes and functions related to TypedSpark Columns."""

from logging import warn
from typing import Generic, Optional, Type, TypeVar, Union, get_args, get_origin
from typing import Generic, Optional, TypeVar, Union, get_args, get_origin

from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -84,11 +84,14 @@ def __hash__(self) -> int:
return hash((self.str, self._curid))

@property
def dtype(self) -> Type[T]:
def dtype(self) -> T:
"""Get the datatype of the column, e.g. Column[IntegerType] -> IntegerType."""
dtype = self._dtype

if get_origin(dtype) == StructType:
dtype.schema = get_args(dtype)[0] # type: ignore
dtype.schema._parent = self # type: ignore
return StructType(
schema=get_args(dtype)[0],
parent=self,
) # type: ignore

return dtype # type: ignore
return dtype() # type: ignore
13 changes: 10 additions & 3 deletions typedspark/_core/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Here, we make our own definitions of ``MapType``, ``ArrayType`` and
``StructType`` in order to allow e.g. for ``ArrayType[StringType]``."""
"""Here, we make our own definitions of ``MapType``, ``ArrayType`` and ``StructType`` in
order to allow e.g. for ``ArrayType[StringType]``."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar

from pyspark.sql.types import DataType

if TYPE_CHECKING: # pragma: no cover
from typedspark._core.column import Column
from typedspark._schema.schema import Schema

_Schema = TypeVar("_Schema", bound=Schema)
Expand Down Expand Up @@ -55,7 +56,13 @@ class Person(Schema):
job: Column[StructType[Job]]
"""

schema: Type[_Schema]
def __init__(
self,
schema: Type[_Schema],
parent: Column,
) -> None:
self.schema = schema
self.schema._parent = parent


class MapType(Generic[_KeyType, _ValueType], TypedSparkDataType):
Expand Down

0 comments on commit d68286b

Please sign in to comment.