Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.
Project description
ESM-Efficient
Efficient implementatin of ESM family of models.
Installation
conda install pytorch cudatoolkit=12.5 -c pytorch -c nvidia
pip install flash-attn --no-build-isolation
pip install esm-efficient
Usage
Predict the log probabilities of a sequence of tokens using the model.
import torch
from esme import ESM2
from esme.alphabet import tokenize
# create load the model
model = ESM2.from_pretrained("{model}.safetensors", device=0)
tokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)
# predict log probabilities
log_probs = model.predict_log_prob(tokens)
# log_probs.shape = (2, seq_len, embed_size)
from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
cu_lens = cu_lens.to(0)
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)
Predict effect of variants:
from esme.variant import predict_mask_margin
seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predict_mask_margin(model, seq)
# ... pd.DataFrame({
# ... 'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ... 'score': [-0.1, -0.2, ..., -0.3]
# ... }).set_index('variant')
Fine-tune the model with lora adapters:
# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])
# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()
# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])
# load the model with the lora weights
model.load_lora('<path>.safetensors')
Quantization of the model:
model = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)
Activation checkpointing of each transformer layer:
model = ESM2.from_pretrained('8M.safetensors', checkpointing=True)
Model Weights
The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient/tree/main
Evaluation
To perform the evaluation reported in the paper, run the following command:
snakemake -n --use-conda
This will download the data, train the models, and evaluate them. The results will be saved in the results
directory.
See the workflow/Snakefile
for more details.
To generate a specific figures in the paper, run the following command:
snakemake reports/paper_figures/figure-2.pdf -n --use-conda
Citation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file esm_efficient-0.0.2.tar.gz
.
File metadata
- Download URL: esm_efficient-0.0.2.tar.gz
- Upload date:
- Size: 29.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8f545792e795fda3285b400272192ffdc86fafe7a92d25526a1c39d67f2d5464 |
|
MD5 | caf3b8c12074c8508720e2118281c081 |
|
BLAKE2b-256 | c012e3a1251f86c7dc7bec490dbc7c6a945d161c6f392d99832d587d5c4c05aa |
Provenance
File details
Details for the file esm_efficient-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: esm_efficient-0.0.2-py3-none-any.whl
- Upload date:
- Size: 26.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 10863ff1268e6dbefbba1e59556e147f63ed200c948755e800698d80ae38292f |
|
MD5 | 0742089c45f59b78fdfd2cca1f30335c |
|
BLAKE2b-256 | 25edbf65f9a05c82a0576c77ad835c10e54a36557e13c9aa32e4434ab0992a6f |