Skip to content

Commit

Permalink
improve logic for handling both a pre-trained model and a dataset_tra…
Browse files Browse the repository at this point in the history
…in when train is False in dataset.py
  • Loading branch information
gcroci2 committed Oct 20, 2023
1 parent fa45c60 commit fe18d3d
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions deeprank2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,15 @@ def __init__( # pylint: disable=too-many-arguments
self._check_features()

if not train:
if not isinstance(train_data, GridDataset):
raise TypeError(f"""The train dataset provided is type: {type(train_data)}
Please provide a valid training GridDataset.""")

#check inherited parameter with the ones in the training set
inherited_params = ["features", "target", "target_transform", "task", "classes"]
self._check_inherited_params(inherited_params, train_data)
if isinstance(train_data, str):
pass
elif isinstance(train_data, GridDataset):
#check inherited parameter with the ones in the training set
inherited_params = ["features", "target", "target_transform", "task", "classes"]
self._check_inherited_params(inherited_params, train_data)
else:
raise TypeError(f"""The train data provided is type: {type(train_data)}
Please provide a valid training GridDataset or the path to a valid DeepRank2 pre-trained model.""")

elif train and train_data:
_log.warning("""`train_data` has been set but train flag was set to True.
Expand Down Expand Up @@ -725,13 +727,17 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local
self._check_features()

if not train:
if not isinstance(train_data, GraphDataset):
raise TypeError(f"""The train dataset provided is type: {type(train_data)}
Please provide a valid training GraphDataset.""")

#check inherited parameter with the ones in the training set
inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"]
self._check_inherited_params(inherited_params, train_data)
if isinstance(train_data, str):
pass
elif isinstance(train_data, GraphDataset):
#check inherited parameter with the ones in the training set
inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"]
self._check_inherited_params(inherited_params, train_data)
train_means = train_data.means
train_devs = train_data.devs
else:
raise TypeError(f"""The train data provided is type: {type(train_data)}
Please provide a valid training GraphDataset or the path to a valid DeepRank2 pre-trained model.""")

elif train and train_data:
_log.warning("""`train_data` has been set but train flag was set to True.
Expand All @@ -756,12 +762,8 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local
self.hdf5_to_pandas()
self._compute_mean_std()
elif standardize and (not train):
if (train_data.means is None) or (train_data.devs is None):
if train_data.df is None:
train_data.hdf5_to_pandas()
train_data._compute_mean_std()
self.means = train_data.means
self.devs = train_data.devs
self.means = train_means
self.devs = train_devs

def get(self, idx: int) -> Data:
"""Gets one graph item from its unique index.
Expand Down

0 comments on commit fe18d3d

Please sign in to comment.