-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Gil Ramot
committed
Feb 15, 2024
0 parents
commit 81d12ca
Showing
16 changed files
with
1,008 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Gil Ramot, Tel Aviv University | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# modelcomp: comparison tool for machine learning models | ||
|
||
[![PyPI](https://img.shields.io/pypi/v/modelcomp)](https://pypi.org/project/modelcomp/) | ||
![License](https://img.shields.io/github/license/gilramot/modelcomp) | ||
![Downloads](https://img.shields.io/pypi/dm/modelcomp) | ||
[![PyPI pyversions](https://img.shields.io/pypi/pyversions/modelcomp)](https://pypi.org/pypi/modelcomp/) | ||
|
||
modelcomp is a python package that helps you compare between machine learning models' performance on your dataset. It was originally developed for a microbiome research at [Borenstein Lab](http://borensteinlab.com/), Tel Aviv University. | ||
|
||
Thanks to [Alpha program](https://www.madaney.net/en/alpha) for making this happen. | ||
|
||
|
||
## Dependencies | ||
|
||
modelcomp currently supports Python 3.6+. | ||
|
||
Check out the [reqs file](https://github.com/gilramot/modelcomp/blob/master/requirements.txt) for additional package requirements. | ||
|
||
## Installation | ||
|
||
The latest stable release (and required dependencies) can be installed from [PyPI](https://pypi.org/project/modelcomp/): | ||
|
||
pip install modelcomp | ||
|
||
The current version (0.0.1a1) is the version used in the research, and thus all pre-0.0.1 versions are very unstable and may be subject to backwards incompatible changes. | ||
|
||
Anaconda support coming soon! | ||
|
||
|
||
## Contributing | ||
|
||
Feel free to [report an issue](https://github.com/gilramot/modelcomp/issues/new) in the package repository. | ||
|
||
|
||
## Research | ||
|
||
An appendix with the dataset used, results & module is available at [modelcomp-appendix](https://github.com/gilramot/modelcomp-appendix). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import modelcomp.models | ||
from modelcomp.analysis import ( | ||
cross_val_models, | ||
std_validation_models, | ||
get_fprtpr, | ||
get_pr | ||
) | ||
from modelcomp.plotter import ( | ||
individual_plots, | ||
general_plots | ||
) | ||
from modelcomp.read import read_data | ||
from modelcomp.utilities import ( | ||
model_names, | ||
model_names_short, | ||
model_names_dict, | ||
remove_falsy_columns, | ||
split_data, | ||
split_array, | ||
get_feature_importance, | ||
merge_dfs, | ||
relative_abundance, | ||
get_label_indexes, | ||
get_k_fold, | ||
get_models, | ||
filter_data, | ||
remove_rare_species, | ||
remove_string_columns, | ||
make_dir, | ||
join_save, | ||
data_to_filename, | ||
filename_to_data, | ||
) | ||
from modelcomp.write import write_data, write_plot | ||
|
||
modelcomp.model_names_dict = dict(zip(modelcomp.utilities.model_names, modelcomp.utilities.model_names_short)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import shap | ||
from sklearn import metrics | ||
from sklearn.metrics import roc_curve, precision_recall_curve | ||
from sklearn.neighbors import KNeighborsClassifier | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.svm import SVC | ||
|
||
import modelcomp as mcp | ||
|
||
|
||
def get_fprtpr(model, X, y, pos_num): | ||
""" | ||
Calculates the fpr & tpr of a model based on the train & test data | ||
:param model: The trained model | ||
:param X: Testing data | ||
:param y: Binary labels | ||
:param pos_num: Positive class label | ||
:return: fpr & tpr values | ||
""" | ||
fpr, tpr, _ = roc_curve(y, model.predict_proba(X)[:, 1], pos_label=pos_num, drop_intermediate=False) | ||
return fpr, tpr | ||
# returning fpr, tpr of roc_curve | ||
|
||
|
||
def get_pr(model, X, y, pos_num): | ||
""" | ||
Calculates the precision & recall of a model based on the train & test data | ||
:param model: The trained model | ||
:param X: Testing data | ||
:param y: Binary labels | ||
:param pos_num: Positive class label | ||
:return: precision & recall values | ||
""" | ||
precision, recall, _ = precision_recall_curve(y, model.predict_proba(X)[:, 1], pos_label=pos_num) | ||
return precision, recall | ||
# returning precision, recall of rp_curve | ||
|
||
|
||
def std_validation_models(models, X_train, X_test, y_train, y_test, tested_on, trained_on, feature_names, validate=True, | ||
explain=True, plot=True): | ||
""" | ||
Standard validation of models (used when the positive class label differs between the training and testing data) | ||
:param models: A list of models to evaluate | ||
:param X_train: Training data | ||
:param X_test: Testing data | ||
:param y_train: Training data labels | ||
:param y_test: Testing data labels | ||
:param tested_on: Tested label | ||
:param trained_on: Trained label | ||
:param feature_names: List of feature names | ||
:param validate: Validate the models (default: True) | ||
:param explain: Explain the models (default: True) | ||
:param plot: Plot the results (default: True) | ||
:return: Model results, exported to the filesystem (default: "export") | ||
""" | ||
X_train, X_test, y_train, y_test = X_train.to_numpy(), X_test.to_numpy(), y_train.to_numpy().ravel(), y_test.to_numpy().ravel() | ||
mean_fpr = np.linspace(0, 1, 100) | ||
# init | ||
for model_index, model in enumerate(models): | ||
X_train_temp, X_test_temp = X_train, X_test | ||
interp_tpr, interp_recall, aucs, pr_aucs = None, None, None, None | ||
feature_importances, shap_values = None, None | ||
save_to_unjoined = mcp.data_to_filename(tested_on, mcp.model_names[model_index], | ||
trained_on=trained_on) | ||
# in-loop init | ||
if validate: | ||
aucs = [] | ||
pr_aucs = [] | ||
if type(model) is SVC: | ||
sc = StandardScaler() | ||
X_train = sc.fit_transform(X_train) | ||
X_test = sc.transform(X_test) | ||
# scaling if svm | ||
model.fit(X_train, y_train) | ||
# fitting model | ||
fpr, tpr = get_fprtpr(model, X_test, y_test, 1) | ||
interp_tpr = np.interp(mean_fpr, fpr, tpr) | ||
interp_tpr[0] = 0.0 | ||
aucs.append(metrics.auc(fpr, tpr)) | ||
# roc curve variables | ||
precision, recall = get_pr(model, X_test, y_test.ravel(), 1) | ||
interp_recall = np.interp(mean_fpr, recall[::-1].ravel(), precision[::-1].ravel()) | ||
interp_recall[0] = 1.0 | ||
pr_aucs.append(metrics.auc(recall, precision)) | ||
# pr curve variables | ||
if explain: | ||
feature_importances = pd.DataFrame(mcp.get_feature_importance(model).T, index=feature_names, | ||
columns=['Importance']) if type( | ||
model) is not KNeighborsClassifier else None | ||
# feature importance | ||
shap_values = (shap.explainers.Permutation(model.predict, X_test, max_evals=1000).shap_values(X_test)) | ||
# shap values | ||
mcp.write_data(save_to_unjoined, feature_names, interp_tpr, interp_recall, aucs, pr_aucs, | ||
feature_importances=feature_importances, shap_values=shap_values) | ||
# exporting data | ||
X_train, X_test = X_train_temp, X_test_temp | ||
if plot: | ||
mcp.individual_plots(save_to_unjoined) | ||
# plotting data | ||
|
||
|
||
def cross_val_models(models, validation_model, X, y, positive_label, feature_names, validate=True, | ||
explain=True, plot=True): | ||
""" | ||
Cross validation of multiple models | ||
:param models: List of models to evaluate | ||
:param validation_model: Validation model to evaluate with | ||
:param X: Features | ||
:param y: Labels | ||
:param positive_label: Positive label | ||
:param feature_names: List of feature names | ||
:param validate: Validate the models (default: True) | ||
:param explain: Explain the models (default: True) | ||
:param plot: Plot the results (default: True) | ||
:return: Model results, exported to the filesystem (default: "export") | ||
""" | ||
mean_fpr = np.linspace(0, 1, 100) | ||
# init | ||
for model_index, model in enumerate(models): | ||
feature_importances, shap_values = None, None | ||
if validate: | ||
feature_importances_per_fold = [] | ||
interp_tpr_per_fold = [] | ||
aucs = [] | ||
interp_recall_per_fold = [] | ||
pr_aucs = [] | ||
fprs = [] | ||
tprs = [] | ||
precisions = [] | ||
recalls = [] | ||
shap_values = None | ||
for split_index, (train, test) in enumerate(validation_model.split(X, y)): | ||
X_train_temp, X_test_temp = X[train], X[test] | ||
if type(model) is SVC: | ||
sc = StandardScaler() | ||
X[train] = sc.fit_transform(X[train]) | ||
X[test] = sc.transform(X[test]) | ||
# scaling if svm | ||
model.fit(X[train], y[train]) | ||
# fitting model | ||
fpr, tpr = get_fprtpr(model, X[test], y[test], 1) | ||
fprs.append(fpr) | ||
tprs.append(tpr) | ||
interp_tpr = np.interp(mean_fpr, fpr, tpr) | ||
interp_tpr[0] = 0.0 | ||
interp_tpr_per_fold.append(interp_tpr) | ||
aucs.append(metrics.auc(fpr, tpr)) | ||
# roc curve variables | ||
precision, recall = get_pr(model, X[test], y[test].ravel(), 1) | ||
precisions.append(precision) | ||
recalls.append(recall) | ||
interp_recall = np.interp(mean_fpr, recall[::-1].ravel(), precision[::-1].ravel()) | ||
interp_recall[0] = 1.0 | ||
interp_recall_per_fold.append(interp_recall) | ||
pr_aucs.append(metrics.auc(recall, precision)) | ||
# pr curve variables | ||
if explain: | ||
feature_importances_per_fold.append( | ||
pd.DataFrame(mcp.get_feature_importance(model).T, index=feature_names, | ||
columns=['Importance']) if type(model) is not KNeighborsClassifier else None) | ||
if feature_importances_per_fold[0] is None: | ||
feature_importances = None | ||
else: | ||
feature_importances = feature_importances_per_fold[0] | ||
for feature_importances_in_fold in feature_importances_per_fold[:1]: | ||
feature_importances = feature_importances.add(feature_importances_in_fold, fill_value=0) | ||
feature_importances['Importance'] = feature_importances['Importance'].map( | ||
lambda old_value: old_value / len(feature_importances_per_fold)) | ||
# feature importance | ||
shap_values_temp = shap.explainers.Permutation(model.predict, X[test], max_evals=1000).shap_values( | ||
X[test]) | ||
if shap_values is None: | ||
shap_values = shap_values_temp | ||
else: | ||
shap_values = np.append(shap_values, shap_values_temp, axis=0) | ||
# shap values | ||
X[train], X[test] = X_train_temp, X_test_temp | ||
mcp.write_data( | ||
mcp.data_to_filename(positive_label, mcp.model_names[model_index]), | ||
feature_names, | ||
interp_tpr_per_fold, interp_recall_per_fold, aucs, pr_aucs, fprs=fprs, tprs=tprs, | ||
precisions=precisions, recalls=recalls, feature_importances=feature_importances, | ||
shap_values=shap_values) | ||
# exporting data | ||
if plot: | ||
mcp.individual_plots( | ||
mcp.data_to_filename(positive_label, mcp.model_names[model_index])) | ||
# plotting data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from modelcomp.models.knn import get | ||
from modelcomp.models.logistic_regression import get | ||
from modelcomp.models.random_forest import get | ||
from modelcomp.models.svm import get | ||
from modelcomp.models.xgboost import get |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from sklearn.neighbors import KNeighborsClassifier | ||
|
||
|
||
def get(): | ||
""" | ||
Load a k-NN model from the scikit-learn library | ||
:return: a k-NN model | ||
""" | ||
return KNeighborsClassifier() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
|
||
def get(seed=None): | ||
""" | ||
Load a logistic regression model from the scikit-learn library | ||
:param seed: set seed for model randomness (default: None) | ||
:return: a logistic regression model | ||
""" | ||
return LogisticRegression(random_state=(seed if seed is not None else None), max_iter=10000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
|
||
def get(seed=None): | ||
""" | ||
Load a random forest from the scikit-learn library | ||
:param seed: set seed for model randomness (default: None) | ||
:return: a random forest model | ||
""" | ||
return RandomForestClassifier(random_state=(seed if seed is not None else None)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from sklearn.svm import SVC | ||
|
||
|
||
def get(seed=None): | ||
""" | ||
Load an svm model from the scikit-learn library | ||
:param seed: set seed for model randomness (default: None) | ||
:return: an svm model | ||
""" | ||
return SVC(random_state=(seed if seed is not None else None), kernel='linear', probability=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import xgboost as xgb | ||
|
||
|
||
def get(seed=None): | ||
""" | ||
Load an XGBoost model from the xgboost library | ||
:param seed: set seed for model randomness (default: None) | ||
:return: an xgboost model | ||
""" | ||
return xgb.XGBClassifier(random_state=(seed if seed is not None else None)) |
Oops, something went wrong.