Skip to content

Commit

Permalink
refactor: consolidate tracers module types
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Nov 19, 2024
1 parent dbbd16e commit e56d32d
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 118 deletions.
2 changes: 0 additions & 2 deletions openfisca_core/tracers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#
# See: https://www.python.org/dev/peps/pep-0008/#imports

from . import types
from .computation_log import ComputationLog
from .flat_trace import FlatTrace
from .full_tracer import FullTracer
Expand All @@ -38,5 +37,4 @@
"SimpleTracer",
"TraceNode",
"TracingParameterNodeAtInstant",
"types",
]
3 changes: 1 addition & 2 deletions openfisca_core/tracers/computation_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import numpy

from openfisca_core import types as t
from openfisca_core.indexed_enums import EnumArray

from . import types as t


class ComputationLog:
_full_tracer: t.FullTracer
Expand Down
3 changes: 1 addition & 2 deletions openfisca_core/tracers/flat_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import numpy

from openfisca_core import types as t
from openfisca_core.indexed_enums import EnumArray

from . import types as t


class FlatTrace:
_full_tracer: t.FullTracer
Expand Down
3 changes: 2 additions & 1 deletion openfisca_core/tracers/full_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
import time

from . import types as t
from openfisca_core import types as t

from .computation_log import ComputationLog
from .flat_trace import FlatTrace
from .performance_log import PerformanceLog
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/tracers/simple_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from . import types as t
from openfisca_core import types as t


class SimpleTracer:
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/tracers/trace_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses

from . import types as t
from openfisca_core import types as t


@dataclasses.dataclass
Expand Down
108 changes: 0 additions & 108 deletions openfisca_core/tracers/types.py

This file was deleted.

93 changes: 92 additions & 1 deletion openfisca_core/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence, Sized
from collections.abc import Iterable, Iterator, Sequence, Sized
from numpy.typing import DTypeLike, NDArray
from typing import NewType, TypeVar, Union
from typing_extensions import Protocol, Required, Self, TypeAlias, TypedDict
Expand Down Expand Up @@ -309,6 +309,97 @@ def get_variable(
) -> None | Variable: ...


# Tracers

#: A type representing a unit time.
Time: TypeAlias = float

#: A type representing a mapping of flat traces.
FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"]

#: A type representing a mapping of serialized traces.
SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"]

#: Key of a trace.
NodeKey = NewType("NodeKey", str)


class FlatTraceMap(TypedDict, total=True):
dependencies: list[NodeKey]
parameters: dict[NodeKey, None | ArrayLike[object]]
value: None | VarArray
calculation_time: Time
formula_time: Time


class SerializedTraceMap(TypedDict, total=True):
dependencies: list[NodeKey]
parameters: dict[NodeKey, None | ArrayLike[object]]
value: None | ArrayLike[object]
calculation_time: Time
formula_time: Time


class SimpleTraceMap(TypedDict, total=True):
name: VariableName
period: int | Period


class ComputationLog(Protocol):
def print_log(self, __aggregate: bool = ..., __max_depth: int = ..., /) -> None: ...


class FlatTrace(Protocol):
def get_trace(self, /) -> FlatNodeMap: ...
def get_serialized_trace(self, /) -> SerializedNodeMap: ...


class FullTracer(Protocol):
@property
def trees(self, /) -> list[TraceNode]: ...
def browse_trace(self, /) -> Iterator[TraceNode]: ...
def get_nb_requests(self, __name: VariableName, /) -> int: ...


class PerformanceLog(Protocol):
def generate_graph(self, __dir_path: str, /) -> None: ...
def generate_performance_tables(self, __dir_path: str, /) -> None: ...


class SimpleTracer(Protocol):
@property
def stack(self, /) -> SimpleStack: ...
def record_calculation_start(
self, __name: VariableName, __period: PeriodInt | Period, /
) -> None: ...
def record_calculation_end(self, /) -> None: ...


class TraceNode(Protocol):
@property
def children(self, /) -> list[TraceNode]: ...
@property
def end(self, /) -> Time: ...
@property
def name(self, /) -> str: ...
@property
def parameters(self, /) -> list[TraceNode]: ...
@property
def parent(self, /) -> None | TraceNode: ...
@property
def period(self, /) -> PeriodInt | Period: ...
@property
def start(self, /) -> Time: ...
@property
def value(self, /) -> None | VarArray: ...
def calculation_time(self, *, __round: bool = ...) -> Time: ...
def formula_time(self, /) -> Time: ...
def append_child(self, __node: TraceNode, /) -> None: ...


#: A stack of simple traces.
SimpleStack: TypeAlias = list[SimpleTraceMap]

# Variables

#: For example "salary".
Expand Down

0 comments on commit e56d32d

Please sign in to comment.