Skip to content

Commit

Permalink
Merge pull request #417 from MannLabs/simplify_config
Browse files Browse the repository at this point in the history
Simplify config
  • Loading branch information
mschwoer authored Jan 9, 2025
2 parents 8ffc1c2 + de995b7 commit 730ece7
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 110 deletions.
98 changes: 20 additions & 78 deletions alphadia/constants/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ library_prediction:

# define custom alphabase modifications not part of unimod or alphabase
# also used for decoy channels
# TODO make this a list
# - name: Dimethyl:d12@K
# composition: H(-2)2H(8)13C(2)
custom_modifications:
# Dimethyl @K channel decoy
- name: Dimethyl:d12@K
Expand Down Expand Up @@ -130,9 +127,6 @@ calibration:
# the maximum number of times an automatic optimizer can be skipped before it is considered to have converged
max_skips: 1

# TODO: remove this parameter
final_full_calibration: False

# TODO: remove this parameter
norm_rt_mode: 'linear'

Expand Down Expand Up @@ -201,27 +195,27 @@ library_multiplexing:
# channels can be either a number or a string
# for every channel, the library gets copied and the modifications are translated according to the mapping
# the following example shows how to multiplex mTRAQ to three sample channels and a decoy channel
# TODO make this a list
# - channel_name: 0
# channel_modifications:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term
multiplex_mapping: {}
#0:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term

#4:
# mTRAQ@K: mTRAQ:13C(3)15N(1)@K
# mTRAQ@Any_N-term: mTRAQ:13C(3)15N(1)@Any_N-term
multiplex_mapping: []
# - channel_name: 0
# modifications:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term
#
# - channel_name: 4
# modifications:
# mTRAQ@K: mTRAQ:13C(3)15N(1)@K
# mTRAQ@Any_N-term: mTRAQ:13C(3)15N(1)@Any_N-term
#
# - channel_name: 8
# modifications:
# mTRAQ@K: mTRAQ:13C(6)15N(2)@K
# mTRAQ@Any_N-term: mTRAQ:13C(6)15N(2)@Any_N-term
#
# - channel_name: 12
# modifications:
# mTRAQ@K: mTRAQ:d12@K
# mTRAQ@Any_N-term: mTRAQ:d12@Any_N-term

#8:
# mTRAQ@K: mTRAQ:13C(6)15N(2)@K
# mTRAQ@Any_N-term: mTRAQ:13C(6)15N(2)@Any_N-term

#12:
# mTRAQ@K: mTRAQ:d12@K
# mTRAQ@Any_N-term: mTRAQ:d12@Any_N-term



Expand Down Expand Up @@ -388,58 +382,6 @@ transfer_learning:
instrument: 'Lumos'


# configuration for the calibration manager
# the config has to start with the calibration keyword and consists of a list of calibration groups.
# each group consists of datapoints which have multiple properties.
# This can be for example precursors (mz, rt ...), fragments (mz, ...), quadrupole (transfer_efficiency)
calibration_manager: # TODO move to a separate file or hard-code
- name: fragment
estimators:
- name: mz
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mz_library
target_columns:
- mz_observed
output_columns:
- mz_calibrated
transform_deviation: 1e6
- name: precursor
estimators:
- name: mz
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mz_library
target_columns:
- mz_observed
output_columns:
- mz_calibrated
transform_deviation: 1e6
- name: rt
model: LOESSRegression
model_args:
n_kernels: 6
input_columns:
- rt_library
target_columns:
- rt_observed
output_columns:
- rt_calibrated
- name: mobility
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mobility_library
target_columns:
- mobility_observed
output_columns:
- mobility_calibrated

# scope of default yaml should be one search step
multistep_search:
transfer_step_enabled: False
Expand Down
12 changes: 10 additions & 2 deletions alphadia/libtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,12 +617,20 @@ def forward(self, input: SpecLibBase) -> SpecLibBase:


class MultiplexLibrary(ProcessingStep):
def __init__(self, multiplex_mapping: dict, input_channel: str | int | None = None):
def __init__(self, multiplex_mapping: list, input_channel: str | int | None = None):
"""Initialize the MultiplexLibrary step."""

self._multiplex_mapping = multiplex_mapping
self._multiplex_mapping = self._create_multiplex_mapping(multiplex_mapping)
self._input_channel = input_channel

@staticmethod
def _create_multiplex_mapping(multiplex_mapping: list) -> dict:
"""Create a dictionary from the multiplex mapping list."""
mapping = {}
for list_item in multiplex_mapping:
mapping[list_item["channel_name"]] = list_item["modifications"]
return mapping

