Skip to main content

No project description provided

Project description

RoRF - Routing on Random Forests

RoRF is a framework for training and serving random forest-based LLM routers. Our experiments show that:

  • Routing between a pair of strong and weak models can reduce costs while maintaining the strong model's performance.
  • Routing between a pair of two strong models can reduce costs while outperforming both individual models.

Our core features include:

  • 12 pre-trained routers across 6 model pairs and 2 embedding models (jinaai/jina-embeddings-v3, voyageai/voyage-large-2-instruct) that reduce costs while maintaining or improving performance.
  • Our pre-trained routers outperform existing routing solutions, including open-source and commercial offerings.

Installation

PyPI

pip install rorf

Source

git clone https://github.com/Not-Diamond/RoRF
cd RoRF
pip install -e .

Quickstart

We adopt RouteLLM's Controller to allow users to replace their existing routing setups with RoRF. Our Controller requires a router (available either locally or on Huggingface Hub) that routes between model_a (usually stronger) and model_b (usually weaker). Our release includes 6 model pairs between different models and providers.

from rorf.controller import Controller

router = Controller(
    router="notdiamond/rorf-jina-llama31405b-llama3170b",
    model_a="llama-3.1-405b-instruct",
    model_b="llama-3.1-70b-instruct",
    threshold=0.3,
)

recommended_model = router.route("What is the meaning of life?")
print(f"Recommended model: {recommended_model}")

We also provide a threshold parameter that determines the percentage of calls made to each model, allowing users to decide their own cost vs performance tradeoffs.

Threshold calibration

The threshold parameter controls the percentage of calls made to each model but depending on the types of query you might receive, you should calibrate it with your own data. As an example, we can use the dataset notdiamond/rorf-llama31405b-llama3170b-battles to calibrate the threshold for sending 50% of the calls to llama-3.1-405b-instruct.

python -m rorf.calibrate_threshold --calibration-dataset "notdiamond/rorf-llama31405b-llama3170b-battles" --router "notdiamond/rorf-jina-llama31405b-llama3170b" --model-a-pct 0.5 --task generate

Pre-trained routers

We provide 12 pre-trained routers using 2 different embedding models. 6 routers are based on the open-source jinaai/jina-embeddings-v3 embedding model, giving developers a completely free experience. Another 6 are based on the closed source voyageai/voyage-large-2-instruct embedding model, allowing developers to use the routers easily with less compute.

The notation rorf-<embed>-<model_a>-<model_b> indicates the embedding model <embed> used and the two models <model_a> and <model_b> that it routes between.

jina-embeddings-v3

voyage-large-2-instruct

To use these routers, set the environment variable VOYAGE_API_KEY=....

Training RoRF

We include our training framework for RoRF so that users can train custom routers on their own data and model pairs. trainer.py is the entry-point for training, and run_trainer.sh provides an example command to train a model router for llama-3.1-405b-instruct vs llama-3.1-70b-instruct on top of Jina AI's embeddings. The key arguments are

  • --model_a: This is the first LLM in the router, usually the stronger model.
  • --model_b: This is the second LLM in the router, usually the weaker model.
  • --dataset_path: This is the HF dataset you want to use to train the router. The dataset must have the columns Input, containing the input prompt, and <model>/score, containing the score achieved by the <model>, in this case, either model_a or model_b. See our calibration dataset for example.
  • --eval_dataset: This is the evaluation dataset for evaluating the router after training. The format should be the same as --dataset_path.
  • --embedding_provider: This is the embedding model to use for the router. We have implemented "voyage", "openai", and "jina" as embedding providers. To use "voyage" or "openai" embedding models, make sure to set the environment variable VOYAGE_API_KEY and OPENAI_API_KEY accordingly.
  • --max_depth: This is the max depth of the random forest estimator. Defaults to 20.
  • --n_estimators: This is the number of trees in the random forest. Defaults to 100.
  • --model_id: This is the name of the model that will be pushed to Huggingface.
  • --model_org: This is the name of the organization that the model will be pushed to on Huggingface.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rorf-0.1.3.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

rorf-0.1.3-py3-none-any.whl (12.4 kB view details)

Uploaded Python 3

File details

Details for the file rorf-0.1.3.tar.gz.

File metadata

  • Download URL: rorf-0.1.3.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.10.9-arch1-2

File hashes

Hashes for rorf-0.1.3.tar.gz
Algorithm Hash digest
SHA256 880ab84f3c17bf9cca372ac2d3c50eda32e96c92a599521eded4832414560257
MD5 d5c43a2b8fa4ea2c52d2dc0d721edd24
BLAKE2b-256 00e80fd8de73cb04f35473c6f077d7fe7e064ef5027cfe43767b201ab2f1c51d

See more details on using hashes here.

File details

Details for the file rorf-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: rorf-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 12.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.10.9-arch1-2

File hashes

Hashes for rorf-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 52a1a8791c37132c1fb26b76ce7b4177cc26524f19e312a0669181f0c597bc8a
MD5 170460330ccb7c3e33ecdf8db1a3eff7
BLAKE2b-256 8a46f905837e1d4efaf36fbfe85ed7afe86dff62d5aff6127bf0b0c571b38426

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page