Skip to content

Commit

Permalink
Add type hints to the SKLearnRegressor protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorC committed Feb 23, 2022
1 parent d5d2077 commit 7b0b6a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
4 changes: 2 additions & 2 deletions pararealml/operators/ml/auto_regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from pararealml.operators.ml.auto_regression.auto_regression_operator \
import AutoRegressionOperator
from pararealml.operators.ml.auto_regression.auto_regression_operator \
import RegressionModel
import SKLearnRegressor
from pararealml.operators.ml.auto_regression.sklearn_keras_regressor \
import SKLearnKerasRegressor

__all__ = [
'AutoRegressionOperator',
'RegressionModel',
'SKLearnRegressor',
'SKLearnKerasRegressor'
]
48 changes: 31 additions & 17 deletions pararealml/operators/ml/auto_regression/auto_regression_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Union, Tuple, Callable, Optional, Protocol, List
from __future__ import annotations

from typing import Tuple, Callable, Optional, Protocol, List

import numpy as np
from sklearn.metrics import mean_squared_error
Expand All @@ -8,20 +10,9 @@
from pararealml.initial_condition import DiscreteInitialCondition
from pararealml.initial_value_problem import InitialValueProblem
from pararealml.operator import Operator, discretize_time_domain
from pararealml.operators.ml.auto_regression.sklearn_keras_regressor \
import SKLearnKerasRegressor
from pararealml.solution import Solution


class SKLearnRegressor(Protocol):
def fit(self, x, y, sample_weight=None): ...
def predict(self, x): ...
def score(self, x, y, sample_weight=None): ...


RegressionModel = Union[SKLearnRegressor, SKLearnKerasRegressor]


class AutoRegressionOperator(Operator):
"""
A supervised machine learning operator that uses auto regression to model
Expand All @@ -40,17 +31,17 @@ def __init__(
"""
super(AutoRegressionOperator, self).__init__(d_t, vertex_oriented)

self._model: Optional[RegressionModel] = None
self._model: Optional[SKLearnRegressor] = None

@property
def model(self) -> Optional[RegressionModel]:
def model(self) -> Optional[SKLearnRegressor]:
"""
The regression model behind the operator.
"""
return self._model

@model.setter
def model(self, model: Optional[RegressionModel]):
def model(self, model: Optional[SKLearnRegressor]):
self._model = model

def solve(
Expand Down Expand Up @@ -173,7 +164,7 @@ def generate_data(

def fit_model(
self,
model: RegressionModel,
model: SKLearnRegressor,
data: Tuple[np.ndarray, np.ndarray],
test_size: float = .2,
score_func: Callable[[np.ndarray, np.ndarray], float] =
Expand Down Expand Up @@ -211,7 +202,7 @@ def train(
self,
ivp: InitialValueProblem,
oracle: Operator,
model: RegressionModel,
model: SKLearnRegressor,
iterations: int,
perturbation_function: Callable[[float, np.ndarray], np.ndarray],
isolate_perturbations: bool = False,
Expand Down Expand Up @@ -274,3 +265,26 @@ def _create_input_placeholder(
t = np.empty((len(x), 1))
y = np.empty((len(x), diff_eq.y_dimension * len(x)))
return np.hstack([y, t, x])


class SKLearnRegressor(Protocol):
"""A protocol class for scikit-learn regression models."""

def fit(
self,
x: np.typing.ArrayLike,
y: np.typing.ArrayLike,
sample_weight: Optional[np.typing.ArrayLike] = None
) -> SKLearnRegressor:
...

def predict(self, x: np.typing.ArrayLike) -> np.ndarray:
...

def score(
self,
x: np.typing.ArrayLike,
y: np.typing.ArrayLike,
sample_weight: Optional[np.typing.ArrayLike] = None
) -> float:
...

0 comments on commit 7b0b6a3

Please sign in to comment.