-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a _update_multi_column_transformer method #758
Changes from 4 commits
74d7303
17aa0d7
f237822
96d8141
e6083f7
e15e9d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,8 @@ | |
import pytest | ||
|
||
from rdt import get_demo | ||
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError | ||
from rdt.errors import ( | ||
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError) | ||
from rdt.hyper_transformer import Config, HyperTransformer | ||
from rdt.transformers import ( | ||
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder, | ||
|
@@ -67,6 +68,10 @@ def _fit(self, data): | |
} for column in self.columns | ||
} | ||
|
||
@classmethod | ||
def _validate_sdtypes(cls, columns_to_sdtype): | ||
return None | ||
|
||
def _get_prefix(self): | ||
return None | ||
|
||
|
@@ -1846,3 +1851,80 @@ def test_with_tuple_returned_by_faker(self): | |
] | ||
}) | ||
pd.testing.assert_frame_equal(result, expected_results) | ||
|
||
methods_to_inputs = { | ||
'update_transformers': {'column_name_to_transformer': {'C': UniformEncoder()}}, | ||
'update_transformers_by_sdtype': {'sdtype': 'boolean', 'transformer': UniformEncoder()}, | ||
'remove_transformers': {'column_names': ['C']}, | ||
'remove_transformers_by_sdtype': {'sdtype': 'boolean'}, | ||
} | ||
|
||
@pytest.mark.parametrize('method_name', methods_to_inputs.keys()) | ||
def test_unvalid_multi_column(self, method_name): | ||
"""Test that the Hypertransformer handles invalid multi column transformers.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can be more specific here and add a couple sentences after saying that this tests the case where the update/remove call leaves the transformer as invalid |
||
# Setup | ||
class BadDummyMultiColumnTransformer(DummyMultiColumnTransformerNumerical): | ||
|
||
@classmethod | ||
def _validate_sdtypes(cls, columns_to_sdtype): | ||
raise TransformerInputError('Invalid sdtype') | ||
|
||
dict_config = { | ||
'sdtypes': { | ||
'A': 'categorical', | ||
'B': 'categorical', | ||
'D': 'categorical', | ||
'E': 'categorical', | ||
'C': 'boolean', | ||
}, | ||
'transformers': { | ||
'A': UniformEncoder(), | ||
('B', 'D', 'C'): BadDummyMultiColumnTransformer(), | ||
'E': UniformEncoder() | ||
} | ||
} | ||
|
||
config = Config(dict_config) | ||
ht = HyperTransformer() | ||
ht.set_config(config) | ||
|
||
# Run | ||
expected_warning = re.escape( | ||
"Transformer 'BadDummyMultiColumnTransformer' is incompatible with the " | ||
"multi-column field '('B', 'D')'. Assigning default transformer to the columns." | ||
) | ||
with pytest.warns(UserWarning, match=expected_warning): | ||
ht.__getattribute__(method_name)(**self.methods_to_inputs[method_name]) | ||
|
||
# Assert | ||
new_config = ht.get_config() | ||
expected_dict_config = { | ||
'sdtypes': { | ||
'A': 'categorical', | ||
'B': 'categorical', | ||
'D': 'categorical', | ||
'E': 'categorical', | ||
'C': 'boolean' | ||
} | ||
} | ||
if method_name.startswith('update'): | ||
expected_dict_config['transformers'] = { | ||
'A': UniformEncoder(), | ||
'E': UniformEncoder(), | ||
'C': UniformEncoder(), | ||
'B': UniformEncoder(), | ||
'D': UniformEncoder() | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could parametrize this part as well. Basically instead of method_to_inputs being a dict, make it a list of tuples where each tuple has the structure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point, thanks, done in e6083f7 |
||
else: | ||
expected_dict_config['transformers'] = { | ||
'A': UniformEncoder(), | ||
'E': UniformEncoder(), | ||
'C': None, | ||
'B': UniformEncoder(), | ||
'D': UniformEncoder() | ||
} | ||
|
||
expected_config = Config(expected_dict_config) | ||
expected_multi_columns = {} | ||
assert ht._multi_column_fields == expected_multi_columns | ||
assert repr(new_config) == repr(expected_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: invalid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thanks, done in e6083f7