Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added autogluon support, more models, more preprocessing strategies #81

Merged
merged 65 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5fde57a
added autogluon support
Oufattole Aug 19, 2024
d6832cb
updates for autogluon
teyaberg Aug 19, 2024
0612730
[wip] filtering features
teyaberg Aug 20, 2024
2feee79
[wip] filtering features
teyaberg Aug 20, 2024
f3c985a
[wip] sharing for updates only
teyaberg Aug 20, 2024
b65754c
[wip] sharing for updates only
teyaberg Aug 20, 2024
a8d8417
[wip] doctests
teyaberg Aug 20, 2024
d07f6a2
autogluon
teyaberg Aug 20, 2024
2aebd70
added logged warning for static data being empty and added support fo…
Oufattole Aug 20, 2024
8c54317
Merge branch 'generalized_load_model' into dev
Oufattole Aug 20, 2024
ecf9292
Added support via hydra for selecting among four imputation methods (…
Oufattole Aug 21, 2024
e6cf085
fixed xgboost model yaml to load imputer and normalization from the m…
Oufattole Aug 21, 2024
94dfde2
added autogluon test and cli support
Oufattole Aug 21, 2024
527eda5
added three more sklearn models and fixed bug with normalzation and i…
Oufattole Aug 21, 2024
0d7ed27
fixed bugs so correlation code filters work now
Oufattole Aug 21, 2024
9c542ea
sweeper
teyaberg Aug 21, 2024
1a519ff
logging
teyaberg Aug 21, 2024
8fc8863
made tash caching parallelize and updated tests for configs
Oufattole Aug 21, 2024
5724d9b
Merge branch 'dev' of github.com:mmcdermott/MEDS_Tabular_AutoML into dev
Oufattole Aug 21, 2024
3e223bb
added more thourough tests for output file paths of task caching and …
Oufattole Aug 22, 2024
926732b
Merge branch 'main' into dev
Oufattole Aug 25, 2024
299bf6f
setup dynamic versioning
Oufattole Aug 25, 2024
8a7692a
version updates
teyaberg Sep 5, 2024
158b8fa
version updates
teyaberg Sep 6, 2024
e92049f
fix hydra-core version for experimental callback support
teyaberg Sep 6, 2024
0623aaa
eval callback logging
teyaberg Sep 6, 2024
e1be850
added script input args checks, reduced redundancy in model launcher …
Oufattole Sep 7, 2024
0e985ee
eval callback
teyaberg Sep 7, 2024
139870f
eval callback
teyaberg Sep 7, 2024
0d5e9e8
Updated pre-commit config too.
mmcdermott Sep 8, 2024
2563aaf
Removed a function that was not yet implemented.
mmcdermott Sep 8, 2024
2d80905
Removing unused function in evaluation callback.
mmcdermott Sep 8, 2024
d29ece9
eval callback
teyaberg Sep 8, 2024
81b022f
added yaml hierarchy for model_launcher
Oufattole Sep 8, 2024
57a4a81
updated configs, fixed most tests
Oufattole Sep 9, 2024
b704bba
Merged
mmcdermott Sep 9, 2024
2f564e6
Removed unused pass block.
mmcdermott Sep 9, 2024
6f68a4b
Removing unnecessary keys call
mmcdermott Sep 9, 2024
6c2ba9a
Fixed workflow files
mmcdermott Sep 9, 2024
e678145
fixed tabularize tests
Oufattole Sep 9, 2024
d64e237
added integration tests covering multirun for all launch_model models…
Oufattole Sep 9, 2024
8d12aed
merged dev
Oufattole Sep 9, 2024
c631e93
fixed tests
Oufattole Sep 9, 2024
2601fca
Merge pull request #90 from mmcdermott/configs
Oufattole Sep 9, 2024
a4ad03c
resolved review feedback. Added a based_model docstring. Added versio…
Oufattole Sep 9, 2024
0db7bd6
fixed min_code_inclusion_frequency kwarg
Oufattole Sep 9, 2024
b289033
added mimic iv tutorial
Oufattole Sep 9, 2024
9294920
updated tabularization script to fix bugs
Oufattole Sep 9, 2024
d71f9dc
reduced the number of workers for resharding
Oufattole Sep 9, 2024
aed27f1
Merged.
mmcdermott Sep 9, 2024
0dc2bc6
updated tabularize meds to take string input for tasks
Oufattole Sep 9, 2024
c981534
Merge pull request #91 from mmcdermott/improve_test_coverage
mmcdermott Sep 9, 2024
2aa4feb
Improved error handling per https://github.com/mmcdermott/MEDS_Tabula…
mmcdermott Sep 9, 2024
a6d9103
Update README.md
mmcdermott Sep 9, 2024
23eb4d4
added try except around loading 0 codes
Oufattole Sep 10, 2024
be5f723
fixed job name config bug where we were missing the $ so it was not …
Oufattole Sep 10, 2024
4c87e94
Merge branch 'dev' into MIMICIV
Oufattole Sep 10, 2024
d390658
fixed precommit issues
Oufattole Sep 10, 2024
b82ee6d
Merge branch 'dev' into MIMICIV
Oufattole Sep 10, 2024
a564886
fix paths for eval_callback and add check to test_integration
teyaberg Sep 10, 2024
430afba
fixing tests for delete_below_top_k
teyaberg Sep 10, 2024
6a89a9f
Merge pull request #92 from mmcdermott/MIMICIV
Oufattole Sep 10, 2024
9e6d99a
fix out of memory xgboost training and added test
teyaberg Sep 10, 2024
8316365
simplified pathing for results and evaluation callback
Oufattole Sep 10, 2024
f7e03dd
fixed doctest for deleting below top k models
Oufattole Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
]
description = "Scalable Tabularization of MEDS format Time-Series data"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand All @@ -17,7 +17,6 @@ classifiers = [
dependencies = [
"polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
Oufattole marked this conversation as resolved.
Show resolved Hide resolved
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3",
"MEDS-transforms==0.0.5",
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
Oufattole marked this conversation as resolved.
Show resolved Hide resolved
]

[project.scripts]
Expand All @@ -33,6 +32,7 @@ generate-subsets = "MEDS_tabular_automl.scripts.generate_subsets:main"
dev = ["pre-commit"]
tests = ["pytest", "pytest-cov", "rootutils"]
profiling = ["mprofile", "matplotlib"]
autogluon = ["autogluon; python_version=='3.11.*'"] # Environment marker to restrict AutoGluon to Python 3.11

[build-system]
requires = ["setuptools>=61.0", "setuptools-scm>=8.0", "wheel"]
Expand Down
236 changes: 236 additions & 0 deletions src/MEDS_tabular_automl/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from pathlib import Path

import hydra
import numpy as np
import scipy.sparse as sp
from loguru import logger
from mixins import TimeableMixin
from omegaconf import DictConfig
from sklearn.metrics import roc_auc_score

from .tabular_dataset import TabularDataset


class BaseIterator(TabularDataset, TimeableMixin):
"""BaseIterator class for loading and processing data shards for use in SciKit-Learn models.

This class provides functionality for iterating through data shards, loading
feature data and labels, and processing them based on the provided configuration.

Args:
cfg: A configuration dictionary containing parameters for
data processing, feature selection, and other settings.
split: The data split to use, which can be one of "train", "tuning",
or "held_out". This determines which subset of the data is loaded and processed.

Attributes:
cfg: Configuration dictionary containing parameters for
data processing, feature selection, and other settings.
file_name_resolver: Object for resolving file names and paths based on the configuration.
split: The data split being used for loading and processing data shards.
_data_shards: List of data shard names.
valid_event_ids: Dictionary mapping shard number to a list of valid event IDs.
labels: Dictionary mapping shard number to a list of labels for the corresponding event IDs.
codes_set: Set of codes to include in the data.
code_masks: Dictionary of code masks for filtering features based on aggregation.
num_features: Total number of features in the data.
"""

def __init__(self, cfg: DictConfig, split: str):
"""Initializes the BaseIterator with the provided configuration and data split.

Args:
cfg: The configuration dictionary.
split: The data split to use.
"""
TabularDataset.__init__(self, cfg=cfg, split=split)
TimeableMixin.__init__(self)
self.valid_event_ids, self.labels = self._load_ids_and_labels()
# check if the labels are empty
if len(self.labels) == 0:
raise ValueError("No labels found.")
# self._it = 0

# def iterate(self, function):
# for shard_idx in range(len(self._data_shards)):
# data, labels = self.get_data_shards(shard_idx)
# function(data, labels)


class BaseMatrix(TimeableMixin):
"""BaseMatrix class for loading and processing data shards for use in SciKit-Learn models."""

def __init__(self, data: sp.csr_matrix, labels: np.ndarray):
"""Initializes the BaseMatrix with the provided configuration and data split.

Args:
data
"""
super().__init__()
self.data = data
self.labels = labels

def get_data(self):
return self.data

def get_label(self):
return self.labels


class BaseModel(TimeableMixin):
"""Class for configuring, training, and evaluating an SciKit-Learn model.

This class utilizes the configuration settings provided to manage the training and evaluation
process of an XGBoost model, ensuring the model is trained and validated using specified parameters
and data splits. It supports training with in-memory data handling as well as direct streaming from
disk using iterators.

Args:
cfg: The configuration settings for the model, including data paths, model parameters,
and flags for data handling.

Attributes:
cfg: Configuration object containing all settings required for model operation.
model: The XGBoost model after being trained.
dtrain: The training dataset in DMatrix format.
dtuning: The tuning (validation) dataset in DMatrix format.
dheld_out: The held-out (test) dataset in DMatrix format.
itrain: Iterator for the training dataset.
ituning: Iterator for the tuning dataset.
iheld_out: Iterator for the held-out dataset.
keep_data_in_memory: Flag indicating whether to keep all data in memory or stream from disk.
"""

def __init__(self, cfg: DictConfig):
"""Initializes the XGBoostClassifier with the provided configuration.

Args:
cfg: The configuration dictionary.
"""
self.cfg = cfg
self.keep_data_in_memory = cfg.model_params.iterator.keep_data_in_memory

self.itrain = None
self.ituning = None
self.iheld_out = None

self.dtrain = None
self.dtuning = None
self.dheld_out = None

self.model = hydra.utils.call(cfg.model_params.model)
# check that self.model is a valid model
if not hasattr(self.model, "fit"):
raise ValueError("Model does not have a fit method.")

@TimeableMixin.TimeAs
def _build_data(self):
"""Builds necessary data structures for training."""
if self.keep_data_in_memory:
self._build_iterators()
self._build_matrix_in_memory()
else:
self._build_iterators()

def _fit_from_partial(self):
"""Fits model until convergence or maximum epochs."""
if not hasattr(self.model, "partial_fit"):
raise ValueError(
f"Data is loaded in shards, but {self.model.__class__.__name__} does not support partial_fit."
)
classes = self.itrain.get_classes()
best_auc = 0
best_epoch = 0
for epoch in range(self.cfg.model_params.epochs):
# train on each all data
for shard_idx in range(len(self.itrain._data_shards)):
data, labels = self.itrain.get_data_shards(shard_idx)
# if self.model.shuffle: # TODO: check this for speed
# # shuffle data
# indices = np.random.permutation(len(labels))
# data = data[indices]
# labels = labels[indices]
self.model.partial_fit(data, labels, classes=classes)
# evaluate on tuning set
auc = self.evaluate()
# early stopping
if auc > best_auc:
best_auc = auc
best_epoch = epoch
if epoch - best_epoch > self.cfg.model_params.early_stopping_rounds:
break

@TimeableMixin.TimeAs
def _train(self):
"""Trains the model."""
# two cases: data is in memory or data is streamed
if self.keep_data_in_memory:
self.model.fit(self.dtrain.get_data(), self.dtrain.get_label())
else:
self._fit_from_partial()

@TimeableMixin.TimeAs
def train(self):
"""Trains the model."""
self._build_data()
self._train()

@TimeableMixin.TimeAs
def _build_matrix_in_memory(self):
"""Builds the DMatrix from the data in memory."""
self.dtrain = BaseMatrix(*self.itrain.get_data())
self.dtuning = BaseMatrix(*self.ituning.get_data())
self.dheld_out = BaseMatrix(*self.iheld_out.get_data())

@TimeableMixin.TimeAs
def _build_iterators(self):
"""Builds the iterators for training, validation, and testing."""
self.itrain = BaseIterator(self.cfg, split="train")
self.ituning = BaseIterator(self.cfg, split="tuning")
self.iheld_out = BaseIterator(self.cfg, split="held_out")

@TimeableMixin.TimeAs
def evaluate(self) -> float:
"""Evaluates the model on the tuning set.

Returns:
The evaluation metric as the ROC AUC score.
"""
# check if model has predict_proba method
if not hasattr(self.model, "predict_proba"):
raise ValueError(f"Model {self.model.__class__.__name__} does not have a predict_proba method.")
# two cases: data is in memory or data is streamed
if self.keep_data_in_memory:
y_pred = self.model.predict_proba(self.dtuning.get_data())[:, 1]
y_true = self.dtuning.get_label()
else:
y_pred = []
y_true = []
for shard_idx in range(len(self.ituning._data_shards)):
data, labels = self.ituning.get_data_shards(shard_idx)
y_pred.extend(self.model.predict_proba(data)[:, 1])
y_true.extend(labels)
y_pred = np.array(y_pred)
y_true = np.array(y_true)
# check if y_pred and y_true are not empty
if len(y_pred) == 0 or len(y_true) == 0:
raise ValueError("Predictions or true labels are empty.")
return roc_auc_score(y_true, y_pred)

def save_model(self, output_fp: str):
"""Saves the model to the specified file path.

Args:
output_fp: The file path to save the model to.
"""
output_fp = Path(output_fp)
# check if model has save method
if not hasattr(self.model, "save_model"):
logger.info(f"Model {self.model.__class__.__name__} does not have a save_model method.")
logger.info("Model will be saved using pickle dump.")
from pickle import dump

with open(output_fp.parent / "model.pkl", "wb") as f:
dump(self.model, f, protocol=5)
else:
self.model.save_model(output_fp)
33 changes: 33 additions & 0 deletions src/MEDS_tabular_automl/configs/launch_basemodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
defaults:
- default
- tabularization: default
- override hydra/sweeper: optuna
- override hydra/sweeper/sampler: tpe
- override hydra/launcher: joblib
- _self_

task_name: task

# Task cached data dir
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${output_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json

# Model parameters
model_params:
epochs: 20
early_stopping_rounds: 5
model:
_target_: sklearn.linear_model.SGDClassifier
loss: log_loss
# n_iter: ${model_params.epochs} # not sure if we want this behaviour
iterator:
keep_data_in_memory: True
binarize_task: True

log_dir: ${model_dir}/.logs/

name: launch_basemodel
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/mapper.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to just see if you can import these functions from MEDS-Transforms? I guess that would conflict with the python version change, though...

Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import TypeVar

from loguru import logger

LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f"
DF_T = TypeVar("DF_T")


def get_earliest_lock(cache_directory: Path) -> datetime | None:
Expand Down Expand Up @@ -82,9 +84,7 @@ def register_lock(cache_directory: Path) -> tuple[datetime, Path]:
return lock_time, lock_fp


def wrap[
DF_T
](
def wrap(
in_fp: Path,
out_fp: Path,
read_fn: Callable[[Path], DF_T],
Expand Down
58 changes: 58 additions & 0 deletions src/MEDS_tabular_automl/scripts/launch_basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

import hydra
from loguru import logger
from omegaconf import DictConfig

from ..base_model import BaseModel
from ..utils import hydra_loguru_init

# config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_basemodel.yaml")
# if not config_yaml.is_file():
# raise FileNotFoundError("Core configuration not successfully installed!")
config_yaml = Path("./src/MEDS_tabular_automl/configs/launch_basemodel.yaml")
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
def main(cfg: DictConfig) -> float:
"""Optimizes the model based on the provided configuration.

Args:
cfg: The configuration dictionary specifying model and training parameters.

Returns:
The evaluation result as the ROC AUC score on the held-out test set.
"""

# print(OmegaConf.to_yaml(cfg))
if not cfg.loguru_init:
hydra_loguru_init()
try:
model = BaseModel(cfg)
model.train()
auc = model.evaluate()
logger.info(f"AUC: {auc}")

# print(
# "Time Profiling for window sizes ",
# f"{cfg.tabularization.window_sizes} and min ",
# f"code frequency of {cfg.tabularization.min_code_inclusion_frequency}:",
# )
# print("Train Time: \n", model._profile_durations())
# print("Train Iterator Time: \n", model.itrain._profile_durations())
# print("Tuning Iterator Time: \n", model.ituning._profile_durations())
# print("Held Out Iterator Time: \n", model.iheld_out._profile_durations())

# save model
output_fp = Path(cfg.output_filepath)
output_fp.parent.mkdir(parents=True, exist_ok=True)

model.save_model(output_fp)
except Exception as e:
logger.error(f"Error occurred: {e}")
auc = 0.0
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
return auc


if __name__ == "__main__":
main()
Loading
Loading