Skip to main content

Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.

Project description

ESM-Efficient

Efficient implementatin of ESM family of models.

Installation

conda install pytorch cudatoolkit=12.5 -c pytorch -c nvidia
pip install flash-attn --no-build-isolation
pip install esm-efficient

Usage

Predict the log probabilities of a sequence of tokens using the model.

import torch
from esme import ESME
from esme.alphabet import tokenize

# create load the model
model = ESME.load_from_checkpoint("{model}.safetensors")

tokens = tokenize(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])

# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)

# predict log probabilities
log_probs = model.predict_log_prob(tokens, pad_output=True)
# log_probs.shape = (2, seq_len, embed_size)

from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, (cu_lens, max_len) = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)

Predict effect of variants:

from esme.variant import predict_mask_margin

seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predicit_mask_margin(model, seq)
# ... pd.DataFrame({
# ...    'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ...    'score': [0.1, 0.2, ..., -0.3]
# ... }).set_index('variant')

Fine-tune the model with lora adapters:

# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])

# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()

# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])

# load the model with the lora weights
model.load_lora('<path>.safetensors', adapter_names=['adapter1'])

Quantization of the model:

model = model.from_pretrained('8M.safetensors', quantization='4bit')

Activation checkpointing of each transformer layer:

model = model.from_pretrained('8M.safetensors', checkpointing=True)

Model Weights

The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient

Evaluation

To perform the evaluation reported in the paper, run the following command:

snakemake -n --use-conda

This will download the data, train the models, and evaluate them. The results will be saved in the results directory. See the workflow/Snakefile for more details.

To generate a specific figures in the paper, run the following command:

snakemake reports/paper_figures/figure-2.pdf -n --use-conda 

Citation



          

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

esm_efficient-0.0.1.tar.gz (29.2 kB view details)

Uploaded Source

Built Distribution

esm_efficient-0.0.1-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file esm_efficient-0.0.1.tar.gz.

File metadata

  • Download URL: esm_efficient-0.0.1.tar.gz
  • Upload date:
  • Size: 29.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for esm_efficient-0.0.1.tar.gz
Algorithm Hash digest
SHA256 45e615a1adcc14e60507c01551a55f690b2f75e9c545b7aa705a71627d135142
MD5 3e0221da9de99ff3d97bdb08f19e2100
BLAKE2b-256 99c9f6f95cda951648bc768ba54c13de1123f5828fc68586e731480b10258539

See more details on using hashes here.

File details

Details for the file esm_efficient-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for esm_efficient-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 06370bd1745a67161d27ce119e291244c864473316676274432ff229c953346d
MD5 963c806251ad21a1ea52a3f4266ba2f5
BLAKE2b-256 6ea754951f9cedcc8e735a260aecad67ec24fe938a42d95bfcd84cac953e8fc7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page