From fe18d3d390897d9f4f53e6c8082e747bd4a7e1ef Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Fri, 20 Oct 2023 16:09:32 +0200 Subject: [PATCH] improve logic for handling both a pre-trained model and a dataset_train when train is False in dataset.py --- deeprank2/dataset.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 5d7dc7021..48757ef67 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -486,13 +486,15 @@ def __init__( # pylint: disable=too-many-arguments self._check_features() if not train: - if not isinstance(train_data, GridDataset): - raise TypeError(f"""The train dataset provided is type: {type(train_data)} - Please provide a valid training GridDataset.""") - - #check inherited parameter with the ones in the training set - inherited_params = ["features", "target", "target_transform", "task", "classes"] - self._check_inherited_params(inherited_params, train_data) + if isinstance(train_data, str): + pass + elif isinstance(train_data, GridDataset): + #check inherited parameter with the ones in the training set + inherited_params = ["features", "target", "target_transform", "task", "classes"] + self._check_inherited_params(inherited_params, train_data) + else: + raise TypeError(f"""The train data provided is type: {type(train_data)} + Please provide a valid training GridDataset or the path to a valid DeepRank2 pre-trained model.""") elif train and train_data: _log.warning("""`train_data` has been set but train flag was set to True. @@ -725,13 +727,17 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local self._check_features() if not train: - if not isinstance(train_data, GraphDataset): - raise TypeError(f"""The train dataset provided is type: {type(train_data)} - Please provide a valid training GraphDataset.""") - - #check inherited parameter with the ones in the training set - inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"] - self._check_inherited_params(inherited_params, train_data) + if isinstance(train_data, str): + pass + elif isinstance(train_data, GraphDataset): + #check inherited parameter with the ones in the training set + inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"] + self._check_inherited_params(inherited_params, train_data) + train_means = train_data.means + train_devs = train_data.devs + else: + raise TypeError(f"""The train data provided is type: {type(train_data)} + Please provide a valid training GraphDataset or the path to a valid DeepRank2 pre-trained model.""") elif train and train_data: _log.warning("""`train_data` has been set but train flag was set to True. @@ -756,12 +762,8 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local self.hdf5_to_pandas() self._compute_mean_std() elif standardize and (not train): - if (train_data.means is None) or (train_data.devs is None): - if train_data.df is None: - train_data.hdf5_to_pandas() - train_data._compute_mean_std() - self.means = train_data.means - self.devs = train_data.devs + self.means = train_means + self.devs = train_devs def get(self, idx: int) -> Data: """Gets one graph item from its unique index.