Skip to content

Commit

Permalink
add estimator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caiodallaqua committed Jan 22, 2024
1 parent db11231 commit 6eaa16c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: Checks

on:
push:
branches: [main]
paths-ignore:
- '**/*.md'
- '**/*.png'
- '**/*.json'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths-ignore:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pier_ds_utils as ds
from sklearn.base import BaseEstimator


def test_glm_wrapper():
wrapper = ds.estimator.GLMWrapper()

assert wrapper is not None

# Check attributes
assert hasattr(wrapper, 'os_factor')
assert hasattr(wrapper, 'init_params')

# Check methods
assert hasattr(wrapper, 'fit')
assert hasattr(wrapper, 'predict')
assert hasattr(wrapper, 'get_params')


def test_predict_proba_selector():
selector = ds.estimator.PredictProbaSelector(
model=BaseEstimator(),
)

assert selector is not None

# Check attributes
assert hasattr(selector, 'model')
assert hasattr(selector, 'column')

# Check methods
assert hasattr(selector, 'fit')
assert hasattr(selector, 'predict_proba')
assert hasattr(selector, 'get_params')
14 changes: 7 additions & 7 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd
from pier_ds_utils import transformer
import pier_ds_utils as ds


def test_custom_discrete_categorizer():
categorizer = transformer.CustomDiscreteCategorizer(
categorizer = ds.transformer.CustomDiscreteCategorizer(
column='gender',
categories=[
['M', 'm', 'Masculino', 'masculino'],
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_custom_discrete_categorizer():


def test_custom_interval_categorizer():
categorizer = transformer.CustomIntervalCategorizer(
categorizer = ds.transformer.CustomIntervalCategorizer(
column='price',
intervals=[
(498, 2700),
Expand Down Expand Up @@ -99,10 +99,10 @@ def test_custom_interval_categorizer():
]

def test_custom_interval_categorizer_by_category():
categorizer = transformer.CustomIntervalCategorizerByCategory(
categorizer = ds.transformer.CustomIntervalCategorizerByCategory(
category_column='brand',
interval_categorizers={
'apple': transformer.CustomIntervalCategorizer(
'apple': ds.transformer.CustomIntervalCategorizer(
column='price',
intervals=[
(498, 2700),
Expand All @@ -112,7 +112,7 @@ def test_custom_interval_categorizer_by_category():
],
labels=['fx1_apple', 'fx2_apple', 'fx3_apple', 'fx4_apple'],
),
'samsung': transformer.CustomIntervalCategorizer(
'samsung': ds.transformer.CustomIntervalCategorizer(
column='price',
intervals=[
(189, 1500),
Expand All @@ -121,7 +121,7 @@ def test_custom_interval_categorizer_by_category():
labels=['fx1_samsung', 'fx2_samsung'],
)
},
default_categorizer=transformer.CustomIntervalCategorizer(
default_categorizer=ds.transformer.CustomIntervalCategorizer(
column='price',
intervals=[(240, 5260)],
labels=['fx_outras_marcas'],
Expand Down

0 comments on commit 6eaa16c

Please sign in to comment.