Skip to content

Commit

Permalink
Implement train/test split dataset function
Browse files Browse the repository at this point in the history
  • Loading branch information
folmos-at-orange committed Jul 4, 2024
1 parent ab44c7e commit 7d2603c
Show file tree
Hide file tree
Showing 2 changed files with 575 additions and 37 deletions.
189 changes: 188 additions & 1 deletion khiops/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import os

from sklearn.model_selection import train_test_split

from khiops import core as kh
from khiops.core.internals.common import is_dict_like, type_error_message
from khiops.utils.dataset import Dataset, FileTable, PandasTable
Expand Down Expand Up @@ -62,7 +64,6 @@ def _sort_df_table(table):

def _sort_file_table(table, sep, header, output_dir):
assert isinstance(table, FileTable), type_error_message("table", table, FileTable)

domain = kh.DictionaryDomain()
dictionary = table.create_khiops_dictionary()
domain.add_dictionary(dictionary)
Expand All @@ -79,3 +80,189 @@ def _sort_file_table(table, sep, header, output_dir):
)

return out_data_source


def train_test_split_dataset(
ds_spec, target_column=None, test_size=0.25, output_dir=None, **kwargs
):
# Check the types
if not is_dict_like(ds_spec):
raise TypeError(type_error_message("ds_spec", ds_spec, "dict-like"))

# Build the dataset for the feature table
ds = Dataset(ds_spec)

# Check the parameter coherence
if not ds.is_in_memory():
if target_column is not None:
raise ValueError("'target_column' cannot be used with file path datasets")
if output_dir is None:
raise ValueError("'output_dir' must be specified for file path datasets")
if not isinstance(output_dir, str):
raise TypeError(type_error_message("output_dir", output_dir, str))

# Perform the split for each type of dataset
if ds.is_in_memory():
# Obtain the keys for the other test_train_split function
sklearn_split_params = {}
for param in ("train_size", "random_state", "shuffle", "stratify"):
if param in kwargs:
sklearn_split_params[param] = kwargs[param]

if target_column is None:
train_ds, test_ds = _train_test_split_in_memory_dataset(
ds,
target_column,
test_size=test_size,
split_params=sklearn_split_params,
)
else:
train_ds, test_ds, train_target_column, test_target_column = (
_train_test_split_in_memory_dataset(
ds,
target_column,
test_size=test_size,
split_params=sklearn_split_params,
)
)
else:
train_ds, test_ds = _train_test_split_file_dataset(ds, test_size, output_dir)

# Create the return tuple
# Note: We use `tuple` to avoid pylint warning about unbalanced-tuple-unpacking
if target_column is None:
split = tuple([train_ds.to_spec(), test_ds.to_spec()])
else:
split = tuple(
[
train_ds.to_spec(),
test_ds.to_spec(),
train_target_column,
test_target_column,
]
)

return split


def _train_test_split_in_memory_dataset(
ds, target_column, test_size, split_params=None
):
# Create shallow copies of the feature dataset
train_ds = ds.copy()
test_ds = ds.copy()

# Split the main table and the target (if any)
if target_column is None:
train_ds.main_table.data_source, test_ds.main_table.data_source = (
train_test_split(
ds.main_table.data_source, test_size=test_size, **split_params
)
)
else:
(
train_ds.main_table.data_source,
test_ds.main_table.data_source,
train_target_column,
test_target_column,
) = train_test_split(
ds.main_table.data_source,
target_column,
test_size=test_size,
**split_params,
)

# Split the secondary tables tables
# Note: The tables are traversed in BFS
todo_relations = [
relation for relation in ds.relations if relation[0] == ds.main_table.name
]
while todo_relations:
current_parent_table_name, current_child_table_name, _ = todo_relations.pop(0)
for relation in ds.relations:
parent_table_name, _, _ = relation
if parent_table_name == current_child_table_name:
todo_relations.append(relation)

for new_ds in (train_ds, test_ds):
origin_child_table = ds.get_table(current_child_table_name)
new_child_table = new_ds.get_table(current_child_table_name)
new_parent_table = new_ds.get_table(current_parent_table_name)
new_parent_key_cols_df = new_parent_table.data_source[new_parent_table.key]
new_child_table.data_source = new_parent_key_cols_df.merge(
origin_child_table.data_source, on=new_parent_table.key
)

# Build the return value
# Note: We use `tuple` to avoid pylint warning about unbalanced-tuple-unpacking
if target_column is None:
return_tuple = tuple([train_ds, test_ds])
else:
return_tuple = tuple(
[train_ds, test_ds, train_target_column, test_target_column]
)

return return_tuple


def _train_test_split_file_dataset(ds, test_size, output_dir):
domain = ds.create_khiops_dictionary_domain()
secondary_data_paths = domain.extract_data_paths(ds.main_table.name)
additional_data_tables = {}
output_additional_data_tables = {
"train": {},
"test": {},
}
# Initialize the split datasets as copies of the original one
split_dss = {
"train": ds.copy(),
"test": ds.copy(),
}
for split, split_ds in split_dss.items():
split_ds.main_table.data_source = os.path.join(
output_dir, split, f"{split_ds.main_table.name}.txt"
)

for data_path in secondary_data_paths:
dictionary = domain.get_dictionary_at_data_path(data_path)
table = ds.get_table(dictionary.name)
additional_data_tables[data_path] = table.data_source
for (
split,
split_output_additional_data_tables,
) in output_additional_data_tables.items():
data_table_path = os.path.join(output_dir, split, f"{table.name}.txt")
split_output_additional_data_tables[data_path] = data_table_path
split_dss[split].get_table(table.name).data_source = data_table_path

kh.deploy_model(
domain,
ds.main_table.name,
ds.main_table.data_source,
split_dss["train"].main_table.data_source,
additional_data_tables=additional_data_tables,
output_additional_data_tables=output_additional_data_tables["train"],
header_line=ds.header,
field_separator=ds.sep,
output_header_line=ds.header,
output_field_separator=ds.sep,
sample_percentage=100.0 * (1 - test_size),
sampling_mode="Include sample",
)
kh.deploy_model(
domain,
ds.main_table.name,
ds.main_table.data_source,
split_dss["test"].main_table.data_source,
additional_data_tables=additional_data_tables,
output_additional_data_tables=output_additional_data_tables["test"],
header_line=ds.header,
field_separator=ds.sep,
output_header_line=ds.header,
output_field_separator=ds.sep,
sample_percentage=100.0 * (1 - test_size),
sampling_mode="Exclude sample",
)

# Note: We use `tuple` to avoid pylint warning about unbalanced-tuple-unpacking
return tuple([split_dss["train"], split_dss["test"]])
Loading

0 comments on commit 7d2603c

Please sign in to comment.