diff --git a/offsets_db_api/routers/charts.py b/offsets_db_api/routers/charts.py index c5262cf..ea8ab8e 100644 --- a/offsets_db_api/routers/charts.py +++ b/offsets_db_api/routers/charts.py @@ -30,16 +30,7 @@ def filter_valid_projects(df: pd.DataFrame, categories: list | None = None) -> p return df # Filter the dataframe to include only rows with the specified categories valid_projects = df[df['category'].isin(categories)] - - # Group by project and filter out projects that have any categories outside the given list - def all_categories_valid(group): - return all(category in categories for category in group['category'].unique()) - - valid_project_ids = ( - valid_projects.groupby('project_id').filter(all_categories_valid).project_id.unique() - ) - - return df[df['project_id'].isin(valid_project_ids)] + return valid_projects def projects_by_category( @@ -85,7 +76,7 @@ def generate_date_bins( min_value, max_value, freq: typing.Literal['D', 'W', 'M', 'Y'] | None = None, - num_bins: int = None, + num_bins: int | None = None, ): """ Generate date bins with the specified frequency. @@ -509,6 +500,7 @@ def get_credits_by_transaction_date( df = pd.read_sql_query(query.statement, engine).explode('category') logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') + df.to_csv('/tmp/testing.csv') # fix the data types df = df.astype({'transaction_date': 'datetime64[ns]'}) results = credits_by_transaction_date(df=df, freq=freq, categories=category) diff --git a/tests/test_charts.py b/tests/test_charts.py index 98f5047..92b236c 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -1,5 +1,71 @@ +import pandas as pd import pytest +from offsets_db_api.routers.charts import filter_valid_projects, projects_by_category + + +@pytest.fixture +def sample_projects(): + data = pd.DataFrame( + [ + {'category': 'ghg-management', 'project_id': 'ACR123', 'issued': 100, 'retired': 50}, + {'category': 'renewable-energy', 'project_id': 'ACR456', 'issued': 200, 'retired': 150}, + {'category': 'biodiversity', 'project_id': 'VER789', 'issued': 300, 'retired': 250}, + { + 'category': 'water-management', + 'project_id': 'CDM1011', + 'issued': 400, + 'retired': 350, + }, + {'category': 'ghg-management', 'project_id': 'ACR111', 'issued': 500, 'retired': 450}, + ] + ) + + return data + + +@pytest.mark.parametrize( + 'categories, expected', + [ + (None, ['ghg-management', 'renewable-energy', 'biodiversity', 'water-management']), + (['renewable-energy'], ['renewable-energy']), + ], +) +def test_filter_valid_projects(sample_projects, categories, expected): + result = filter_valid_projects(sample_projects, categories=categories) + assert list(result['category'].unique()) == expected + + +@pytest.mark.parametrize( + 'categories,expected', + [ + ( + None, + [ + {'category': 'ghg-management', 'value': 2}, + {'category': 'renewable-energy', 'value': 1}, + {'category': 'biodiversity', 'value': 1}, + {'category': 'water-management', 'value': 1}, + ], + ), + (['ghg-management'], [{'category': 'ghg-management', 'value': 2}]), + ( + ['renewable-energy', 'biodiversity'], + [ + {'category': 'renewable-energy', 'value': 1}, + {'category': 'biodiversity', 'value': 1}, + ], + ), + (['non-existent-category'], []), + ([], []), + ], +) +def test_projects_by_category(categories, expected, sample_projects): + result = projects_by_category(df=sample_projects, categories=categories) + sorted_result = sorted(result, key=lambda x: x['category']) + sorted_expected = sorted(expected, key=lambda x: x['category']) + assert sorted_result == sorted_expected + @pytest.mark.parametrize('freq', ['D', 'M', 'Y', 'W']) @pytest.mark.parametrize('registry', ['american-carbon-registry', 'climate-action-reserve'])