Skip to main content

Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.

Project description

ESM-Efficient

Efficient implementatin of ESM family of models.

Installation

conda install pytorch cudatoolkit=12.5 -c pytorch -c nvidia
pip install flash-attn --no-build-isolation
pip install esm-efficient

Usage

Predict the log probabilities of a sequence of tokens using the model.

import torch
from esme import ESM2
from esme.alphabet import tokenize

# create load the model
model = ESM2.from_pretrained("{model}.safetensors", device=0)

tokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)

# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)

# predict log probabilities
log_probs = model.predict_log_prob(tokens)
# log_probs.shape = (2, seq_len, embed_size)

from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
cu_lens = cu_lens.to(0)
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)

Predict effect of variants:

from esme.variant import predict_mask_margin

seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predict_mask_margin(model, seq)
# ... pd.DataFrame({
# ...    'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ...    'score': [-0.1, -0.2, ..., -0.3]
# ... }).set_index('variant')

Fine-tune the model with lora adapters:

# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])

# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()

# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])

# load the model with the lora weights
model.load_lora('<path>.safetensors')

Quantization of the model:

model = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)

Activation checkpointing of each transformer layer:

model = ESM2.from_pretrained('8M.safetensors', checkpointing=True)

Model Weights

The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient/tree/main

Evaluation

To perform the evaluation reported in the paper, run the following command:

snakemake -n --use-conda

This will download the data, train the models, and evaluate them. The results will be saved in the results directory. See the workflow/Snakefile for more details.

To generate a specific figures in the paper, run the following command:

snakemake reports/paper_figures/figure-2.pdf -n --use-conda 

Citation



          

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

esm_efficient-0.0.2.tar.gz (29.3 kB view details)

Uploaded Source

Built Distribution

esm_efficient-0.0.2-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file esm_efficient-0.0.2.tar.gz.

File metadata

  • Download URL: esm_efficient-0.0.2.tar.gz
  • Upload date:
  • Size: 29.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.10

File hashes

Hashes for esm_efficient-0.0.2.tar.gz
Algorithm Hash digest
SHA256 8f545792e795fda3285b400272192ffdc86fafe7a92d25526a1c39d67f2d5464
MD5 caf3b8c12074c8508720e2118281c081
BLAKE2b-256 c012e3a1251f86c7dc7bec490dbc7c6a945d161c6f392d99832d587d5c4c05aa

See more details on using hashes here.

Provenance

File details

Details for the file esm_efficient-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for esm_efficient-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 10863ff1268e6dbefbba1e59556e147f63ed200c948755e800698d80ae38292f
MD5 0742089c45f59b78fdfd2cca1f30335c
BLAKE2b-256 25edbf65f9a05c82a0576c77ad835c10e54a36557e13c9aa32e4434ab0992a6f

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page