Skip to content

Commit

Permalink
Merge pull request #515 from DeepRank/510_testing_pre_trained_gcroci2
Browse files Browse the repository at this point in the history
feat: improve `Trainer` and `DeeprankDataset` logic for production testing
  • Loading branch information
gcroci2 authored Jan 3, 2024
2 parents 1b8f1f0 + 7f33a68 commit 226ff35
Show file tree
Hide file tree
Showing 16 changed files with 867 additions and 352 deletions.
4 changes: 2 additions & 2 deletions .github/actions/install-python-and-package/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ runs:
conda install -c bioconda msms
## PyTorch, PyG, PyG adds
### Installing for CPU only on the CI
conda install pytorch torchvision torchaudio cpuonly -c pytorch
conda install pyg -c pyg
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 -c pytorch
pip install torch_geometric==2.3.1
pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)")+cpu.html
- name: Install dependencies on MacOS
shell: bash {0}
Expand Down
54 changes: 43 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ DeepRank2 extensive documentation can be found [here](https://deeprank2.rtfd.io/
- [Table of contents](#table-of-contents)
- [Installation](#installation)
- [Dependencies](#dependencies)
- [Deeprank2 Package](#deeprank2-package)
- [Deeprank2 Package](#deeprank2-package)
- [Test installation](#test-installation)
- [Contributing](#contributing)
- [Data generation](#data-generation)
- [Datasets](#datasets)
- [GraphDataset](#graphdataset)
- [GridDataset](#griddataset)
- [Training](#training)
- [Run a pre-trained model on new data](#run-a-pre-trained-model-on-new-data)
- [Computational performances](#computational-performances)
- [Package development](#package-development)

Expand All @@ -61,7 +62,8 @@ Before installing deeprank2 you need to install some dependencies. We advise to
* [Here](https://ssbio.readthedocs.io/en/latest/instructions/msms.html) for MacOS with M1 chip users.
* [PyTorch](https://pytorch.org/get-started/locally/)
* We support torch's CPU library as well as CUDA.
* [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`.
* Currently, the package is tested using [PyTorch 2.0.1](https://pytorch.org/get-started/previous-versions/#v201).
* [PyG](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`.
* [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/)
* Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4:
* on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`.
Expand All @@ -70,7 +72,7 @@ Before installing deeprank2 you need to install some dependencies. We advise to
* Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`.
* For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html).

### Deeprank2 Package
## Deeprank2 Package

Once the dependencies are installed, you can install the latest stable release of deeprank2 using the PyPi package manager:

Expand Down Expand Up @@ -214,14 +216,12 @@ dataset_train = GraphDataset(
dataset_val = GraphDataset(
hdf5_path = hdf5_paths,
subset = valid_ids,
train = False,
dataset_train = dataset_train
train_source = dataset_train
)
dataset_test = GraphDataset(
hdf5_path = hdf5_paths,
subset = test_ids,
train = False,
dataset_train = dataset_train
train_source = dataset_train
)
```

Expand All @@ -248,14 +248,12 @@ dataset_train = GridDataset(
dataset_val = GridDataset(
hdf5_path = hdf5_paths,
subset = valid_ids,
train = False,
dataset_train = dataset_train,
train_source = dataset_train,
)
dataset_test = GridDataset(
hdf5_path = hdf5_paths,
subset = test_ids,
train = False,
dataset_train = dataset_train,
train_source = dataset_train,
)
```

Expand Down Expand Up @@ -313,6 +311,40 @@ trainer.test()

```

### Run a pre-trained model on new data

If you want to analyze new PDB files using a pre-trained model, the first step is to process and save them into HDF5 files [as we have done above](#data-generation).

Then, the `DeeprankDataset` instance for the newly processed data can be created. Do this by specifying the path for the pre-trained model in `train_source`, together with the path to the HDF5 files just created. Note that there is no need of setting the dataset's parameters, since they are inherited from the information saved in the pre-trained model. Let's suppose that the model has been trained with `GraphDataset` objects:

```python
from deeprank2.dataset import GraphDataset

dataset_test = GraphDataset(
hdf5_path = "<output_folder>/<prefix_for_outputs>",
train_source = "<pretrained_model_path>"
)
```

Finally, the `Trainer` instance can be defined and the new data can be tested:

```python
from deeprank2.trainer import Trainer
from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork
from deeprank2.utils.exporters import HDF5OutputExporter

trainer = Trainer(
NaiveNetwork,
dataset_test = dataset_test,
pretrained_model = "<pretrained_model_path>",
output_exporters = [HDF5OutputExporter("<output_folder_path>")]
)

trainer.test()
```

For more details about how to run a pre-trained model on new data, see the [docs](https://deeprank2.readthedocs.io/en/latest/getstarted.html#run-a-pre-trained-model-on-new-data).

## Computational performances

We measured the efficiency of data generation in DeepRank2 using the tutorials' [PDB files](https://zenodo.org/record/8187806) (~100 data points per data set), averaging the results run on Apple M1 Pro, using a single CPU.
Expand Down
Loading

0 comments on commit 226ff35

Please sign in to comment.