From 6eaa16cdd9927a9cf4c2fae1c9a189aaa34740e0 Mon Sep 17 00:00:00 2001 From: caiodallaqua Date: Mon, 22 Jan 2024 14:54:02 -0300 Subject: [PATCH] add estimator tests --- .github/workflows/checks.yaml | 6 ++++++ tests/test_estimator.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_transformer.py | 14 +++++++------- 3 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 tests/test_estimator.py diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 23d383f..1683a2c 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -1,6 +1,12 @@ name: Checks on: + push: + branches: [main] + paths-ignore: + - '**/*.md' + - '**/*.png' + - '**/*.json' pull_request: types: [opened, synchronize, reopened, ready_for_review] paths-ignore: diff --git a/tests/test_estimator.py b/tests/test_estimator.py new file mode 100644 index 0000000..cd4ef19 --- /dev/null +++ b/tests/test_estimator.py @@ -0,0 +1,34 @@ +import pier_ds_utils as ds +from sklearn.base import BaseEstimator + + +def test_glm_wrapper(): + wrapper = ds.estimator.GLMWrapper() + + assert wrapper is not None + + # Check attributes + assert hasattr(wrapper, 'os_factor') + assert hasattr(wrapper, 'init_params') + + # Check methods + assert hasattr(wrapper, 'fit') + assert hasattr(wrapper, 'predict') + assert hasattr(wrapper, 'get_params') + + +def test_predict_proba_selector(): + selector = ds.estimator.PredictProbaSelector( + model=BaseEstimator(), + ) + + assert selector is not None + + # Check attributes + assert hasattr(selector, 'model') + assert hasattr(selector, 'column') + + # Check methods + assert hasattr(selector, 'fit') + assert hasattr(selector, 'predict_proba') + assert hasattr(selector, 'get_params') diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 1553d8c..4588ce0 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,9 +1,9 @@ import pandas as pd -from pier_ds_utils import transformer +import pier_ds_utils as ds def test_custom_discrete_categorizer(): - categorizer = transformer.CustomDiscreteCategorizer( + categorizer = ds.transformer.CustomDiscreteCategorizer( column='gender', categories=[ ['M', 'm', 'Masculino', 'masculino'], @@ -53,7 +53,7 @@ def test_custom_discrete_categorizer(): def test_custom_interval_categorizer(): - categorizer = transformer.CustomIntervalCategorizer( + categorizer = ds.transformer.CustomIntervalCategorizer( column='price', intervals=[ (498, 2700), @@ -99,10 +99,10 @@ def test_custom_interval_categorizer(): ] def test_custom_interval_categorizer_by_category(): - categorizer = transformer.CustomIntervalCategorizerByCategory( + categorizer = ds.transformer.CustomIntervalCategorizerByCategory( category_column='brand', interval_categorizers={ - 'apple': transformer.CustomIntervalCategorizer( + 'apple': ds.transformer.CustomIntervalCategorizer( column='price', intervals=[ (498, 2700), @@ -112,7 +112,7 @@ def test_custom_interval_categorizer_by_category(): ], labels=['fx1_apple', 'fx2_apple', 'fx3_apple', 'fx4_apple'], ), - 'samsung': transformer.CustomIntervalCategorizer( + 'samsung': ds.transformer.CustomIntervalCategorizer( column='price', intervals=[ (189, 1500), @@ -121,7 +121,7 @@ def test_custom_interval_categorizer_by_category(): labels=['fx1_samsung', 'fx2_samsung'], ) }, - default_categorizer=transformer.CustomIntervalCategorizer( + default_categorizer=ds.transformer.CustomIntervalCategorizer( column='price', intervals=[(240, 5260)], labels=['fx_outras_marcas'],