Objax is a machine learning framework that provides an Object Oriented layer for JAX.
Project description
Objax
Tutorials | Install | Documentation | Philosophy
This is not an officially supported Google product.
Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.
You can find READMEs in the subdirectory of this project, for example:
User installation guide
You install Objax using pip
as follows:
pip install --upgrade objax
Objax supports GPUs but assumes that you already have some version of CUDA instaled. Here are the extra steps:
# Update accordingly to your installed CUDA version
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl
Useful environment configurations
Here are a few useful options:
# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Testing your installation
You can test your installation by running the code below:
import jax
import objax
print(f'Number of GPUs {jax.device_count()}')
x = objax.random.normal((100, 4))
m = objax.nn.Linear(4, 5)
print('Matrix product shape', m(x).shape) # (100, 5)
x = objax.random.normal((100, 3, 32, 32))
m = objax.nn.Conv2D(3, 4, k=3)
print('Conv2D return shape', m(x).shape) # (100, 4, 32, 32)
Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.
Runing code examples
Clone the code repository:
git clone https://github.com/google/objax.git
cd objax/examples
Developer installation guide
We recommend using virtualenv
if you want to develop in Objax. The setup for
Ubuntu or a similar Linux distribution is as follows:
# Install virtualenv if you haven't done so already
sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv pandoc
# Create a virtual environment (for example ~/jax3, you can use your name here)
virtualenv -p python3 --system-site-packages ~/jax3
# Start the virtual environment
. ~/jax3/bin/activate
# Clone objax git repository, if you haven't.
git clone https://github.com/google/objax.git
cd objax
# Install python dependencies.
pip install --upgrade -r requirements.txt
pip install --upgrade -r docs/requirements.txt
pip install --upgrade -r examples/requirements.txt
# If you have CUDA installed, specify your installed CUDA version.
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl
The current folder must be in PYTHONPATH
. You can do this with the following command:
export PYTHONPATH=$PYTHONPATH:.
Running tests
You can run all tests as follows:
./tests/run.sh
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file objax-1.0.2.tar.gz
.
File metadata
- Download URL: objax-1.0.2.tar.gz
- Upload date:
- Size: 35.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 80621e36425e793214b8bccc104704f39bc6018793e003ffabf12c22ae840028 |
|
MD5 | 36fe3d650e7b18b707d481f21ad2ece4 |
|
BLAKE2b-256 | 88bdb105855a6093bb0c05d42723ce5680d6149ed36f407f8f444917cf3e458b |