Skip to content

Commit

Permalink
coverage 1
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 22, 2024
1 parent 57d12be commit 29f6c33
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 49 deletions.
27 changes: 1 addition & 26 deletions rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,11 @@

from rdt.errors import TransformerInputError
from rdt.transformers.base import BaseTransformer
from rdt.transformers.utils import check_nan_in_transform, fill_nan_with_none
from rdt.transformers.utils import check_nan_in_transform, fill_nan_with_none, try_convert_to_dtype

LOGGER = logging.getLogger(__name__)


def try_convert_to_dtype(data, dtype):
"""Try to convert data to a given dtype.
Args:
data (pd.Series or numpy.ndarray):
Data to convert.
dtype (str):
Data type to convert to.
Returns:
data:
Data converted to the given dtype.
"""
try:
data = data.astype(dtype)
except ValueError as error:
is_integer = pd.api.types.is_integer_dtype(dtype)
if is_integer:
data = data.astype(float)
else:
raise error

return data


class UniformEncoder(BaseTransformer):
"""Transformer for categorical data.
Expand Down
25 changes: 25 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,28 @@ def check_nan_in_transform(data, dtype):
message += '.'

warnings.warn(message)


def try_convert_to_dtype(data, dtype):
"""Try to convert data to a given dtype.
Args:
data (pd.Series or numpy.ndarray):
Data to convert.
dtype (str):
Data type to convert to.
Returns:
data:
Data converted to the given dtype.
"""
try:
data = data.astype(dtype)
except ValueError as error:
is_integer = pd.api.types.is_integer_dtype(dtype)
if is_integer:
data = data.astype(float)
else:
raise error

return data
24 changes: 2 additions & 22 deletions tests/unit/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,11 @@
from rdt.errors import TransformerInputError
from rdt.transformers.categorical import (
CustomLabelEncoder, FrequencyEncoder, LabelEncoder, OneHotEncoder, OrderedLabelEncoder,
OrderedUniformEncoder, UniformEncoder, try_convert_to_dtype)
OrderedUniformEncoder, UniformEncoder)

RE_SSN = re.compile(r'\d\d\d-\d\d-\d\d\d\d')


def test_try_convert_to_dtype():
"""Test ``try_convert_to_dtype`` method.
If the data can be converted to the specified dtype, it should be converted.
If the data cannot be converted, a ValueError should be raised.
Should allow to convert integer with NaNs to float.
"""
# Setup
data_int_with_nan = pd.Series([1.0, 2.0, np.nan, 4.0, 5.0])
data_not_convetible = pd.Series(['a', 'b', 'c', 'd', 'e'])

# Run
output_int_with_nan = try_convert_to_dtype(data_int_with_nan, 'int')
with pytest.raises(ValueError, match="could not convert string to float: 'a'"):
try_convert_to_dtype(data_not_convetible, 'int')

# Assert
expected_data_with_nan = pd.Series([1, 2, np.nan, 4, 5])
pd.testing.assert_series_equal(output_int_with_nan, expected_data_with_nan)


class TestUniformEncoder:
"""Test class for the UniformEncoder."""

Expand Down Expand Up @@ -2359,6 +2338,7 @@ def test__fit(self):
transformer._fit(data)

# Assert
assert transformer.dtype == 'float'
expected_values_to_categories = {0: 2, 1: 3, 2: np.nan, 3: 1}
expected_categories_to_values = {2: 0, 3: 1, 1: 3, np.nan: 2}
for key, value in transformer.values_to_categories.items():
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest

from rdt.transformers.utils import (
_any, _max_repeat, check_nan_in_transform, flatten_column_list, strings_from_regex)
_any, _max_repeat, check_nan_in_transform, flatten_column_list, strings_from_regex,
try_convert_to_dtype)


def test_strings_from_regex_literal():
Expand Down Expand Up @@ -81,8 +82,10 @@ def test_check_nan_in_transform():
"""
# Setup
transformed = pd.Series([0.1026, 0.1651, np.nan, 0.3116, 0.6546, 0.8541, 0.7041])
data_without_nans = pd.Series([0.1026, 0.1651, 0.3116, 0.6546, 0.8541, 0.7041])

# Run and Assert
check_nan_in_transform(data_without_nans, 'float')
expected_message = (
'There are null values in the transformed data. The reversed '
'transformed data will contain null values'
Expand All @@ -94,3 +97,27 @@ def test_check_nan_in_transform():

with pytest.warns(UserWarning, match=expected_message_integer):
check_nan_in_transform(transformed, 'int')


def test_try_convert_to_dtype():
"""Test ``try_convert_to_dtype`` method.
If the data can be converted to the specified dtype, it should be converted.
If the data cannot be converted, a ValueError should be raised.
Should allow to convert integer with NaNs to float.
"""
# Setup
data_int_with_nan = pd.Series([1.0, 2.0, np.nan, 4.0, 5.0])
data_not_convertible = pd.Series(['a', 'b', 'c', 'd', 'e'])

# Run
output_convertibe = try_convert_to_dtype(data_int_with_nan, 'str')
output_int_with_nan = try_convert_to_dtype(data_int_with_nan, 'int')
with pytest.raises(ValueError, match="could not convert string to float: 'a'"):
try_convert_to_dtype(data_not_convertible, 'int')

# Assert
expected_data_with_nan = pd.Series([1, 2, np.nan, 4, 5])
expected_data_covertibe = pd.Series(['1.0', '2.0', 'nan', '4.0', '5.0'])
pd.testing.assert_series_equal(output_int_with_nan, expected_data_with_nan)
pd.testing.assert_series_equal(output_convertibe, expected_data_covertibe)

0 comments on commit 29f6c33

Please sign in to comment.