def validate(self, input: str) -> bool:
"""Validate the input object. It is expected that the input is a path to a file which exists."""
valid = True
Expand Down
6 changes: 3 additions & 3 deletions alphadia/search_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _init_config(

config_updates = []

if user_config is not None:
if user_config:
logger.info("loading additional config provided via CLI")
# load update config from dict
if isinstance(user_config, dict):
Expand All @@ -108,13 +108,13 @@ def _init_config(
"'config' parameter must be of type 'dict' or 'Config'"
)

if cli_config is not None:
if cli_config:
logger.info("loading additional config provided via CLI parameters")
cli_config_update = Config(cli_config, name=USER_DEFINED_CLI_PARAM)
config_updates.append(cli_config_update)

# this needs to be last
if extra_config is not None:
if extra_config:
extra_config_update = Config(extra_config, name=MULTISTEP_SEARCH)
# need to overwrite user-defined output folder here to have correct value in config dump
extra_config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder
Expand Down
1 change: 0 additions & 1 deletion alphadia/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def load(

# initialize the calibration manager
self._calibration_manager = manager.CalibrationManager(
self.config["calibration_manager"],
path=os.path.join(self.path, self.CALIBRATION_MANAGER_PKL_NAME),
load_from_file=self.config["general"]["reuse_calibration"],
reporter=self.reporter,
Expand Down
59 changes: 53 additions & 6 deletions alphadia/workflow/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,58 @@

# TODO move all managers to dedicated modules

# configuration for the calibration manager
# the config has to start with the calibration keyword and consists of a list of calibration groups.
# each group consists of datapoints which have multiple properties.
# This can be for example precursors (mz, rt ...), fragments (mz, ...), quadrupole (transfer_efficiency)
# TODO simplify this structure and the config loading
CALIBRATION_MANAGER_CONFIG = [
{
"estimators": [
{
"input_columns": ["mz_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mz",
"output_columns": ["mz_calibrated"],
"target_columns": ["mz_observed"],
"transform_deviation": "1e6",
}
],
"name": "fragment",
},
{
"estimators": [
{
"input_columns": ["mz_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mz",
"output_columns": ["mz_calibrated"],
"target_columns": ["mz_observed"],
"transform_deviation": "1e6",
},
{
"input_columns": ["rt_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 6},
"name": "rt",
"output_columns": ["rt_calibrated"],
"target_columns": ["rt_observed"],
},
{
"input_columns": ["mobility_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mobility",
"output_columns": ["mobility_calibrated"],
"target_columns": ["mobility_observed"],
},
],
"name": "precursor",
},
]


class BaseManager:
def __init__(
Expand Down Expand Up @@ -156,7 +208,6 @@ def fit_predict(self):
class CalibrationManager(BaseManager):
def __init__(
self,
config: None | dict = None,
path: None | str = None,
load_from_file: bool = True,
**kwargs,
Expand All @@ -167,10 +218,6 @@ def __init__(
Parameters
----------
config : typing.Union[None, dict], default=None
Calibration config dict. If None, the default config is used.
path : str, default=None
Path where the current parameter set is saved to and loaded from.
Expand All @@ -186,7 +233,7 @@ def __init__(

if not self.is_loaded_from_file:
self.estimator_groups = []
self.load_config(config)
self.load_config(CALIBRATION_MANAGER_CONFIG)

@property
def estimator_groups(self):
Expand Down
15 changes: 9 additions & 6 deletions tests/unit_tests/test_libtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,14 @@ def test_multiplex_library():
test_lib.calc_precursor_mz()
test_lib.calc_fragment_mz_df()

test_multiplex_mapping = {
0: {"mTRAQ@K": "mTRAQ@K"},
"magic_chanel": {"mTRAQ@K": "mTRAQ:13C(3)15N(1)@K"},
1337: {"mTRAQ@K": "mTRAQ:13C(6)15N(2)@K"},
}
test_multiplex_mapping = [
{"channel_name": 0, "modifications": {"mTRAQ@K": "mTRAQ@K"}},
{
"channel_name": "magic_channel",
"modifications": {"mTRAQ@K": "mTRAQ:13C(3)15N(1)@K"},
},
{"channel_name": 1337, "modifications": {"mTRAQ@K": "mTRAQ:13C(6)15N(2)@K"}},
]

# when
multiplexer = libtransform.MultiplexLibrary(test_multiplex_mapping)
Expand All @@ -116,7 +119,7 @@ def test_multiplex_library():
assert result_lib.precursor_df["charge"].nunique() == 2
assert result_lib.precursor_df["frag_stop_idx"].nunique() == 6

for channel in [0, 1337, "magic_chanel"]:
for channel in [0, 1337, "magic_channel"]:
assert (
result_lib.precursor_df[
result_lib.precursor_df["channel"] == channel
Expand Down
44 changes: 30 additions & 14 deletions tests/unit_tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_base_manager_load():
os.remove(my_base_manager.path)


TEST_CONFIG = [
TEST_CALIBRATION_MANAGER_CONFIG = [
{
"name": "precursor",
"estimators": [
Expand Down Expand Up @@ -90,9 +91,13 @@ def test_base_manager_load():
def test_calibration_manager_init():
# initialize the calibration manager
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

assert calibration_manager.path == temp_path
assert calibration_manager.is_loaded_from_file is False
Expand Down Expand Up @@ -158,9 +163,13 @@ def calibration_testdata():

def test_calibration_manager_fit_predict():
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

test_df = calibration_testdata()

Expand All @@ -182,9 +191,13 @@ def test_calibration_manager_fit_predict():

def test_calibration_manager_save_load():
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

test_df = calibration_testdata()
calibration_manager.fit(test_df, "precursor", plot=False)
Expand All @@ -195,9 +208,13 @@ def test_calibration_manager_save_load():

calibration_manager.save()

calibration_manager_loaded = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=True
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager_loaded = manager.CalibrationManager(
path=temp_path, load_from_file=True
)
assert calibration_manager_loaded.is_fitted is True
assert calibration_manager_loaded.is_loaded_from_file is True

Expand Down Expand Up @@ -433,7 +450,6 @@ def create_workflow_instance():
]
)
workflow._calibration_manager = manager.CalibrationManager(
workflow.config["calibration_manager"],
path=os.path.join(workflow.path, workflow.CALIBRATION_MANAGER_PKL_NAME),
load_from_file=workflow.config["general"]["reuse_calibration"],
reporter=workflow.reporter,
Expand Down

0 comments on commit 730ece7

Please sign in to comment.