Skip to content

Commit

Permalink
enable search across beneficiary columns on /projects endpoint (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Jan 23, 2025
1 parent 6148099 commit cbb9494
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
12 changes: 6 additions & 6 deletions offsets_db_api/routers/credits.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ async def get_credits(
transaction_date_to: datetime.datetime | datetime.date | None = Query(
default=None, description='Format: YYYY-MM-DD'
),
search: str | None = Query(
beneficiary_search: str | None = Query(
None,
description='Case insensitive search string. Currently searches in fields specified in `search_fileds` parameter',
),
search_fields: list[str] = Query(
beneficiary_search_fields: list[str] = Query(
default=[
'retirement_beneficiary',
'retirement_account',
Expand Down Expand Up @@ -85,11 +85,11 @@ async def get_credits(
operation=operation,
)

if search:
if beneficiary_search:
# Default to case-insensitive partial match
search_term = f'%{search}%'
search_term = f'%{beneficiary_search}%'
search_conditions = []
for field in search_fields:
for field in beneficiary_search_fields:
if field in Credit.__table__.columns:
search_conditions.append(getattr(Credit, field).ilike(search_term))
elif field in Project.__table__.columns:
Expand All @@ -101,7 +101,7 @@ async def get_credits(
if sort:
statement = apply_sorting(statement=statement, sort=sort, model=Credit, primary_key='id')

logger.info(f"SQL Credits Query: {statement.compile(compile_kwargs={'literal_binds': True})}")
logger.info(f'SQL Credits Query: {statement.compile(compile_kwargs={"literal_binds": True})}')

total_entries, current_page, total_pages, next_page, results = handle_pagination(
statement=statement,
Expand Down
68 changes: 55 additions & 13 deletions offsets_db_api/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from fastapi_cache.decorator import cache
from sqlalchemy import or_
from sqlmodel import Session, col, select
from sqlalchemy.orm import aliased
from sqlmodel import Session, col, distinct, select

from offsets_db_api.cache import CACHE_NAMESPACE
from offsets_db_api.database import get_session
from offsets_db_api.log import get_logger
from offsets_db_api.models import (
Clip,
ClipProject,
Credit,
PaginatedProjects,
Project,
ProjectType,
Expand Down Expand Up @@ -51,6 +53,19 @@ async def get_projects(
None,
description='Case insensitive search string. Currently searches on `project_id` and `name` fields only.',
),
beneficiary_search: str | None = Query(
None,
description='Case insensitive search string. Currently searches on specified beneficiary_search_fields only.',
),
beneficiary_search_fields: list[str] = Query(
default=[
'retirement_beneficiary',
'retirement_account',
'retirement_note',
'retirement_reason',
],
description='Beneficiary fields to search in',
),
current_page: int = Query(1, description='Page number', ge=1),
per_page: int = Query(100, description='Items per page', le=200, ge=1),
sort: list[str] = Query(
Expand All @@ -64,6 +79,9 @@ async def get_projects(

logger.info(f'Getting projects: {request.url}')

# Base query without Credit join
matching_projects = select(distinct(Project.project_id))

filters = [
('registry', registry, 'ilike', Project),
('country', country, 'ilike', Project),
Expand All @@ -79,9 +97,42 @@ async def get_projects(
('project_type', project_type, 'ilike', ProjectType),
]

# Modified to include ProjectType in the initial query
statement = select(Project, ProjectType.project_type, ProjectType.source).outerjoin(
ProjectType, Project.project_id == ProjectType.project_id
if search:
search_pattern = f'%{search}%'
matching_projects = matching_projects.where(
or_(
col(Project.project_id).ilike(search_pattern),
col(Project.name).ilike(search_pattern),
)
)

if beneficiary_search:
Credit_alias = aliased(Credit)
matching_projects = matching_projects.outerjoin(
Credit_alias, col(Project.project_id) == col(Credit_alias.project_id)
)
beneficiary_search_pattern = f'%{beneficiary_search}%'
beneficiary_search_conditions = []

for field in beneficiary_search_fields:
if hasattr(Credit_alias, field):
beneficiary_search_conditions.append(
getattr(Credit_alias, field).ilike(beneficiary_search_pattern)
)
elif hasattr(Project, field):
beneficiary_search_conditions.append(
getattr(Project, field).ilike(beneficiary_search_pattern)
)

matching_projects = matching_projects.where(or_(*beneficiary_search_conditions))

matching_projects_select = select(matching_projects.subquery())

# Use the subquery to filter the main query
statement = (
select(Project, ProjectType.project_type, ProjectType.source)
.outerjoin(ProjectType, col(Project.project_id) == col(ProjectType.project_id))
.where(col(Project.project_id).in_(matching_projects_select))
)

for attribute, values, operation, model in filters:
Expand All @@ -93,15 +144,6 @@ async def get_projects(
operation=operation,
)

if search:
search_pattern = f'%{search}%'
statement = statement.where(
or_(
col(Project.project_id).ilike(search_pattern),
col(Project.name).ilike(search_pattern),
)
)

if sort:
statement = apply_sorting(
statement=statement, sort=sort, model=Project, primary_key='project_id'
Expand Down
8 changes: 8 additions & 0 deletions tests/test_credits.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@ def test_get_credits_with_invalid_sort(test_app: TestClient):
response = test_app.get('/credits/?sort=invalid_field')
assert response.status_code == 400
assert 'Invalid sort field' in response.json()['detail']


@pytest.mark.parametrize('beneficiary_search', ['foo'])
def test_credits_beneficiary_search(test_app: TestClient, beneficiary_search):
response = test_app.get(f'/credits?beneficiary_search={beneficiary_search}')
assert response.status_code == 200
data = response.json()['data']
assert isinstance(data, list)
8 changes: 8 additions & 0 deletions tests/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def test_get_projects_with_filters(
assert issued_min <= project['issued'] <= issued_max


@pytest.mark.parametrize('beneficiary_search', ['foo'])
def test_projects_beneficiary_search(test_app: TestClient, beneficiary_search):
response = test_app.get(f'/projects?beneficiary_search={beneficiary_search}')
assert response.status_code == 200
data = response.json()['data']
assert isinstance(data, list)


def test_get_projects_with_invalid_sort(test_app: TestClient):
response = test_app.get('/projects?sort=+foo')
assert response.status_code == 400
Expand Down

0 comments on commit cbb9494

Please sign in to comment.