diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c682da53..9b5e99850 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,11 @@ repos: rev: 22.10.0 hooks: - id: black +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.2.1 + hooks: + # Run the linter. + - id: ruff + args: [ --fix , --ignore , "F403" ] + exclude: "__init__" diff --git a/example_tool/brainmask_tool.py b/example_tool/brainmask_tool.py index 8060d7e22..568e09db2 100644 --- a/example_tool/brainmask_tool.py +++ b/example_tool/brainmask_tool.py @@ -3,14 +3,12 @@ import argparse import pydicom -from pdf2dcm import Pdf2EncapsDCM, Pdf2RgbSC +from pdf2dcm import Pdf2EncapsDCM from subprocess import run -from pipeline_functions import * +from pipeline_functions import dicom_inference_and_conversion, brainmask_inference from pdf_report import generate_report from pydicom import dcmread from pathlib import Path -from enum import Enum, auto - description = "author: Michal Brzus\nBrainmask Tool\n" @@ -51,7 +49,7 @@ output_path = Path(args.output_dir) -try : +try: nifti_path = dicom_inference_and_conversion( session_dir=session_path.as_posix(), output_dir=output_path.as_posix(), @@ -101,7 +99,9 @@ mask_path = list(Path(brainmask_output_dir).glob("*.nii.gz"))[0] stage_name = "report_generation" try: - pdf_fn = generate_report(im_path.as_posix(), mask_path.as_posix(), report_output_dir) + pdf_fn = generate_report( + im_path.as_posix(), mask_path.as_posix(), report_output_dir + ) print(f"Report created: {pdf_fn}") except Exception as e: print(f"Error in stage: {stage_name}") @@ -114,22 +114,23 @@ try: converter = Pdf2EncapsDCM() - converted_dcm = converter.run(path_pdf=pdf_fn, path_template_dcm=template_dcm.as_posix(), suffix =".dcm")[0] + converted_dcm = converter.run( + path_pdf=pdf_fn, path_template_dcm=template_dcm.as_posix(), suffix=".dcm" + )[0] del report_output_dir, brainmask_output_dir, nifti_path print(f"Report created: {converted_dcm}") # Adding needed metadata to the report """""" - pdf_dcm = dcmread(converted_dcm,stop_before_pixels=True) - + pdf_dcm = dcmread(converted_dcm, stop_before_pixels=True) extra_metadata = [ - ( - "SeriesDescription", - "0008,103e", - f"This is a rough brainmask", - ), + ( + "SeriesDescription", + "0008,103e", + "This is a rough brainmask", + ), ] for info in extra_metadata: title = info[0] @@ -148,4 +149,4 @@ print(f"Successfully finished stage: {stage_name}") -# [ 'tests/test_data/test_file.dcm' ] \ No newline at end of file +# [ 'tests/test_data/test_file.dcm' ] diff --git a/example_tool/cnn_transforms.py b/example_tool/cnn_transforms.py index f5c5e166a..7b2efd9e1 100644 --- a/example_tool/cnn_transforms.py +++ b/example_tool/cnn_transforms.py @@ -150,8 +150,8 @@ def __call__(self, data): return d -unsqueze_lambda = lambda x: x.squeeze(dim=0) -shape_lambda = lambda x: x.shape +# unsqueze_lambda = lambda x: x.squeeze(dim=0) +# shape_lambda = lambda x: x.shape class ResampleMaskToOgd(object): diff --git a/example_tool/pdf_report.py b/example_tool/pdf_report.py index 5a1be7720..24eb4d1ae 100644 --- a/example_tool/pdf_report.py +++ b/example_tool/pdf_report.py @@ -6,8 +6,6 @@ import numpy as np from io import BytesIO import base64 -import subprocess -import platform # CSS Content @@ -102,8 +100,7 @@ def generate_image(im_path, mask_path): return image_base64 -def generate_pdf(brain_volume, image_base64, file_path -): +def generate_pdf(brain_volume, image_base64, file_path): # HTML Content html_string = f""" diff --git a/example_tool/pipeline_functions.py b/example_tool/pipeline_functions.py index 711033ec0..aa722834e 100644 --- a/example_tool/pipeline_functions.py +++ b/example_tool/pipeline_functions.py @@ -1,7 +1,14 @@ -from cnn_transforms import * +from cnn_transforms import ( + LoadITKImaged, + ResampleStartRegionBrainMaskd, + ITKImageToNumpyd, + AddChanneld, + ToITKImaged, + ResampleMaskToOgd, + SaveITKImaged, +) import pytorch_lightning as pl from monai.data import CacheDataset - from monai.networks.layers import Norm from monai.networks.nets import UNet from monai.transforms import ( @@ -9,15 +16,8 @@ ScaleIntensityRangePercentilesd, ToTensord, CopyItemsd, - KeepLargestConnectedComponentd, - FillHolesd ) -from torchmetrics.classification import Dice import torch -from monai.losses.dice import GeneralizedDiceFocalLoss - -itk.MultiThreaderBase.SetGlobalDefaultNumberOfThreads(1) - from pathlib import Path from dcm_classifier.study_processing import ProcessOneDicomStudyToVolumesMappingBase from dcm_classifier.image_type_inference import ImageTypeClassifierBase @@ -25,6 +25,10 @@ import re from pydicom import dcmread from subprocess import run +import itk + + +itk.MultiThreaderBase.SetGlobalDefaultNumberOfThreads(1) def validate_subject_id(subject_id: str) -> str: @@ -82,9 +86,13 @@ def dicom_inference_and_conversion( fname = f"{validate_subject_id(sub)}_{validate_session_id(ses)}_acq-{plane}_{modality}" series_vol_list = series.get_volume_list() if len(series_vol_list) > 1: - print(f"Series {series_number} not supported. More than one volume in series.") + print( + f"Series {series_number} not supported. More than one volume in series." + ) else: - itk_im = itk_read_from_dicomfn_list(series_vol_list[0].get_one_volume_dcm_filenames()) + itk_im = itk_read_from_dicomfn_list( + series_vol_list[0].get_one_volume_dcm_filenames() + ) itk.imwrite(itk_im, f"{sub_ses_dir}/{fname}.nii.gz") return sub_ses_dir @@ -108,7 +116,9 @@ def forward(self, x): return self.model(x) -def brainmask_inference(data: list, model_file: str, out_dir: str, postfix='brainmask') -> None: +def brainmask_inference( + data: list, model_file: str, out_dir: str, postfix="brainmask" +) -> None: print("\nDATA: ", data) model = BrainmaskModel.load_from_checkpoint( checkpoint_path=model_file, @@ -155,12 +165,10 @@ def brainmask_inference(data: list, model_file: str, out_dir: str, postfix='brai with torch.no_grad(): # perform the inference test_output = model.model(item["image"].unsqueeze(dim=0).to(device)) # convert from one hot encoding - out_im = ( - torch.argmax(test_output, dim=1).detach().cpu() - ) + out_im = torch.argmax(test_output, dim=1).detach().cpu() print(out_im.shape) - item["inferred_label"] = out_im #.squeeze(dim=0) + item["inferred_label"] = out_im # .squeeze(dim=0) item["inferred_label_meta_dict"] = item["image_meta_dict"] item["inferred_label_meta_dict"]["filename"] = item["image_meta_dict"][ "filename" @@ -176,4 +184,3 @@ def brainmask_inference(data: list, model_file: str, out_dir: str, postfix='brai ] ) out_transforms(item) - diff --git a/job-monitoring-app/backend/app/models/api_key.py b/job-monitoring-app/backend/app/models/api_key.py index f0e4ccaf1..707db5da3 100644 --- a/job-monitoring-app/backend/app/models/api_key.py +++ b/job-monitoring-app/backend/app/models/api_key.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, func +from sqlalchemy import Column, Integer, String, ForeignKey, DateTime from sqlalchemy.orm import relationship from .base import Base, DateMixin diff --git a/job-monitoring-app/backend/app/models/event.py b/job-monitoring-app/backend/app/models/event.py index 6f89c9542..6103ec943 100644 --- a/job-monitoring-app/backend/app/models/event.py +++ b/job-monitoring-app/backend/app/models/event.py @@ -1,5 +1,5 @@ from sqlalchemy import Column, ForeignKey -from sqlalchemy.sql.sqltypes import String, Integer, Enum, JSON +from sqlalchemy.sql.sqltypes import String, Integer, Enum from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship diff --git a/job-monitoring-app/backend/app/models/job.py b/job-monitoring-app/backend/app/models/job.py index 9f9c60115..f54a69d2c 100644 --- a/job-monitoring-app/backend/app/models/job.py +++ b/job-monitoring-app/backend/app/models/job.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, ForeignKey, UniqueConstraint from sqlalchemy.sql.sqltypes import String, Integer -from sqlalchemy.orm import relationship, backref +from sqlalchemy.orm import relationship from .base import Base, DateMixin diff --git a/job-monitoring-app/backend/app/models/metadata_configuration.py b/job-monitoring-app/backend/app/models/metadata_configuration.py index 8d44f8e08..208774bb8 100644 --- a/job-monitoring-app/backend/app/models/metadata_configuration.py +++ b/job-monitoring-app/backend/app/models/metadata_configuration.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Enum, ForeignKey, UniqueConstraint +from sqlalchemy import Column, Enum, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.sql.sqltypes import Integer, String diff --git a/job-monitoring-app/backend/app/routers/api_keys.py b/job-monitoring-app/backend/app/routers/api_keys.py index 141120474..886831f82 100644 --- a/job-monitoring-app/backend/app/routers/api_keys.py +++ b/job-monitoring-app/backend/app/routers/api_keys.py @@ -2,11 +2,10 @@ get_db, get_user_from_api_key, API_KEY_HEADER_NAME, - get_current_user_from_token, get_current_provider, ) from app import schemas, services -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session router = APIRouter() diff --git a/job-monitoring-app/backend/app/routers/events.py b/job-monitoring-app/backend/app/routers/events.py index 80483ef23..064e11b9c 100644 --- a/job-monitoring-app/backend/app/routers/events.py +++ b/job-monitoring-app/backend/app/routers/events.py @@ -1,6 +1,6 @@ from app import schemas, services from app.dependencies import get_db, get_user_from_api_key -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session router = APIRouter() diff --git a/job-monitoring-app/backend/app/routers/job_configurations.py b/job-monitoring-app/backend/app/routers/job_configurations.py index a82f7d2e3..047fd78b3 100644 --- a/job-monitoring-app/backend/app/routers/job_configurations.py +++ b/job-monitoring-app/backend/app/routers/job_configurations.py @@ -38,7 +38,7 @@ def get_job_configuration_by_id( if job_configuration is None: raise HTTPException(status_code=404, detail="Job not found") - if not (provider.id in [job_configuration.provider_id]): + if provider.id not in [job_configuration.provider_id]: raise HTTPException(status_code=403, detail="Not allowed") return job_configuration @@ -53,7 +53,7 @@ def get_job_configurations_by_tag_and_version( ): # case 1: get specific configuration if both tag and version are provided should_get_specific_version_of_tag = tag and ( - type(version) is str and version != "latest" + isinstance(version, str) and version != "latest" ) if should_get_specific_version_of_tag: diff --git a/job-monitoring-app/backend/app/routers/jobs.py b/job-monitoring-app/backend/app/routers/jobs.py index cd322f2be..78b7f4427 100644 --- a/job-monitoring-app/backend/app/routers/jobs.py +++ b/job-monitoring-app/backend/app/routers/jobs.py @@ -42,7 +42,7 @@ def get_job( if job is None: raise HTTPException(status_code=404, detail="Job not found") - if not (user.id in [job.customer_id, job.provider_id]): + if user.id not in [job.customer_id, job.provider_id]: # TODO: add job.provider_id to the list of allowed users that can # access this once we have api key based access? See above comment raise HTTPException(status_code=403, detail="Not allowed") @@ -62,7 +62,7 @@ def get_job_events( if job is None: raise HTTPException(status_code=404, detail="Job not found") - if not (user.id in [job.customer_id, job.provider_id]): + if user.id not in [job.customer_id, job.provider_id]: raise HTTPException(status_code=403, detail="Not allowed") return job.events diff --git a/job-monitoring-app/backend/app/routers/reporting.py b/job-monitoring-app/backend/app/routers/reporting.py index 46d6b002f..debde88b2 100644 --- a/job-monitoring-app/backend/app/routers/reporting.py +++ b/job-monitoring-app/backend/app/routers/reporting.py @@ -1,10 +1,9 @@ from datetime import datetime, timedelta -from app import schemas, services +from app import services from app.dependencies import get_db, get_current_provider -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from fastapi import FastAPI from fastapi.responses import StreamingResponse import io import pandas as pd diff --git a/job-monitoring-app/backend/app/schemas/event.py b/job-monitoring-app/backend/app/schemas/event.py index d55bcebf4..9bcc5b420 100644 --- a/job-monitoring-app/backend/app/schemas/event.py +++ b/job-monitoring-app/backend/app/schemas/event.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Dict, Optional, Union -from pydantic import BaseModel, Json +from pydantic import BaseModel from .step_configuration import StepConfiguration diff --git a/job-monitoring-app/backend/app/schemas/job.py b/job-monitoring-app/backend/app/schemas/job.py index 3ff1eec8b..c4730e407 100644 --- a/job-monitoring-app/backend/app/schemas/job.py +++ b/job-monitoring-app/backend/app/schemas/job.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from pydantic import BaseModel, StrictStr +from pydantic import BaseModel from . import Event from .job_configuration import JobConfiguration diff --git a/job-monitoring-app/backend/app/schemas/step_configuration.py b/job-monitoring-app/backend/app/schemas/step_configuration.py index b595b2ba7..494de95db 100644 --- a/job-monitoring-app/backend/app/schemas/step_configuration.py +++ b/job-monitoring-app/backend/app/schemas/step_configuration.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List, Optional -from pydantic import StrictInt, StrictStr, conlist +from pydantic import StrictInt, StrictStr from .metadata_configuration import MetadataConfiguration, MetadataConfigurationCreate from .unique_tag import UniqueTagModel diff --git a/job-monitoring-app/backend/app/services/api_keys.py b/job-monitoring-app/backend/app/services/api_keys.py index 57dc8d674..5d6f2ee32 100644 --- a/job-monitoring-app/backend/app/services/api_keys.py +++ b/job-monitoring-app/backend/app/services/api_keys.py @@ -10,7 +10,6 @@ from datetime import datetime -from .users import get_user pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/job-monitoring-app/backend/app/services/events.py b/job-monitoring-app/backend/app/services/events.py index 9669a7aad..15b195159 100644 --- a/job-monitoring-app/backend/app/services/events.py +++ b/job-monitoring-app/backend/app/services/events.py @@ -1,5 +1,4 @@ from app import models, schemas -from sqlalchemy import cast from sqlalchemy.orm import Session from .job_configuration import get_step_configuration_by_composite_key diff --git a/job-monitoring-app/backend/app/services/reporting.py b/job-monitoring-app/backend/app/services/reporting.py index 92b4de813..505f140bd 100644 --- a/job-monitoring-app/backend/app/services/reporting.py +++ b/job-monitoring-app/backend/app/services/reporting.py @@ -1,5 +1,5 @@ import json -from datetime import datetime, timedelta +from datetime import datetime from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/job-monitoring-app/backend/conftest.py b/job-monitoring-app/backend/conftest.py index fe4a24882..c345bc4fc 100644 --- a/job-monitoring-app/backend/conftest.py +++ b/job-monitoring-app/backend/conftest.py @@ -1,6 +1,3 @@ -import os -import random - import pytest from app import schemas, services, models from app.models.base import truncate_all_tables diff --git a/job-monitoring-app/backend/tests/routers/test_job_configurations.py b/job-monitoring-app/backend/tests/routers/test_job_configurations.py index 3fc5d210a..237aae1ad 100644 --- a/job-monitoring-app/backend/tests/routers/test_job_configurations.py +++ b/job-monitoring-app/backend/tests/routers/test_job_configurations.py @@ -3,6 +3,28 @@ from starlette import status +def create_data_dict(tag, name, version, step_name, points, step_tag, meta_name, units, kind): + return { + "tag": tag, + "name": name, + "version": version, + "step_configurations": [ + { + "name": step_name, + "points": points, + "tag": step_tag, + "metadata_configurations": [ + { + "name": meta_name, + "units": units, + "kind": kind, + } + ], + } + ], + } + + def test_create_job_configurations(app_client, random_provider_user_with_api_key): data = { "tag": "lung_cancer", @@ -28,7 +50,7 @@ def test_create_job_configurations(app_client, random_provider_user_with_api_key def test_create_job_configurations_with_new_version( db, app_client, random_provider_user_with_api_key ): - result = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -50,25 +72,18 @@ def test_create_job_configurations_with_new_version( ), ) - data = { - "tag": "lung_cancer", - "name": "Lung Cancer Again", - "version": "1.0.1", - "step_configurations": [ - { - "name": "Lung Search", - "points": 10, - "tag": "lung_search", - "metadata_configurations": [ - { - "name": "Protein Density", - "units": "gm/cc", - "kind": "number", - } - ], - } - ], - } + data = create_data_dict( + tag="lung_cancer", + name="Lung Cancer Again", + version="1.0.1", + step_name="Lung Search", + points=10, + step_tag="lung_search", + meta_name="Protein Density", # Specify the unique name here + units="gm/cc", + kind="number", + ) + response = app_client.post( "/job_configurations", json=data, @@ -90,7 +105,7 @@ def test_create_job_configurations_with_new_version( def test_create_job_configuration_with_conflicting_version( db, app_client, random_provider_user_with_api_key ): - result = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -122,7 +137,7 @@ def test_create_job_configuration_with_conflicting_version( def test_create_job_configuration_with_conflicting_version_on_metadata( db, app_client, random_provider_user_with_api_key ): - result = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -144,25 +159,17 @@ def test_create_job_configuration_with_conflicting_version_on_metadata( ), ) - data = { - "tag": "lung_cancer", - "name": "Lung Cancer", - "version": "1.0.0", - "step_configurations": [ - { - "name": "Lung Search", - "points": 10, - "tag": "lung_search", - "metadata_configurations": [ - { - "name": "Protein Density 2", # New field but same version + tag - "units": "gm/cc", - "kind": "number", - } - ], - } - ], - } + data = create_data_dict( + tag="lung_cancer", + name="Lung Cancer", + version="1.0.0", + step_name="Lung Search", + points=10, + step_tag="lung_search", + meta_name="Protein Density 2", # Specify the unique name here + units="gm/cc", + kind="number", + ) response = app_client.post( "/job_configurations", @@ -286,7 +293,7 @@ def test_get_job_configurations_with_specific_tag_and_version( def test_job_configuration_with_tag_and_latest_version( app_client, db, random_provider_user_with_api_key ): - job_configuration1 = ( + _ = ( services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, @@ -411,7 +418,7 @@ def test_get_all_configurations_for_tag_with_missing_version( def test_get_list_of_latest_versions_for_all_job_configurations_with_version_latest( app_client, db, random_provider_user_with_api_key ): - job_configuration1 = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -422,7 +429,7 @@ def test_get_list_of_latest_versions_for_all_job_configurations_with_version_lat ), ) - job_configuration2 = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -490,7 +497,7 @@ def test_get_list_of_latest_versions_for_all_job_configurations_with_version_lat def test_get_list_of_latest_versions_for_all_job_configurations_with_empty_query_params( app_client, db, random_provider_user_with_api_key ): - job_configuration1 = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -501,7 +508,7 @@ def test_get_list_of_latest_versions_for_all_job_configurations_with_empty_query ), ) - job_configuration2 = services.create_job_configuration( + _ = services.create_job_configuration( db, provider_id=random_provider_user_with_api_key.id, job_configuration=schemas.JobConfigurationCreate( @@ -544,7 +551,7 @@ def test_get_list_of_latest_versions_for_all_job_configurations_with_empty_query access_token = response.json()["access_token"] response = app_client.get( - f"/job_configurations/", + "/job_configurations/", cookies={"access_token": access_token}, ) diff --git a/job-monitoring-app/backend/tests/routers/test_jobs.py b/job-monitoring-app/backend/tests/routers/test_jobs.py index 7728dbdb5..7fc6d56eb 100644 --- a/job-monitoring-app/backend/tests/routers/test_jobs.py +++ b/job-monitoring-app/backend/tests/routers/test_jobs.py @@ -111,7 +111,7 @@ def test_get_jobs_as_customer( access_token = response.json()["access_token"] # Use access token in the request to get a job - response = app_client.get(f"/jobs", cookies={"access_token": access_token}) + response = app_client.get("/jobs", cookies={"access_token": access_token}) assert response.status_code == 200 assert len(response.json()) == 2 diff --git a/job-monitoring-app/cdk-infra/cdk_infra/cdk_infra_stack.py b/job-monitoring-app/cdk-infra/cdk_infra/cdk_infra_stack.py index 53ed9ff96..b3c98a7ba 100644 --- a/job-monitoring-app/cdk-infra/cdk_infra/cdk_infra_stack.py +++ b/job-monitoring-app/cdk-infra/cdk_infra/cdk_infra_stack.py @@ -1,5 +1,3 @@ -import json - import aws_cdk as cdk import aws_cdk.aws_amplify_alpha as aws_amplify import aws_cdk.aws_apigateway as aws_apigateway @@ -139,7 +137,7 @@ def __init__( tracker_amplify_app.node.default_child.platform = "WEB_COMPUTE" # Amplify App Build Trigger on Create - build_trigger = aws_custom_resources.AwsCustomResource( + _ = aws_custom_resources.AwsCustomResource( self, TRACKER_PREFIX + "AmplifyBuildTrigger" diff --git a/job-monitoring-app/trackerapi/tests/test_api.py b/job-monitoring-app/trackerapi/tests/test_api.py index 0a41f0949..2f2aa34bd 100644 --- a/job-monitoring-app/trackerapi/tests/test_api.py +++ b/job-monitoring-app/trackerapi/tests/test_api.py @@ -13,7 +13,7 @@ def test_init(): def test_init_without_api_key(): with pytest.raises(TypeError): - tracker = TrackerApi() + _ = TrackerApi() def test_requests_made_with_api_key(): diff --git a/job-monitoring-app/trackerapi/tests/test_helpers.py b/job-monitoring-app/trackerapi/tests/test_helpers.py index cfea88abd..881a07f98 100644 --- a/job-monitoring-app/trackerapi/tests/test_helpers.py +++ b/job-monitoring-app/trackerapi/tests/test_helpers.py @@ -3,8 +3,12 @@ import pytest -from trackerapi import JobConfig, JobConfigManager, StepConfig -from trackerapi.helpers import DuplicateJobConfigException, MissingJobConfigException +from trackerapi.schemas import JobConfig, StepConfig +from trackerapi.helpers import ( + DuplicateJobConfigException, + MissingJobConfigException, + JobConfigManager, +) test_job_config = JobConfig( name="Test Job", diff --git a/job-monitoring-app/trackerapi/tests/test_schemas.py b/job-monitoring-app/trackerapi/tests/test_schemas.py index 11e34447c..553afba05 100644 --- a/job-monitoring-app/trackerapi/tests/test_schemas.py +++ b/job-monitoring-app/trackerapi/tests/test_schemas.py @@ -6,7 +6,7 @@ def test_no_duplicate_steps(): with pytest.raises(ValidationError) as exc: - config = JobConfig( + _ = JobConfig( name="Test Job", tag="test_job", step_configurations=[ diff --git a/job-monitoring-app/trackerapi/trackerapi/schemas.py b/job-monitoring-app/trackerapi/trackerapi/schemas.py index 784fead26..6fd58295e 100644 --- a/job-monitoring-app/trackerapi/trackerapi/schemas.py +++ b/job-monitoring-app/trackerapi/trackerapi/schemas.py @@ -1,5 +1,4 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import List, Optional from pydantic import BaseModel, StrictInt, StrictStr, conlist from enum import Enum @@ -33,10 +32,24 @@ class StepConfig(UniqueTagModel): metadata_configurations: Optional[List[MetadataConfig]] = [] - def __init__(self, name: str, tag: str, points: int, metadata_configurations: List[MetadataConfig] = None, - **kwargs): - metadata_configurations = metadata_configurations if metadata_configurations else [] - super().__init__(name=name, tag=tag, points=points, metadata_configurations=metadata_configurations, **kwargs) + def __init__( + self, + name: str, + tag: str, + points: int, + metadata_configurations: List[MetadataConfig] = None, + **kwargs + ): + metadata_configurations = ( + metadata_configurations if metadata_configurations else [] + ) + super().__init__( + name=name, + tag=tag, + points=points, + metadata_configurations=metadata_configurations, + **kwargs + ) class JobConfig(UniqueTagModel): @@ -45,12 +58,12 @@ class JobConfig(UniqueTagModel): version: str def __init__( - self, - name: str, - tag: str, - step_configurations: List[StepConfig], - version: str, - **kwargs + self, + name: str, + tag: str, + step_configurations: List[StepConfig], + version: str, + **kwargs ): super().__init__( name=name,