Skip to content

Commit

Permalink
refactor(logs): revert to standard python logs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo4diani committed Nov 18, 2023
1 parent af2792c commit dae6c25
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 88 deletions.
41 changes: 0 additions & 41 deletions auton_survival/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,46 +473,5 @@ class is a composite transform that does both Imputing ***and*** Scaling with
from .models.cph import DeepCoxPH, DeepRecurrentCoxPH
from .models.cmhe import DeepCoxMixturesHeterogenousEffects

from loguru import logger
import warnings
from tqdm.auto import tqdm


logger.disable("auton_survival")


def enable_auton_logger(
add_logger=False, capture_warnings=False, log_level="INFO"
):
"""
Enable auton_survival logs
"""
logger.enable("auton_survival")
if add_logger:
_add_auton_logger(log_level)
if capture_warnings:
_enable_warnings_capture()


def _add_auton_logger(log_level):
"""
Enable auton_survival logs and add a minimal tqdm-friendly loguru logger
"""
logger.remove()

logger.add(
lambda msg: tqdm.write(msg, end=""), colorize=True, level=log_level
)


def _enable_warnings_capture():
"""
Defer warnings to loguru
"""
showwarning_ = warnings.showwarning

def showwarning(message, *args, **kwargs):
logger.warning(message)
showwarning_(message, *args, **kwargs)

warnings.showwarning = showwarning
8 changes: 5 additions & 3 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from sklearn.utils import shuffle

from tqdm import tqdm
from loguru import logger
import logging

logger = logging.getLogger(__name__)


class SurvivalRegressionCV:
Expand Down Expand Up @@ -160,11 +162,11 @@ def fit(self, features, outcomes, horizons, metric="ibs"):

hyper_param_scores = []
for i, hyper_param in enumerate(self.hyperparam_grid):
logger.info("At hyper-param: {}", hyper_param)
logger.info(f"At hyper-param: {hyper_param}")

fold_scores = []
for fold in set(self.folds):
logger.info("At fold: {}", fold)
logger.info(f"At fold: {fold}")
model = SurvivalModel(
self.model, random_seed=self.random_seed, **hyper_param
)
Expand Down
15 changes: 8 additions & 7 deletions auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
from scipy.optimize import fsolve
from sklearn.metrics import auc
from tqdm import tqdm
from loguru import logger
import logging
import warnings

logger = logging.getLogger(__name__)


