Skip to content

Commit

Permalink
fix(simulations): add input data types
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Sep 16, 2024
1 parent c9b35dc commit 46e6624
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 59 deletions.
6 changes: 2 additions & 4 deletions openfisca_core/simulations/_build_default_simulation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""This module contains the _BuildDefaultSimulation class."""

from typing_extensions import Self, TypeAlias
from typing_extensions import Self

import numpy

from .simulation import Simulation
from .types import CoreEntity, GroupPopulation, TaxBenefitSystem

Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]]
from .types import Populations, TaxBenefitSystem


class _BuildDefaultSimulation:
Expand Down
6 changes: 2 additions & 4 deletions openfisca_core/simulations/_build_from_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from __future__ import annotations

from collections.abc import Sized
from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from openfisca_core import errors

from ._build_default_simulation import _BuildDefaultSimulation
from ._guards import is_variable_dated
from .simulation import Simulation
from .types import CoreEntity, GroupPopulation, TaxBenefitSystem, Variables

Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]]
from .types import Populations, TaxBenefitSystem, Variables


class _BuildFromVariables:
Expand Down
25 changes: 18 additions & 7 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from openfisca_core import commons, errors, indexed_enums, periods, tracers
from openfisca_core import warnings as core_warnings

from .types import GroupPopulation, TaxBenefitSystem, Variable
from .types import (
EntityPlural,
GroupEntity,
GroupPopulation,
Populations,
TaxBenefitSystem,
Variable,
)


class Simulation:
Expand All @@ -19,13 +26,13 @@ class Simulation:
"""

tax_benefit_system: TaxBenefitSystem
populations: dict[str, GroupPopulation]
populations: Populations
invalidated_caches: Set[Cache]

