Skip to content

lujiarui/esmdiff

Repository files navigation

esmdiff

Installation

conda create -n py310 python=3.10 -y
conda activate py310
pip install -r requirements.txt
pip install -e .

The first time running ESM3:

In order to download the weights, we require users to accept our non-commercial license. The weights are stored on HuggingFace Hub under HuggingFace/EvolutionaryScale/esm3. Please create an account and accept the license.

from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()

Process data

(Optionally) Download the whole pdb database and extract/process structures to pickle (.pkl) files:

# to download the whole pdb database
python scripts/download_pdb_mmcif.sh ./pdb_mmcif 

# ad hoc: biopython=1.79 is required for mmcif_parsing
pip install biopython==1.79
python scripts/pdb/preprocess.py --mmcif_dir pdb_data/pdb_mmcif --output_dir pdb_data/processed_chains --per_chain --strip_array

# turn back to this version for main usage
pip install biopython==1.84 

For training purpose, the VQ-VAE encoding should be pre-computed:

# turn processed pickle (above) or pdb files into 
python scripts/dump.py pdb_data/processed_chains pdb_data/processed_chains_encoding pkl
# if you have some dataset of pdb files at hand
python scripts/dump.py pdb_data/raw_pdb pdb_data/raw_pdb_encoding pdb

Training

sbatch train.sh experiment=jlm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 
sbatch train.sh experiment=clm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 
sbatch train.sh experiment=mdlm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 

Inference

Sample from HuggingFace-based models (T5, GPT2), for example:

python slm/sample_hf.py ckpt_path=logs/ConditionalLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt inference.input=data/targets/bpti inference.output=outputs/inference inference.batch_size=32 inference.n_samples=100
# or 
python slm/sample_hf.py ckpt_path=logs/ConditionalLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt inference.target=bpti inference.output=outputs/inference inference.batch_size=32 inference.n_samples=100

Sample from ESMDiff (masked diffusion fine-tuned ESM3):

python slm/sample_esmdiff.py --input data/targets/bpti --output outputs/inference_esmdiff --num_steps 25 --num_samples 100 --ckpt logs/MaskedDiffusionLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt --mode ddpm
# inpainting 
python slm/sample_esmdiff.py --input data/targets/bpti --output outputs/inference_esmdiff --num_steps 25 --num_samples 100 --mask_ids 1,2,3,4,5

Evaluation of samples

See ./analysis.

LICENSE

The source code and model can be used for non-commerical purpose. For any parts related to ESM3, please strictly follow the EvolutionaryScale Community License Agreement https://www.evolutionaryscale.ai/policies/community-license-agreement.

Remarks

This repo is still work-in-progress and more features will be added in the short future.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published