Skip to content

Commit

Permalink
feat: add unit tests and PR workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
patheard committed Jan 18, 2025
1 parent 95af8d7 commit 19eed31
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 11 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/etl-pull-request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: ETL pull request tests

on:
workflow_dispatch:
pull_request:
paths:
- "terragrunt/aws/glue/etl/**"
- ".github/workflows/etl-pull-request.yml"

env:
ETL_BASE_PATH: terragrunt/aws/glue/etl

jobs:
etl-test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
etl:
- platform/support/freshdesk/scripts
steps:
- name: Audit DNS requests
uses: cds-snc/dns-proxy-action@main
env:
DNS_PROXY_FORWARDTOSENTINEL: "true"
DNS_PROXY_LOGANALYTICSWORKSPACEID: ${{ secrets.LOG_ANALYTICS_WORKSPACE_ID }}
DNS_PROXY_LOGANALYTICSSHAREDKEY: ${{ secrets.LOG_ANALYTICS_WORKSPACE_KEY }}

- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |
etl:
- '${{ env.ETL_BASE_PATH }}/${{ matrix.etl }}/**'
- '.github/workflows/export-pull-request.yml'
- name: Setup python
if: steps.changes.outputs.etl == 'true'
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.13"

- name: Run PR tests
if: steps.changes.outputs.etl == 'true'
working-directory: ${{ env.ETL_BASE_PATH }}/${{ matrix.etl }}
run: make ARGS=--check pull_request
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ install_dev:
lint:
flake8 --ignore=E501 *.py

pull_request: install install_dev fmt lint test

test:
python -m pytest -s -vv .

Expand All @@ -21,4 +23,5 @@ test:
install \
install_dev \
lint \
pull_request \
test
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,10 @@
DATABASE_NAME_TRANSFORMED = args["database_name_transformed"]
TABLE_NAME = args["table_name"]

sparkContext = SparkContext()
glueContext = GlueContext(sparkContext)
job = Job(glueContext)
job.init(JOB_NAME, args)

glueContext = GlueContext(SparkContext.getOrCreate())
logger = glueContext.get_logger()