def __init__(
self,
tax_benefit_system: TaxBenefitSystem,
populations: dict[str, GroupPopulation],
populations: Populations,
):
"""
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
Expand Down Expand Up @@ -555,10 +562,14 @@ def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulati

def get_entity(
self,
plural: Optional[str] = None,
) -> Optional[GroupPopulation]:
population = self.get_population(plural)
return population and population.entity
plural: EntityPlural | None = None,
) -> GroupEntity | None:
population: GroupPopulation | None = self.get_population(plural)

if population is None:
return None

return population.entity

def describe_entities(self):
return {
Expand Down
38 changes: 21 additions & 17 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from numpy.typing import NDArray as Array
from typing import Dict, List

import copy

import dpath.util
import numpy

from openfisca_core import entities, errors, periods, populations, variables
from openfisca_core import errors, periods

from . import helpers
from ._build_default_simulation import _BuildDefaultSimulation
Expand All @@ -22,23 +20,31 @@
)
from .simulation import Simulation
from .types import (
Array,
Axis,
EntityCounts,
EntityIds,
EntityRoles,
FullySpecifiedEntities,
GroupEntities,
GroupEntity,
ImplicitGroupEntities,
InputBuffer,
Memberships,
Params,
ParamsWithoutAxes,
Populations,
Role,
SingleEntity,
SinglePopulation,
TaxBenefitSystem,
VariableEntity,
Variables,
)


class SimulationBuilder:
def __init__(self):
def __init__(self) -> None:
self.default_period = (
None # Simulation period used for variables when no period is defined
)
Expand All @@ -47,26 +53,24 @@ def __init__(self):
)

# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: Dict[
variables.Variable.name, Dict[str(periods.period), numpy.array]
] = {}
self.populations: Dict[entities.Entity.key, populations.Population] = {}
self.input_buffer: InputBuffer = {}
self.populations: Populations = {}
# JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes.
self.entity_counts: Dict[entities.Entity.plural, int] = {}
self.entity_counts: EntityCounts = {}
# JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``.
self.entity_ids: Dict[entities.Entity.plural, List[int]] = {}
self.entity_ids: EntityIds = {}

# Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id)
self.memberships: Dict[entities.Entity.plural, List[int]] = {}
self.roles: Dict[entities.Entity.plural, List[int]] = {}
self.memberships: Memberships = {}
self.roles: EntityRoles = {}

self.variable_entities: Dict[variables.Variable.name, entities.Entity] = {}
self.variable_entities: VariableEntity = {}

self.axes = [[]]
self.axes_entity_counts: Dict[entities.Entity.plural, int] = {}
self.axes_entity_ids: Dict[entities.Entity.plural, List[int]] = {}
self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {}
self.axes_roles: Dict[entities.Entity.plural, List[int]] = {}
self.axes_entity_counts: EntityCounts = {}
self.axes_entity_ids: EntityIds = {}
self.axes_memberships: Memberships = {}
self.axes_roles: EntityRoles = {}

def build_from_dict(
self,
Expand Down
76 changes: 49 additions & 27 deletions openfisca_core/simulations/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
from typing import Protocol, TypeVar, TypedDict, Union
from typing import NewType, Protocol, TypeVar, TypedDict, Union
from typing_extensions import NotRequired, Required, TypeAlias

import datetime
from abc import abstractmethod

from numpy import bool_ as Bool
from numpy import datetime64 as Date
Expand All @@ -19,32 +18,36 @@
from openfisca_core import types as t

# Generic type variables.
D = TypeVar("D")
E = TypeVar("E", covariant=True)
G = TypeVar("G", covariant=True)
T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True)
U = TypeVar("U", bool, datetime.date, float, str)
V = TypeVar("V", covariant=True)

# New types.
PeriodStr = NewType("PeriodStr", str)
EntityKey = NewType("EntityKey", str)
EntityPlural = NewType("EntityPlural", str)
VariableName = NewType("VariableName", str)

# Type aliases.

#: Type alias for numpy arrays values.
Item: TypeAlias = Union[Bool, Date, Enum, Float, Int, String]

#: Type Alias for a numpy Array.
Array: TypeAlias = t.Array

# Entities


#: Type alias for a simulation dictionary defining the roles.
Roles: TypeAlias = dict[str, Union[str, Iterable[str]]]


class CoreEntity(t.CoreEntity, Protocol):
key: str
plural: str | None
key: EntityKey
plural: EntityPlural | None

def get_variable(
self,
__variable_name: str,
__check_existence: bool = ...,
__variable_name: VariableName,
check_existence: bool = ...,
) -> Variable[T] | None:
...

Expand All @@ -55,7 +58,6 @@ class SingleEntity(t.SingleEntity, Protocol):

class GroupEntity(t.GroupEntity, Protocol):
@property
@abstractmethod
def flattened_roles(self) -> Iterable[Role[G]]:
...

Expand All @@ -69,11 +71,10 @@ class Role(t.Role, Protocol[G]):

class Holder(t.Holder, Protocol[V]):
@property
@abstractmethod
def variable(self) -> Variable[T]:
...

def get_array(self, __period: str) -> t.Array[T] | None:
def get_array(self, __period: PeriodStr) -> t.Array[T] | None:
...

def set_input(
Expand All @@ -94,18 +95,19 @@ class Period(t.Period, Protocol):
# Populations


class CorePopulation(t.CorePopulation, Protocol[D]):
entity: D
class CorePopulation(t.CorePopulation, Protocol):
entity: CoreEntity

def get_holder(self, __variable_name: str) -> Holder[V]:
def get_holder(self, __variable_name: VariableName) -> Holder[V]:
...


class SinglePopulation(t.SinglePopulation, Protocol[E]):
...
class SinglePopulation(t.SinglePopulation, Protocol):
entity: SingleEntity


class GroupPopulation(t.GroupPopulation, Protocol[E]):
class GroupPopulation(t.GroupPopulation, Protocol):
entity: GroupEntity
members_entity_id: t.Array[String]

def nb_persons(self, __role: Role[G] | None = ...) -> int:
Expand All @@ -114,6 +116,29 @@ def nb_persons(self, __role: Role[G] | None = ...) -> int:

# Simulations

#: Dictionary with axes parameters per variable.
InputBuffer: TypeAlias = dict[VariableName, dict[PeriodStr, Array]]

#: Dictionary with entity/population key/pais.
Populations: TypeAlias = dict[EntityKey, GroupPopulation]

#: Dictionary with single entity count per group entity.
EntityCounts: TypeAlias = dict[EntityPlural, int]

#: Dictionary with a list of single entities per group entity.
EntityIds: TypeAlias = dict[EntityPlural, Iterable[int]]

#: Dictionary with a list of members per group entity.
Memberships: TypeAlias = dict[EntityPlural, Iterable[int]]

#: Dictionary with a list of roles per group entity.
EntityRoles: TypeAlias = dict[EntityPlural, Iterable[int]]

#: Dictionary with a map between variables and entities.
VariableEntity: TypeAlias = dict[VariableName, CoreEntity]

#: Type alias for a simulation dictionary defining the roles.
Roles: TypeAlias = dict[str, Union[str, Iterable[str]]]

#: Type alias for a simulation dictionary with undated variables.
UndatedVariable: TypeAlias = dict[str, object]
Expand Down Expand Up @@ -169,21 +194,18 @@ class Simulation(t.Simulation, Protocol):

class TaxBenefitSystem(t.TaxBenefitSystem, Protocol):
@property
@abstractmethod
def person_entity(self) -> SingleEntity:
...

@person_entity.setter
@abstractmethod
def person_entity(self, person_entity: SingleEntity) -> None:
...

@property
@abstractmethod
def variables(self) -> dict[str, V]:
...

def entities_by_singular(self) -> dict[str, E]:
def entities_by_singular(self) -> dict[str, CoreEntity]:
...

def entities_plural(self) -> Iterable[str]:
Expand All @@ -198,7 +220,7 @@ def get_variable(

def instantiate_entities(
self,
) -> dict[str, GroupPopulation[E]]:
) -> Populations:
...


Expand All @@ -209,7 +231,7 @@ class Variable(t.Variable, Protocol[T]):
calculate_output: Callable[[Simulation, str, str], t.Array[T]] | None
definition_period: str
end: str
name: str
name: VariableName

def default_array(self, __array_size: int) -> t.Array[T]:
...

0 comments on commit 46e6624

Please sign in to comment.