Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a _update_multi_column_transformer method #758

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,30 @@ def _remove_column_in_multi_column_fields(self, column):

self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple)

def _update_multi_column_transformer(self):
"""Check that multi-columns mappings are valid and update them otherwise."""
all_fields_multi_column = set(self._multi_column_fields.values())
for field in all_fields_multi_column:
transformer = self.field_transformers[field]
if transformer is None:
continue

columns_to_sdtypes = self._get_columns_to_sdtypes(field)
try:
transformer._validate_sdtypes( # pylint: disable=protected-access
columns_to_sdtypes
)
except TransformerInputError:
warnings.warn(
f"Transformer '{transformer.get_name()}' is incompatible with the "
f"multi-column field '{field}'. Assigning default transformer to the columns."
)
del self.field_transformers[field]
for column, sdtype in columns_to_sdtypes.items():
self.field_transformers[column] = deepcopy(get_default_transformer(sdtype))

self._multi_column_fields = self._create_multi_column_fields()

def update_transformers_by_sdtype(
self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None):
"""Update the transformers for the specified ``sdtype``.
Expand Down Expand Up @@ -397,6 +421,7 @@ def update_transformers_by_sdtype(
self._remove_column_in_multi_column_fields(field)

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True

def update_sdtypes(self, column_name_to_sdtype):
Expand Down Expand Up @@ -445,6 +470,7 @@ def update_sdtypes(self, column_name_to_sdtype):
)

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
Expand Down Expand Up @@ -485,6 +511,7 @@ def update_transformers(self, column_name_to_transformer):
self.field_transformers[column_name] = transformer

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True

def remove_transformers(self, column_names):
Expand Down Expand Up @@ -514,6 +541,7 @@ def remove_transformers(self, column_names):

self.field_transformers[column_name] = None

self._update_multi_column_transformer()
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)

Expand All @@ -540,6 +568,7 @@ def remove_transformers_by_sdtype(self, sdtype):

self.field_transformers[column_name] = None

self._update_multi_column_transformer()
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)

Expand Down
4 changes: 4 additions & 0 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
missing_to_print = ', '.join(missing)
raise ValueError(f'Columns ({missing_to_print}) are not present in the data.')

@classmethod
def _validate_sdtypes(cls, columns_to_sdtypes):
raise NotImplementedError()

def _fit(self, data):
"""Fit the transformer to the data.

Expand Down
84 changes: 83 additions & 1 deletion tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest

from rdt import get_demo
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
from rdt.errors import (
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError)
from rdt.hyper_transformer import Config, HyperTransformer
from rdt.transformers import (
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder,
Expand Down Expand Up @@ -67,6 +68,10 @@ def _fit(self, data):
} for column in self.columns
}

@classmethod
def _validate_sdtypes(cls, columns_to_sdtype):
return None

def _get_prefix(self):
return None

Expand Down Expand Up @@ -1846,3 +1851,80 @@ def test_with_tuple_returned_by_faker(self):
]
})
pd.testing.assert_frame_equal(result, expected_results)

methods_to_inputs = {
'update_transformers': {'column_name_to_transformer': {'C': UniformEncoder()}},
'update_transformers_by_sdtype': {'sdtype': 'boolean', 'transformer': UniformEncoder()},
'remove_transformers': {'column_names': ['C']},
'remove_transformers_by_sdtype': {'sdtype': 'boolean'},
}