def treatment_effect(
metric,
Expand Down Expand Up @@ -143,8 +145,7 @@ def treatment_effect(

if isinstance(n_bootstrap, int):
logger.info(
"Bootstrapping... {} number of times. This may take a while. Please be Patient...",
n_bootstrap,
f"Bootstrapping... {n_bootstrap} number of times. This may take a while. Please be Patient...",
)

is_treated = treatment_indicator.astype(float)
Expand Down Expand Up @@ -528,7 +529,7 @@ def _restricted_mean_diff(
control_weights,
size_bootstrap=1.0,
random_seed=None,
**kwargs
**kwargs,
):
"""Compute the difference in the area under the Kaplan Meier curve
(mean survival time) between control and treatment groups.
Expand Down Expand Up @@ -602,7 +603,7 @@ def _survival_at_diff(
interpolate=True,
size_bootstrap=1.0,
random_seed=None,
**kwargs
**kwargs,
):
"""Compute the difference in Kaplan Meier survival function estimates
between the control and treatment groups at a specified time horizon.
Expand Down Expand Up @@ -680,7 +681,7 @@ def _tar(
interpolate=True,
size_bootstrap=1.0,
random_seed=None,
**kwargs
**kwargs,
):
"""Time at Risk (TaR) measures time-to-event at a specified level
of risk.
Expand Down Expand Up @@ -749,7 +750,7 @@ def _hazard_ratio(
control_weights,
size_bootstrap=1.0,
random_seed=None,
**kwargs
**kwargs,
):
"""Train an instance of the Cox Proportional Hazards model and return the
exp(coefficients) (hazard ratios) of the model.
Expand Down
6 changes: 4 additions & 2 deletions auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"""

from loguru import logger
import logging
import numpy as np
import torch

Expand All @@ -85,6 +85,8 @@

from auton_survival.utils import _dataframe_to_array

logger = logging.getLogger(__name__)


class DeepCoxMixturesHeterogenousEffects:
"""A Deep Cox Mixtures with Heterogenous Effects model.
Expand Down Expand Up @@ -149,7 +151,7 @@ def __call__(self):
else:
logger.info("An unfitted instance of the CMHE model")

logger.info("Hidden Layers: {}", self.layers)
logger.info(f"Hidden Layers: {self.layers}")

def _preprocess_test_data(self, x, a=None):
x = _dataframe_to_array(x)
Expand Down
8 changes: 4 additions & 4 deletions auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from loguru import logger
import logging
import torch
import numpy as np

Expand All @@ -40,6 +40,8 @@
smooth_bl_survival,
)

logger = logging.getLogger(__name__)


def get_likelihood(model, breslow_splines, x, t, e, a):
# Function requires numpy/torch
Expand Down Expand Up @@ -300,9 +302,7 @@ def train_cmhe(

losses.append(valcn)

logger.debug(
"Patience: {} | Epoch: {} | Loss: {}", patience_, epoch, valcn
)
logger.debug(f"Patience: {patience_} | Epoch: {epoch} | Loss: {valcn}")

if valcn > valc:
patience_ += 1
Expand Down
13 changes: 8 additions & 5 deletions auton_survival/models/cph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
r""" Deep Cox Proportional Hazards Model"""

from collections import namedtuple
from loguru import logger
import logging
import torch
import numpy as np
import random
Expand All @@ -37,6 +37,9 @@
from auton_survival.models.utils.recurrent_nn_utils import _get_padded_targets


logger = logging.getLogger(__name__)


DcphModel = namedtuple("DcphModel", ["module", "breslow"])


Expand Down Expand Up @@ -127,7 +130,7 @@ def __call__(self):
else:
logger.info("An unfitted instance of the Deep Cox PH model")

logger.info("Hidden Layers: {}", self.layers)
logger.info(f"Hidden Layers: {self.layers}")

def _preprocess_test_data(self, x):
x = _dataframe_to_array(x)
Expand Down Expand Up @@ -209,7 +212,7 @@ def init_torch_model(
self.initialized = True
else:
logger.info(
f"""Early initialization selected. Model-specific `fit` parameters will be ignored."""
"Early initialization selected. Model-specific `fit` parameters will be ignored."
)

def fit(
Expand Down Expand Up @@ -324,7 +327,7 @@ def predict_time_independent_survival(
@torch.inference_mode()
def predict_time_independent_risk(self, x: torch.Tensor) -> torch.Tensor:
if not self.initialized:
logger.warning(
raise Exception(
"The PyTorch module has not been initialized yet. Please init the "
+ "model using the `init_torch_model` method on some training data "
+ "before calling `predict_time_independent_risk`."
Expand Down Expand Up @@ -418,7 +421,7 @@ def __call__(self):
"An unfitted instance of the Recurrent Deep Cox PH model"
)

logger.info("Hidden Layers: {}", self.layers)
logger.info(f"Hidden Layers: {self.layers}")

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
Expand Down
8 changes: 4 additions & 4 deletions auton_survival/models/cph/dcph_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import numpy as np
import pandas as pd
from loguru import logger
import logging

from sksurv.linear_model.coxph import BreslowEstimator

Expand All @@ -19,6 +19,8 @@
partial_ll_loss,
)

logger = logging.getLogger(__name__)


def fit_breslow(model, x, t, e):
return BreslowEstimator().fit(
Expand Down Expand Up @@ -119,9 +121,7 @@ def train_dcph(

dics.append(deepcopy(model.state_dict()))

logger.debug(
"Patience: {} | Epoch: {} | Loss: {}", patience_, epoch, valcn
)
logger.debug(f"Patience: {patience_} | Epoch: {epoch} | Loss: {valcn}")

if valcn > valc:
patience_ += 1
Expand Down
8 changes: 5 additions & 3 deletions auton_survival/models/dcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"""

from loguru import logger
import logging
import torch
import numpy as np

Expand All @@ -59,6 +59,8 @@

from auton_survival.utils import _dataframe_to_array

logger = logging.getLogger(__name__)


class DeepCoxMixtures:
"""A Deep Cox Mixture model.
Expand Down Expand Up @@ -116,8 +118,8 @@ def __call__(self):
else:
logger.info("An unfitted instance of the Deep Cox Mixtures model")

logger.info("Number of underlying cox distributions (k): {}", self.k)
logger.info("Hidden Layers: {}", self.layers)
logger.info(f"Number of underlying cox distributions (k): {self.k}")
logger.info(f"Hidden Layers: {self.layers}")

def _preprocess_test_data(self, x):
x = _dataframe_to_array(x)
Expand Down
8 changes: 4 additions & 4 deletions auton_survival/models/dcm/dcm_utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from loguru import logger
import logging
from matplotlib.pyplot import get

import torch
Expand All @@ -24,6 +24,8 @@
smooth_bl_survival,
)

logger = logging.getLogger(__name__)


def get_likelihood(model, breslow_splines, x, t, e):
# Function requires numpy/torch
Expand Down Expand Up @@ -263,9 +265,7 @@ def train_dcm(

losses.append(valcn)

logger.debug(
"Patience: {} | Epoch: {} | Loss: {}", patience_, epoch, valcn
)
logger.debug(f"Patience: {patience_} | Epoch: {epoch} | Loss: {valcn}")

if valcn > valc:
patience_ += 1
Expand Down
11 changes: 7 additions & 4 deletions auton_survival/models/dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
"""

from loguru import logger
import logging
import torch
import numpy as np

Expand All @@ -183,6 +183,9 @@
__pdoc__["DSMBase"] = False


logger = logging.getLogger(__name__)


class DSMBase:
"""Base Class for all DSM models"""

Expand Down Expand Up @@ -547,9 +550,9 @@ def __call__(self):
"An unfitted instance of the Deep Survival Machines model"
)

logger.info("Number of underlying distributions (k): {}", self.k)
logger.info("Hidden Layers: {}", self.layers)
logger.info("Distribution Choice: {}", self.dist)
logger.info(f"Number of underlying distributions (k): {self.k}")
logger.info(f"Hidden Layers: {self.layers}")
logger.info(f"Distribution Choice: {self.dist}")


class DeepRecurrentSurvivalMachines(DSMBase):
Expand Down
5 changes: 4 additions & 1 deletion auton_survival/models/dsm/dsm_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import numpy as np

import gc
from loguru import logger
import logging


logger = logging.getLogger(__name__)


def pretrain_dsm(
Expand Down
Loading

0 comments on commit dae6c25

Please sign in to comment.