Skip to content

Commit

Permalink
Move target_column from tables to dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
folmos-at-orange committed Jun 12, 2024
1 parent 69a7636 commit 31ea278
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 608 deletions.
82 changes: 44 additions & 38 deletions khiops/sklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
is_list_like,
type_error_message,
)
from khiops.utils.dataset import Dataset, FileTable, read_internal_data_table
from khiops.utils.dataset import (
Dataset,
FileTable,
get_khiops_variable_name,
read_internal_data_table,
)

# Disable PEP8 variable names because of scikit-learn X,y conventions
# To capture invalid-names other than X,y run:
Expand Down Expand Up @@ -123,14 +128,14 @@ def _check_dictionary_compatibility(


def _check_categorical_target_type(ds):
if ds.target_column_type is None:
if ds.target_column is None:
raise ValueError("Target vector is not specified.")

if ds.is_in_memory() and not (
isinstance(ds.target_column_type, pd.CategoricalDtype)
or pd.api.types.is_string_dtype(ds.target_column_type)
or pd.api.types.is_integer_dtype(ds.target_column_type)
or pd.api.types.is_float_dtype(ds.target_column_type)
isinstance(ds.target_column_dtype, pd.CategoricalDtype)
or pd.api.types.is_string_dtype(ds.target_column_dtype)
or pd.api.types.is_integer_dtype(ds.target_column_dtype)
or pd.api.types.is_float_dtype(ds.target_column_dtype)
):
raise ValueError(
f"'y' has invalid type '{ds.target_column_type}'. "
Expand All @@ -145,16 +150,16 @@ def _check_categorical_target_type(ds):


def _check_numerical_target_type(ds):
if ds.target_column_type is None:
if ds.target_column is None:
raise ValueError("Target vector is not specified.")
if ds.is_in_memory():
if not pd.api.types.is_numeric_dtype(ds.target_column_type):
if not pd.api.types.is_numeric_dtype(ds.target_column_dtype):
raise ValueError(
f"Unknown label type '{ds.target_column_type}'. "
"Expected a numerical type."
)
if ds.main_table.target_column is not None:
assert_all_finite(ds.main_table.target_column)
if ds.target_column is not None:
assert_all_finite(ds.target_column)
elif not ds.is_in_memory() and ds.target_column_type != "Numerical":
raise ValueError(
f"Target column has invalid type '{ds.target_column_type}'. "
Expand Down Expand Up @@ -335,38 +340,38 @@ def fit(self, X, y=None, **kwargs):

return self

def _fit(self, dataset, computation_dir, **kwargs):
def _fit(self, ds, computation_dir, **kwargs):
"""Template pattern of a fit method
Parameters
----------
dataset : `Dataset`
ds : `Dataset`
The learning dataset.
computation_dir : str
Path or URI where the Khiops computation results will be stored.
The called methods are reimplemented in concrete sub-classes
"""
# Check model parameters
self._fit_check_params(dataset, **kwargs)
self._fit_check_params(ds, **kwargs)

# Check the dataset
self._fit_check_dataset(dataset)
self._fit_check_dataset(ds)

# Train the model
self._fit_train_model(dataset, computation_dir, **kwargs)
self.n_features_in_ = dataset.main_table.n_features()
self._fit_train_model(ds, computation_dir, **kwargs)
self.n_features_in_ = ds.main_table.n_features()

# If the main attributes are of the proper type finish the fitting
# Otherwise it means there was an abort (early return) of the previous steps
if isinstance(self.model_, kh.DictionaryDomain) and isinstance(
self.model_report_, kh.KhiopsJSONObject
):
self._fit_training_post_process(dataset)
self._fit_training_post_process(ds)
self.is_fitted_ = True
self.is_multitable_model_ = dataset.is_multitable()
self.is_multitable_model_ = ds.is_multitable()

def _fit_check_params(self, dataset, **_):
def _fit_check_params(self, ds, **_):
"""Check the model parameters including those data dependent (in kwargs)"""
if (
self.key is not None
Expand All @@ -375,7 +380,7 @@ def _fit_check_params(self, dataset, **_):
):
raise TypeError(type_error_message("key", self.key, str, "list-like"))

if not dataset.is_in_memory() and self.output_dir is None:
if not ds.is_in_memory() and self.output_dir is None:
raise ValueError("'output_dir' is not set but dataset is file-based")

def _fit_check_dataset(self, ds):
Expand Down Expand Up @@ -1456,7 +1461,7 @@ def _fit_prepare_training_function_inputs(self, ds, computation_dir):
ds.create_khiops_dictionary_domain(),
ds.main_table.name,
main_table_path,
ds.main_table.get_khiops_variable_name(ds.main_table.target_column_id),
get_khiops_variable_name(ds.target_column_id),
output_dir,
]

Expand Down Expand Up @@ -1499,9 +1504,7 @@ def _fit_training_post_process(self, ds):
super()._fit_training_post_process(ds)

# Set the target variable name
self.model_target_variable_name_ = ds.main_table.get_khiops_variable_name(
ds.main_table.target_column_id
)
self.model_target_variable_name_ = get_khiops_variable_name(ds.target_column_id)

# Verify it has at least one dictionary and a root dictionary in multi-table
if len(self.model_.dictionaries) == 1:
Expand Down Expand Up @@ -1778,10 +1781,10 @@ def __init__(
self._predicted_target_meta_data_tag = "Prediction"

def _is_real_target_dtype_integer(self):
assert self._original_target_type is not None, "Original target type not set"
return pd.api.types.is_integer_dtype(self._original_target_type) or (
isinstance(self._original_target_type, pd.CategoricalDtype)
and pd.api.types.is_integer_dtype(self._original_target_type.categories)
assert self._original_target_dtype is not None, "Original target type not set"
return pd.api.types.is_integer_dtype(self._original_target_dtype) or (
isinstance(self._original_target_dtype, pd.CategoricalDtype)
and pd.api.types.is_integer_dtype(self._original_target_dtype.categories)
)

def _sorted_prob_variable_names(self):
Expand Down Expand Up @@ -1843,14 +1846,14 @@ def _fit_check_dataset(self, ds):

# Check that the target is for classification in in_memory_tables
if ds.is_in_memory():
current_type_of_target = type_of_target(ds.main_table.target_column)
current_type_of_target = type_of_target(ds.target_column)
if current_type_of_target not in ["binary", "multiclass"]:
raise ValueError(
f"Unknown label type: '{current_type_of_target}' "
"for classification. Maybe you passed a floating point target?"
)
# Check if the target has more than 1 class
if ds.is_in_memory() and len(np.unique(ds.main_table.target_column)) == 1:
if ds.is_in_memory() and len(np.unique(ds.target_column)) == 1:
raise ValueError(
f"{self.__class__.__name__} can't train when only one class is present."
)
Expand All @@ -1863,7 +1866,10 @@ def _fit_training_post_process(self, ds):
super()._fit_training_post_process(ds)

# Save the target datatype
self._original_target_type = ds.target_column_type
if ds.is_in_memory():
self._original_target_dtype = ds.target_column_dtype
else:
self._original_target_dtype = np.dtype("object")

# Save class values in the order of deployment
self.classes_ = []
Expand Down Expand Up @@ -1929,21 +1935,21 @@ def predict(self, X):
y_pred = y_pred.to_numpy(copy=False).ravel()

# If integer and string just transform
if pd.api.types.is_integer_dtype(self._original_target_type):
y_pred = y_pred.astype(self._original_target_type)
elif pd.api.types.is_string_dtype(self._original_target_type):
if pd.api.types.is_integer_dtype(self._original_target_dtype):
y_pred = y_pred.astype(self._original_target_dtype)
elif pd.api.types.is_string_dtype(self._original_target_dtype):
y_pred = y_pred.astype(str, copy=False)
# If category first coerce the type to the categories' type
else:
assert pd.api.types.is_categorical_dtype(self._original_target_type), (
assert pd.api.types.is_categorical_dtype(self._original_target_dtype), (
"_original_target_dtype is not categorical"
f", it is '{self._original_target_type}'"
f", it is '{self._original_target_dtype}'"
)
if pd.api.types.is_integer_dtype(
self._original_target_type.categories.dtype
self._original_target_dtype.categories.dtype
):
y_pred = y_pred.astype(
self._original_target_type.categories.dtype, copy=False
self._original_target_dtype.categories.dtype, copy=False
)
else:
y_pred = y_pred.astype(str, copy=False)
Expand Down
Loading

0 comments on commit 31ea278

Please sign in to comment.