@pytest.mark.parametrize('method_name', methods_to_inputs.keys())
def test_unvalid_multi_column(self, method_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: invalid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks, done in e6083f7

"""Test that the Hypertransformer handles invalid multi column transformers."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can be more specific here and add a couple sentences after saying that this tests the case where the update/remove call leaves the transformer as invalid

# Setup
class BadDummyMultiColumnTransformer(DummyMultiColumnTransformerNumerical):

@classmethod
def _validate_sdtypes(cls, columns_to_sdtype):
raise TransformerInputError('Invalid sdtype')

dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'D': 'categorical',
'E': 'categorical',
'C': 'boolean',
},
'transformers': {
'A': UniformEncoder(),
('B', 'D', 'C'): BadDummyMultiColumnTransformer(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
expected_warning = re.escape(
"Transformer 'BadDummyMultiColumnTransformer' is incompatible with the "
"multi-column field '('B', 'D')'. Assigning default transformer to the columns."
)
with pytest.warns(UserWarning, match=expected_warning):
ht.__getattribute__(method_name)(**self.methods_to_inputs[method_name])

# Assert
new_config = ht.get_config()
expected_dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'D': 'categorical',
'E': 'categorical',
'C': 'boolean'
}
}
if method_name.startswith('update'):
expected_dict_config['transformers'] = {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': UniformEncoder(),
'B': UniformEncoder(),
'D': UniformEncoder()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could parametrize this part as well. Basically instead of method_to_inputs being a dict, make it a list of tuples where each tuple has the structure (method_name, input, expected), and then just use those variables here accordingly

Copy link
Contributor Author

@R-Palazzo R-Palazzo Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, thanks, done in e6083f7

else:
expected_dict_config['transformers'] = {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': None,
'B': UniformEncoder(),
'D': UniformEncoder()
}

expected_config = Config(expected_dict_config)
expected_multi_columns = {}
assert ht._multi_column_fields == expected_multi_columns
assert repr(new_config) == repr(expected_config)
127 changes: 117 additions & 10 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,7 +2317,7 @@ def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
ht.field_transformers = {
'A': LabelEncoder(),
'B': UniformEncoder(),
"('C', 'D')": None,
('C', 'D'): None,
}
ht.field_sdtypes = {
'A': 'categorical',
Expand All @@ -2330,9 +2330,6 @@ def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
'C': ('C', 'D'),
'D': ('C', 'D')
}
mock__remove_column_in_multi_column_fields = Mock()
ht._remove_column_in_multi_column_fields = mock__remove_column_in_multi_column_fields
ht._create_multi_column_fields = Mock()

# Run
ht.update_transformers_by_sdtype(
Expand All @@ -2341,9 +2338,15 @@ def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
)

# Assert
assert len(ht.field_transformers) == 4
assert mock__remove_column_in_multi_column_fields.call_count == 1
ht._create_multi_column_fields.assert_called_once()
expected_field_transformers = {
'A': LabelEncoder(),
'B': UniformEncoder(),
'C': LabelEncoder(),
'D': None,
}
expected_multi_column_fieds = {}
assert str(ht.field_transformers) == str(expected_field_transformers)
assert ht._multi_column_fields == expected_multi_column_fieds

@patch('rdt.hyper_transformer.warnings')
def test_update_transformers_fitted(self, mock_warnings):
Expand Down Expand Up @@ -2391,7 +2394,99 @@ def test_update_transformers_fitted(self, mock_warnings):
assert instance.field_transformers['my_column'] == transformer
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)

def test_update_transformers_multi_column(self):
def test__update_transformers_multi_column_valid(self):
"""Test ``_update_multi_column_transformer`` with a valid multi-column transformer."""
# Setup
class ValidMultiColumnTransformer(BaseMultiColumnTransformer):
@classmethod
def _validate_sdtypes(cls, columns_to_sdtypes):
return

ht = HyperTransformer()
ht.field_sdtypes = {
'A': 'categorical',
'B': 'boolean',
'C': 'numerical',
}
ht.field_transformers = {
'A': LabelEncoder(),
('B', 'C'): ValidMultiColumnTransformer(),
}
ht._multi_column_fields = {
'B': ('B', 'C'),
'C': ('B', 'C'),
}

# Run
ht._update_multi_column_transformer()

# Assert
expected_field_transformers = {
'A': LabelEncoder(),
('B', 'C'): ValidMultiColumnTransformer(),
}
expected_multi_column_fieds = {
'B': ('B', 'C'),
'C': ('B', 'C'),
}
assert str(ht.field_transformers) == str(expected_field_transformers)
assert ht._multi_column_fields == expected_multi_column_fieds

def test__update_transformers_multi_column_invalid(self):
"""Test ``_update_multi_column_transformer`` with an invalid multi-column transformer.

The multi column transformer should be removed and its columns assigned to their default
transformers.
"""
# Setup
class InvalidMultiColumnTransformer(BaseMultiColumnTransformer):
@classmethod
def _validate_sdtypes(cls, columns_to_sdtypes):
raise TransformerInputError('Invalid columns.')

ht = HyperTransformer()
ht.field_sdtypes = {
'A': 'categorical',
'B': 'boolean',
'C': 'numerical',
'D': 'categorical',
'E': 'categorical'
}
ht.field_transformers = {
'A': LabelEncoder(),
('B', 'C'): InvalidMultiColumnTransformer(),
('D', 'E'): None,
}
ht._multi_column_fields = {
'B': ('B', 'C'),
'C': ('B', 'C'),
'D': ('D', 'E'),
'E': ('D', 'E'),
}

# Run
expected_msg = re.escape(
"Transformer 'InvalidMultiColumnTransformer' is incompatible with the "
"multi-column field '('B', 'C')'. Assigning default transformer to the columns."
)
with pytest.warns(UserWarning, match=expected_msg):
ht._update_multi_column_transformer()

# Assert
expected_field_transformers = {
'A': LabelEncoder(),
('D', 'E'): None,
'B': UniformEncoder(),
'C': FloatFormatter(),
}
expected_multi_column_fieds = {
'D': ('D', 'E'),
'E': ('D', 'E'),
}
assert str(ht.field_transformers) == str(expected_field_transformers)
assert ht._multi_column_fields == expected_multi_column_fieds

def test_update_transformers_with_multi_column(self):
"""Test ``update_transformers`` with a multi-column transformer."""
# Setup
ht = HyperTransformer()
Expand All @@ -2411,6 +2506,7 @@ def test_update_transformers_multi_column(self):
'C': None,
}
ht._create_multi_column_fields = Mock()
ht._update_multi_column_transformer = Mock()

# Run
ht.update_transformers(column_name_to_transformer)
Expand All @@ -2423,6 +2519,7 @@ def test_update_transformers_multi_column(self):

assert ht.field_transformers == expected_field_transformers
ht._create_multi_column_fields.assert_called_once()
ht._update_multi_column_transformer.assert_called_once()

def test_update_transformers_changing_multi_column_transformer(self):
"""Test ``update_transformers`` when changing a multi column transformer."""
Expand Down Expand Up @@ -2969,9 +3066,12 @@ def test_update_sdtypes_multi_column_with_supported_sdtypes(self):
# Setup
class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
"""Dummy multi column transformer."""

SUPPORTED_SDTYPES = ['categorical', 'boolean']

@classmethod
def _validate_sdtypes(cls, columns_to_sdtypes):
return

ht = HyperTransformer()
ht.field_sdtypes = {
'column1': 'categorical',
Expand All @@ -2988,7 +3088,13 @@ class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3')
}
ht._create_multi_column_fields = Mock()
ht._create_multi_column_fields = Mock(
return_value={
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3')
}
)
ht._update_multi_column_transformer = Mock()

# Run
ht.update_sdtypes(column_name_to_sdtype={
Expand All @@ -3012,6 +3118,7 @@ class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
assert ht.field_sdtypes == expected_field_sdtypes
assert str(ht.field_transformers) == str(expected_field_transformers)
ht._create_multi_column_fields.assert_called_once()
ht._update_multi_column_transformer.assert_called_once()

def test_update_sdtypes_multi_column_with_unsupported_sdtypes(self):
"""Test the ``update_sdtypes`` method.
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,12 @@ def test__validate_columns_to_sdtypes(self):
with pytest.raises(ValueError, match=expected_error_msg):
transformer._validate_columns_to_sdtypes(data, wrong_columns_to_sdtypes)

def test__validate_sdtypes(self):
"""Test the ``_validate_sdtypes`` method."""
# Run and Assert
with pytest.raises(NotImplementedError):
BaseMultiColumnTransformer._validate_sdtypes({})

def test__fit(self):
"""Test the ``_fit`` method.

Expand Down
Loading