Skip to content

Commit

Permalink
python primitive type updates
Browse files Browse the repository at this point in the history
Signed-off-by: Clemens Vasters <[email protected]>
  • Loading branch information
clemensv committed Sep 23, 2024
1 parent d757f45 commit 36ad6ea
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 21 deletions.
28 changes: 14 additions & 14 deletions avrotize/avrotopython.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,25 @@ def convert_logical_type_to_python(self, avro_type: Dict, import_types: Set[str]
"""Converts Avro logical type to Python type"""
if avro_type['logicalType'] == 'decimal':
import_types.add('decimal.Decimal')
return 'Decimal'
return 'decimal.Decimal'
elif avro_type['logicalType'] == 'date':
import_types.add('datetime.date')
return 'date'
return 'datetime.date'
elif avro_type['logicalType'] == 'time-millis':
import_types.add('datetime.time')
return 'time'
return 'datetime.time'
elif avro_type['logicalType'] == 'time-micros':
import_types.add('datetime.time')
return 'time'
return 'datetime.time'
elif avro_type['logicalType'] == 'timestamp-millis':
import_types.add('datetime.datetime')
return 'datetime'
return 'datetime.datetime'
elif avro_type['logicalType'] == 'timestamp-micros':
import_types.add('datetime.datetime')
return 'datetime'
return 'datetime.datetime'
elif avro_type['logicalType'] == 'duration':
import_types.add('datetime.timedelta')
return 'timedelta'
return 'datetime.timedelta'
return 'typing.Any'

def convert_avro_type_to_python(self, avro_type: Union[str, Dict, List], parent_package: str, import_types: set) -> str:
Expand Down Expand Up @@ -180,9 +180,9 @@ def init_field_value(self, field_type: str, field_name: str, field_is_enum: bool
""" Initialize the field value based on its type. """
if field_type == "typing.Any":
return field_ref
elif field_type in ['datetime', 'date', 'time', 'timedelta']:
elif field_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta']:
return f"{field_ref}"
elif field_type in ['int', 'str', 'float', 'bool', 'bytes', 'Decimal', 'datetime', 'date', 'time', 'timedelta']:
elif field_type in ['int', 'str', 'float', 'bool', 'bytes', 'Decimal']:
return f"{field_type}({field_ref})"
elif field_type.startswith("typing.List["):
inner_type = get_typing_args_from_string(field_type)[0]
Expand Down Expand Up @@ -373,11 +373,11 @@ def generate_value(field_type: str):
'float': f'float({random.uniform(0, 100)})',
'bytes': 'b"test_bytes"',
'None': 'None',
'date': random.choice(['datetime.date.today()', 'datetime.date(2021, 1, 1)']),
'datetime': 'datetime.datetime.now()',
'time': 'datetime.datetime.now().time()',
'Decimal': f'Decimal("{random.randint(0, 100)}.{random.randint(0, 100)}")',
'timedelta': 'datetime.timedelta(days=1)',
'datetime.date': random.choice(['datetime.date.today()', 'datetime.date(2021, 1, 1)']),
'datetime.datetime': 'datetime.datetime.now(datetime.timezone.utc)',
'datetime.time': 'datetime.datetime.now(datetime.timezone.utc).time()',
'decimal.Decimal': f'decimal.Decimal("{random.randint(0, 100)}.{random.randint(0, 100)}")',
'datetime.timedelta': 'datetime.timedelta(days=1)',
'typing.Any': '{"test": "test"}'
}

Expand Down
19 changes: 16 additions & 3 deletions avrotize/avrotopython/dataclass_core.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@ import dataclasses
{%- if dataclasses_json_annotation %}
import dataclasses_json
import json
{%- for field in fields if field.type == "datetime" or field.type == "typing.Optional[datetime.datetime]" %}
{%- if loop.first %}
from marshmallow import fields
{%- endif %}
{%- endfor %}
{%- endif %}
{%- if avro_annotation %}
import avro.schema
import avro.io
{%- endif %}
{%- for import_type in import_types %}
{%- for import_type in import_types if import_type not in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta', 'decimal.Decimal'] %}
from {{ '.'.join(import_type.split('.')[:-1]) | lower }} import {{ import_type.split('.')[-1] }}
{%- endfor %}
{%- for import_type in import_types if import_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
{%- if loop.first %}
import datetime
{%- endif %}
{%- endfor %}

{% if dataclasses_json_annotation %}
@dataclasses_json.dataclass_json
Expand All @@ -34,8 +44,9 @@ class {{ class_name }}:
{%- endfor -%}
"""
{% for field in fields %}
{{ field.name }}: {{ field.type }}=dataclasses.field(kw_only=True{% if dataclasses_json_annotation %}, metadata=dataclasses_json.config(field_name="{{ field.original_name }}"){%- endif %})
{%- endfor %}
{%- set isdate = field.type == "datetime" or field.type == "typing.Optional[datetime.datetime]" %}
{{ field.name }}: {{ field.type }}=dataclasses.field(kw_only=True{% if dataclasses_json_annotation %}, metadata=dataclasses_json.config(field_name="{{ field.original_name }}"{%- if isdate -%}, encoder=lambda d: datetime.datetime.isoformat(d) if d else None, decoder=lambda d:datetime.datetime.fromisoformat(d) if d else None, mm_field=fields.DateTime(format='iso'){%- endif -%}){%- endif %})
{%- endfor %}
{% if avro_annotation %}
AvroType: typing.ClassVar[avro.schema.Schema] = avro.schema.parse(
"{{ avro_schema_json }}"
Expand Down Expand Up @@ -128,7 +139,9 @@ class {{ class_name }}:

{%- if dataclasses_json_annotation %}
if content_type == 'application/json':
#pylint: disable=no-member
result = self.to_json()
#pylint: enable=no-member
{%- endif %}

if result is not None and content_type.endswith('+gzip'):
Expand Down
8 changes: 7 additions & 1 deletion avrotize/avrotopython/test_class.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), '../src

from {{ package_name | lower }} import {{ class_name }}

{%- for import_type in import_types %}
{%- for import_type in import_types if import_type not in ['decimal.Decimal', 'datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
{%- set import_type_name = 'Test_'+import_type.split('.')[-1] %}
{%- set import_package_name = 'test_'+'_'.join(import_type.split('.')[:-1]) | lower %}

Expand All @@ -20,6 +20,12 @@ from .{{ import_package_name }} import {{ import_type_name }}
from {{ import_package_name }} import {{ import_type_name }}
{%- endif -%}
{%- endfor %}
{%- for import_type in import_types if import_type in ['datetime.datetime', 'datetime.date', 'datetime.time', 'datetime.timedelta'] %}
{%- if loop.first %}
import datetime
{%- endif %}
{%- endfor %}


class {{ test_class_name }}(unittest.TestCase):
"""
Expand Down
Loading

0 comments on commit 36ad6ea

Please sign in to comment.