From 46e662466927f2b1065eadd4f15f14668475e128 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Mon, 16 Sep 2024 16:28:24 +0200 Subject: [PATCH] fix(simulations): add input data types --- .../simulations/_build_default_simulation.py | 6 +- .../simulations/_build_from_variables.py | 6 +- openfisca_core/simulations/simulation.py | 25 ++++-- .../simulations/simulation_builder.py | 38 +++++----- openfisca_core/simulations/types.py | 76 ++++++++++++------- 5 files changed, 92 insertions(+), 59 deletions(-) diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py index f8828ff59..780dc9d49 100644 --- a/openfisca_core/simulations/_build_default_simulation.py +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -1,13 +1,11 @@ """This module contains the _BuildDefaultSimulation class.""" -from typing_extensions import Self, TypeAlias +from typing_extensions import Self import numpy from .simulation import Simulation -from .types import CoreEntity, GroupPopulation, TaxBenefitSystem - -Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]] +from .types import Populations, TaxBenefitSystem class _BuildDefaultSimulation: diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py index 4721ebca6..152e73846 100644 --- a/openfisca_core/simulations/_build_from_variables.py +++ b/openfisca_core/simulations/_build_from_variables.py @@ -3,16 +3,14 @@ from __future__ import annotations from collections.abc import Sized -from typing_extensions import Self, TypeAlias +from typing_extensions import Self from openfisca_core import errors from ._build_default_simulation import _BuildDefaultSimulation from ._guards import is_variable_dated from .simulation import Simulation -from .types import CoreEntity, GroupPopulation, TaxBenefitSystem, Variables - -Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]] +from .types import Populations, TaxBenefitSystem, Variables class _BuildFromVariables: diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 084331644..136b960fd 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -10,7 +10,14 @@ from openfisca_core import commons, errors, indexed_enums, periods, tracers from openfisca_core import warnings as core_warnings -from .types import GroupPopulation, TaxBenefitSystem, Variable +from .types import ( + EntityPlural, + GroupEntity, + GroupPopulation, + Populations, + TaxBenefitSystem, + Variable, +) class Simulation: @@ -19,13 +26,13 @@ class Simulation: """ tax_benefit_system: TaxBenefitSystem - populations: dict[str, GroupPopulation] + populations: Populations invalidated_caches: Set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, - populations: dict[str, GroupPopulation], + populations: Populations, ): """ This constructor is reserved for internal use; see :any:`SimulationBuilder`, @@ -555,10 +562,14 @@ def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulati def get_entity( self, - plural: Optional[str] = None, - ) -> Optional[GroupPopulation]: - population = self.get_population(plural) - return population and population.entity + plural: EntityPlural | None = None, + ) -> GroupEntity | None: + population: GroupPopulation | None = self.get_population(plural) + + if population is None: + return None + + return population.entity def describe_entities(self): return { diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 7b5860082..f1a2b71b8 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,15 +1,13 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from numpy.typing import NDArray as Array -from typing import Dict, List import copy import dpath.util import numpy -from openfisca_core import entities, errors, periods, populations, variables +from openfisca_core import errors, periods from . import helpers from ._build_default_simulation import _BuildDefaultSimulation @@ -22,23 +20,31 @@ ) from .simulation import Simulation from .types import ( + Array, Axis, + EntityCounts, + EntityIds, + EntityRoles, FullySpecifiedEntities, GroupEntities, GroupEntity, ImplicitGroupEntities, + InputBuffer, + Memberships, Params, ParamsWithoutAxes, + Populations, Role, SingleEntity, SinglePopulation, TaxBenefitSystem, + VariableEntity, Variables, ) class SimulationBuilder: - def __init__(self): + def __init__(self) -> None: self.default_period = ( None # Simulation period used for variables when no period is defined ) @@ -47,26 +53,24 @@ def __init__(self): ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: Dict[ - variables.Variable.name, Dict[str(periods.period), numpy.array] - ] = {} - self.populations: Dict[entities.Entity.key, populations.Population] = {} + self.input_buffer: InputBuffer = {} + self.populations: Populations = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: Dict[entities.Entity.plural, int] = {} + self.entity_counts: EntityCounts = {} # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: Dict[entities.Entity.plural, List[int]] = {} + self.entity_ids: EntityIds = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: Dict[entities.Entity.plural, List[int]] = {} - self.roles: Dict[entities.Entity.plural, List[int]] = {} + self.memberships: Memberships = {} + self.roles: EntityRoles = {} - self.variable_entities: Dict[variables.Variable.name, entities.Entity] = {} + self.variable_entities: VariableEntity = {} self.axes = [[]] - self.axes_entity_counts: Dict[entities.Entity.plural, int] = {} - self.axes_entity_ids: Dict[entities.Entity.plural, List[int]] = {} - self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {} - self.axes_roles: Dict[entities.Entity.plural, List[int]] = {} + self.axes_entity_counts: EntityCounts = {} + self.axes_entity_ids: EntityIds = {} + self.axes_memberships: Memberships = {} + self.axes_roles: EntityRoles = {} def build_from_dict( self, diff --git a/openfisca_core/simulations/types.py b/openfisca_core/simulations/types.py index 8317ad84e..cce6d12b8 100644 --- a/openfisca_core/simulations/types.py +++ b/openfisca_core/simulations/types.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence -from typing import Protocol, TypeVar, TypedDict, Union +from typing import NewType, Protocol, TypeVar, TypedDict, Union from typing_extensions import NotRequired, Required, TypeAlias import datetime -from abc import abstractmethod from numpy import bool_ as Bool from numpy import datetime64 as Date @@ -19,32 +18,36 @@ from openfisca_core import types as t # Generic type variables. -D = TypeVar("D") -E = TypeVar("E", covariant=True) G = TypeVar("G", covariant=True) T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) U = TypeVar("U", bool, datetime.date, float, str) V = TypeVar("V", covariant=True) +# New types. +PeriodStr = NewType("PeriodStr", str) +EntityKey = NewType("EntityKey", str) +EntityPlural = NewType("EntityPlural", str) +VariableName = NewType("VariableName", str) + +# Type aliases. + #: Type alias for numpy arrays values. Item: TypeAlias = Union[Bool, Date, Enum, Float, Int, String] +#: Type Alias for a numpy Array. +Array: TypeAlias = t.Array # Entities -#: Type alias for a simulation dictionary defining the roles. -Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] - - class CoreEntity(t.CoreEntity, Protocol): - key: str - plural: str | None + key: EntityKey + plural: EntityPlural | None def get_variable( self, - __variable_name: str, - __check_existence: bool = ..., + __variable_name: VariableName, + check_existence: bool = ..., ) -> Variable[T] | None: ... @@ -55,7 +58,6 @@ class SingleEntity(t.SingleEntity, Protocol): class GroupEntity(t.GroupEntity, Protocol): @property - @abstractmethod def flattened_roles(self) -> Iterable[Role[G]]: ... @@ -69,11 +71,10 @@ class Role(t.Role, Protocol[G]): class Holder(t.Holder, Protocol[V]): @property - @abstractmethod def variable(self) -> Variable[T]: ... - def get_array(self, __period: str) -> t.Array[T] | None: + def get_array(self, __period: PeriodStr) -> t.Array[T] | None: ... def set_input( @@ -94,18 +95,19 @@ class Period(t.Period, Protocol): # Populations -class CorePopulation(t.CorePopulation, Protocol[D]): - entity: D +class CorePopulation(t.CorePopulation, Protocol): + entity: CoreEntity - def get_holder(self, __variable_name: str) -> Holder[V]: + def get_holder(self, __variable_name: VariableName) -> Holder[V]: ... -class SinglePopulation(t.SinglePopulation, Protocol[E]): - ... +class SinglePopulation(t.SinglePopulation, Protocol): + entity: SingleEntity -class GroupPopulation(t.GroupPopulation, Protocol[E]): +class GroupPopulation(t.GroupPopulation, Protocol): + entity: GroupEntity members_entity_id: t.Array[String] def nb_persons(self, __role: Role[G] | None = ...) -> int: @@ -114,6 +116,29 @@ def nb_persons(self, __role: Role[G] | None = ...) -> int: # Simulations +#: Dictionary with axes parameters per variable. +InputBuffer: TypeAlias = dict[VariableName, dict[PeriodStr, Array]] + +#: Dictionary with entity/population key/pais. +Populations: TypeAlias = dict[EntityKey, GroupPopulation] + +#: Dictionary with single entity count per group entity. +EntityCounts: TypeAlias = dict[EntityPlural, int] + +#: Dictionary with a list of single entities per group entity. +EntityIds: TypeAlias = dict[EntityPlural, Iterable[int]] + +#: Dictionary with a list of members per group entity. +Memberships: TypeAlias = dict[EntityPlural, Iterable[int]] + +#: Dictionary with a list of roles per group entity. +EntityRoles: TypeAlias = dict[EntityPlural, Iterable[int]] + +#: Dictionary with a map between variables and entities. +VariableEntity: TypeAlias = dict[VariableName, CoreEntity] + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] #: Type alias for a simulation dictionary with undated variables. UndatedVariable: TypeAlias = dict[str, object] @@ -169,21 +194,18 @@ class Simulation(t.Simulation, Protocol): class TaxBenefitSystem(t.TaxBenefitSystem, Protocol): @property - @abstractmethod def person_entity(self) -> SingleEntity: ... @person_entity.setter - @abstractmethod def person_entity(self, person_entity: SingleEntity) -> None: ... @property - @abstractmethod def variables(self) -> dict[str, V]: ... - def entities_by_singular(self) -> dict[str, E]: + def entities_by_singular(self) -> dict[str, CoreEntity]: ... def entities_plural(self) -> Iterable[str]: @@ -198,7 +220,7 @@ def get_variable( def instantiate_entities( self, - ) -> dict[str, GroupPopulation[E]]: + ) -> Populations: ... @@ -209,7 +231,7 @@ class Variable(t.Variable, Protocol[T]): calculate_output: Callable[[Simulation, str, str], t.Array[T]] | None definition_period: str end: str - name: str + name: VariableName def default_array(self, __array_size: int) -> t.Array[T]: ...