Skip to content

Commit

Permalink
Update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Nov 1, 2024
1 parent 8436111 commit 9cfabb7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 15 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,26 @@ graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.

### Finetuning

After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task:
```bash
graphium-train +finetuning custom finetuning.pretrained_model=[model_identifier] constants.data_path=[path_to_data] constants.task=[name_of_task] constants.task_type=[cls OR reg]
```

The `[model_identifier]` serves to identify the pretrained model among those maintained in the `GRAPHIUM_PRETRAINED_MODELS_DICT` in `graphium/utils/spaces.py`, where the `[model_identifier]` maps to the location of the checkpoint of the pretrained model.

The custom dataset to finetune from consists of two files `raw.csv` and `split.csv` that are provided in `[path_to_data]/[name_of_task]`. The `raw.csv` contains two columns, namely `smiles` with the smiles strings, and `target` with the corresponding targets. In `split.csv`, three columns `train`, `val`, `test` contain the indices of the rows in `raw.csv`. Examples can be found under `expts/data/finetuning_example-reg` (regression) and `expts/data/finetuning_example-cls` (binary classification).

### Fingerprinting

Alternatively, we can also obtain molecular embeddings (fingerprints) from a pretrained model:
```bash
graphium fps create custom pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers] datamodelu.df_path=[path_to_data]
```

After specifiying the `[model_identifier]`, we need to provide a list of layers from that model where we want to read out embeddings via `[layer_identifiers]`. An example can be found in `expts/hydra-configs/fingerprinting/custom.yaml`. In addition, the location of the smiles to be embedded needs to be passed as `[path_to_data]`. The data can be passed as a csv file with a column `smiles`, similar to `expts/data/finetuning_example-reg/raw.csv`.

## License

Under the Apache-2.0 license. See [LICENSE](LICENSE).
Expand Down
20 changes: 8 additions & 12 deletions expts/hydra-configs/finetuning/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ defaults:

constants:
benchmark: custom
task: task
# task_type: reg
# task_name: adme-fang-HPPB-reg
task_type: cls
task_name: CYP2D6_Veith
task: finetuning_example-cls # finetuning_example-cls OR finetuning_example-reg
task_type: cls # cls OR reg
data_path: expts/data
wandb:
name: finetune_${constants.task_name}
name: finetune_${constants.task}
project: finetuning
entity: valencelabs
tags:
- finetuning
- ${constants.task_name}
- ${constants.task}
- ${finetuning.pretrained_model}
seed: 42
max_epochs: 20
Expand All @@ -34,10 +32,8 @@ datamodule:
task_specific_args:
finetune:
df: null
# df_path: expts/data/finetuning_example-reg/raw.csv
# splits_path: expts/data/finetuning_example-reg/split.csv
df_path: expts/data/finetuning_example-cls/raw.csv
splits_path: expts/data/finetuning_example-cls/split.csv
df_path: ${constants.data_path}/${constants.task}/raw.csv
splits_path: ${constants.data_path}/${constants.task}/split.csv
smiles_col: smiles
label_cols: target
task_level: graph
Expand Down Expand Up @@ -90,7 +86,7 @@ finetuning:

# Optional finetuning head appended to model after finetuning_module
# finetuning_head:
# task: ${constants.task}
# task: finetune
# previous_module: task_heads
# incoming_level: graph
# model_type: mlp
Expand Down
5 changes: 3 additions & 2 deletions expts/hydra-configs/fingerprinting/custom.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pretrained_models:
dummy-pretrained-model:
pretrained:
model: dummy-pretrained-model:
layers:
- graph_output_nn-graph:0
- task_heads-zinc:0

Expand Down
2 changes: 1 addition & 1 deletion graphium/cli/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]:
project="graphium-fingerprints",
)

pretrained_models = cfg.get("pretrained_models")
pretrained_models = cfg.get("pretrained")

# Allow alternative definition of `pretrained_models` with the single model specifier and desired layers
if "layers" in pretrained_models.keys():
Expand Down

0 comments on commit 9cfabb7

Please sign in to comment.