def validate_schema(dataframe: pd.DataFrame, glue_table_schema: pd.DataFrame) -> bool:
"""
Validate that the DataFrame conforms to the Glue table schema.
Expand All @@ -59,7 +56,6 @@ def validate_schema(dataframe: pd.DataFrame, glue_table_schema: pd.DataFrame) ->
f"Validation failed: Column '{column_name}' type mismatch. Expected {column_type}"
)
return False

return True


Expand Down Expand Up @@ -98,11 +94,13 @@ def get_days_tickets(day: datetime) -> pd.DataFrame:
logger.info(f"Loading source JSON file: {source_file_path}")
new_tickets = pd.DataFrame()
try:
new_tickets = wr.s3.read_json(source_file_path, dtype = True)
new_tickets = wr.s3.read_json(source_file_path, dtype=True)

# Ensure date columns are parsed correctly and all timezones are treated as UTC
for date_column in ["created_at", "updated_at", "due_by", "fr_due_by"]:
new_tickets[date_column] = pd.to_datetime(new_tickets[date_column], errors="coerce")
new_tickets[date_column] = pd.to_datetime(
new_tickets[date_column], errors="coerce"
)
new_tickets[date_column] = new_tickets[date_column].dt.tz_localize(None)

except wr.exceptions.NoFilesFound:
Expand All @@ -125,7 +123,9 @@ def get_existing_tickets(start_date: str) -> pd.DataFrame:
lambda partition: partition[PARTITION_KEY] >= start_date_formatted
),
)
existing_tickets["updated_at"] = existing_tickets["updated_at"].dt.tz_localize(None) # Treat all as UTC
existing_tickets["updated_at"] = existing_tickets["updated_at"].dt.tz_localize(
None
) # Treat all as UTC
except wr.exceptions.NoFilesFound:
logger.warn("No existing data found. Starting fresh.")

Expand Down Expand Up @@ -189,6 +189,9 @@ def process_tickets():
)
logger.info("ETL process completed successfully.")

process_tickets()

job.commit()
if __name__ == "__main__":
job = Job(glueContext)
job.init(JOB_NAME, args)
process_tickets()
job.commit()
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import pytest
import sys

from datetime import datetime, UTC
from unittest.mock import Mock, patch

import pandas as pd

from awswrangler.exceptions import NoFilesFound

# Create mock for getResolvedOptions that returns test arguments
mock_args = {
"JOB_NAME": "test_job",
"source_bucket": "test-source-bucket",
"source_prefix": "test-source-prefix/",
"transformed_bucket": "test-transformed-bucket",
"transformed_prefix": "test-transformed-prefix/",
"database_name_raw": "test_raw_db",
"database_name_transformed": "test_transformed_db",
"table_name": "test_table",
}

# Mock the AWS Glue and PySpark modules
mock_glue_utils = Mock()
mock_glue_utils.getResolvedOptions.return_value = mock_args
sys.modules["awsglue.utils"] = mock_glue_utils
sys.modules["awsglue.context"] = Mock()
sys.modules["awsglue.job"] = Mock()
sys.modules["pyspark.context"] = Mock()

# flake8: noqa: E402
from process_tickets import (
validate_schema,
is_type_compatible,
merge_tickets,
process_tickets,
get_days_tickets,
)


# Mock the AWS Glue and PySpark dependencies
@pytest.fixture
def mock_glue_context():
mock_logger = Mock()
mock_logger.info = Mock()
mock_logger.error = Mock()
mock_logger.warn = Mock()

mock_context = Mock()
mock_context.get_logger.return_value = mock_logger
return mock_context


@pytest.fixture
def mock_spark_context():
return Mock()


@pytest.fixture
def mock_job():
return Mock()


# Sample test data fixtures
@pytest.fixture
def sample_tickets_df():
return pd.DataFrame(
{
"id": ["1", "2", "3"],
"subject": ["Test 1", "Test 2", "Test 3"],
"created_at": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
"updated_at": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
"due_by": pd.to_datetime(["2024-01-05", "2024-01-06", "2024-01-07"]),
"fr_due_by": pd.to_datetime(["2024-01-04", "2024-01-05", "2024-01-06"]),
"status": ["open", "pending", "closed"],
"priority": [1, 2, 3],
}
)


@pytest.fixture
def glue_table_schema():
return pd.DataFrame(
{
"Column Name": [
"id",
"subject",
"created_at",
"updated_at",
"due_by",
"fr_due_by",
"status",
"priority",
],
"Type": [
"string",
"string",
"timestamp",
"timestamp",
"timestamp",
"timestamp",
"string",
"int",
],
}
)


def test_validate_schema_valid(sample_tickets_df, glue_table_schema):
assert validate_schema(sample_tickets_df, glue_table_schema) is True


def test_validate_schema_missing_column(sample_tickets_df, glue_table_schema):
df_missing_column = sample_tickets_df.drop("status", axis=1)
assert validate_schema(df_missing_column, glue_table_schema) is False


def test_validate_schema_wrong_type(sample_tickets_df, glue_table_schema):
df_wrong_type = sample_tickets_df.copy()
df_wrong_type["priority"] = pd.to_datetime(
["2024-01-04", "2024-01-05", "2024-01-06"]
)
assert validate_schema(df_wrong_type, glue_table_schema) is False


def test_is_type_compatible():
assert is_type_compatible(pd.Series(["a", "b", "c"]), "string") is True
assert is_type_compatible(pd.Series([1, 2, 3]), "int") is True
assert is_type_compatible(pd.Series([1.1, 2.2, 3.3]), "double") is True
assert is_type_compatible(pd.Series([True, False]), "boolean") is True
assert is_type_compatible(pd.Series(["a", "b", "c"]), "int") is False
assert is_type_compatible(pd.Series(["a", "b", "c"]), "foobar") is False


# Test ticket merging functionality
def test_merge_tickets_empty_existing(sample_tickets_df):
existing_tickets = pd.DataFrame()
merged = merge_tickets(existing_tickets, sample_tickets_df)

assert len(merged) == len(sample_tickets_df)
assert all(merged["id"] == sample_tickets_df["id"])


def test_merge_tickets_with_duplicates():
# Create existing tickets with some overlap
existing_tickets = pd.DataFrame(
{
"id": ["1", "2"],
"subject": ["Old 1", "Old 2"],
"updated_at": pd.to_datetime(["2024-01-01", "2024-01-02"]),
}
)

# Create new tickets with updated information for id=1
new_tickets = pd.DataFrame(
{
"id": ["1", "3"],
"subject": ["Updated 1", "New 3"],
"updated_at": pd.to_datetime(["2024-01-03", "2024-01-03"]),
}
)

merged = merge_tickets(existing_tickets, new_tickets)

assert len(merged) == 3 # Should have 3 unique tickets
assert (
merged[merged["id"] == "1"]["subject"].iloc[0] == "Updated 1"
) # Should keep newer version


# Test the main process with mocked AWS services
@patch("awswrangler.s3")
@patch("awswrangler.catalog")
def test_process_tickets(
mock_wr_catalog, mock_wr_s3, sample_tickets_df, glue_table_schema, mock_glue_context
):
# Mock AWS Wrangler responses
mock_wr_s3.read_json.return_value = sample_tickets_df
mock_wr_catalog.table.return_value = glue_table_schema
mock_wr_s3.read_parquet.return_value = sample_tickets_df

# Run the process
with patch("process_tickets.glueContext", mock_glue_context):
process_tickets()

# Verify the write operation was called
mock_wr_s3.to_parquet.assert_called_once()


# Test error handling
@patch("awswrangler.s3")
@patch("awswrangler.catalog")
def test_process_tickets_no_new_data(
mock_wr_catalog, mock_wr_s3, glue_table_schema, mock_glue_context
):
# Mock empty response from S3
mock_wr_s3.read_json.side_effect = NoFilesFound("Simulate no file for read_json")
mock_wr_catalog.table.return_value = glue_table_schema

# Run the process
with patch("process_tickets.glueContext", mock_glue_context):
process_tickets()

# Verify no write operation was attempted
mock_wr_s3.to_parquet.assert_not_called()


# Test date handling
def test_get_days_tickets_date_handling(mock_glue_context):
test_date = datetime(2024, 1, 1, tzinfo=UTC)

with patch("awswrangler.s3.read_json") as mock_read_json:
# Create test data with timezone-aware timestamps
test_data = pd.DataFrame(
{
"created_at": [pd.Timestamp("2024-01-01 10:00:00+00:00")],
"updated_at": [pd.Timestamp("2024-01-01 11:00:00+00:00")],
"due_by": [pd.Timestamp("2024-01-02 10:00:00+00:00")],
"fr_due_by": [pd.Timestamp("2024-01-02 11:00:00+00:00")],
}
)
mock_read_json.return_value = test_data

result = get_days_tickets(test_date)

# Verify all datetime columns are timezone-naive
assert result["created_at"].dt.tz is None
assert result["updated_at"].dt.tz is None
assert result["due_by"].dt.tz is None
assert result["fr_due_by"].dt.tz is None

0 comments on commit 19eed31

Please sign in to comment.