This package helps users do distributed training with TensorFlow on their Spark clusters.
Project description
Spark TensorFlow Distributor
This package helps users do distributed training with TensorFlow on their Spark clusters.
Installation
This package requires Python 3.6+, tensorflow>=2.1.0
and pyspark>=3.0.0
to run.
To install spark-tensorflow-distributor
, run:
pip install spark-tensorflow-distributor
The installation does not install PySpark because for most users, PySpark is already installed.
In addition, tensorflow
not installed so that users may choose between regular and CPU-only
installation via pip install tensorflow
and pip install tensorflow-cpu
.
If you do not have PySpark installed, you can install it directly:
pip install pyspark>=3.0.*
Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.
Running Tests
For integration tests, first build the master and worker images and then run the test script.
docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
./tests/integration/run.sh
For linting, run the following.
./tests/lint.sh
To use the autoformatter, run the following.
yapf --recursive --in-place spark_tensorflow_distributor
Examples
Run following example code in pyspark
shell:
from spark_tensorflow_distributor import MirroredStrategyRunner
# Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
import tensorflow as tf
import uuid
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def make_datasets():
(mnist_images, mnist_labels), _ = \
tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz')
dataset = tf.data.Dataset.from_tensor_slices((
tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
return dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'],
)
return model
train_datasets = make_datasets()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_datasets = train_datasets.with_options(options)
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
MirroredStrategyRunner(num_slots=8).run(train)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file spark_tensorflow_distributor-1.0.0.tar.gz
.
File metadata
- Download URL: spark_tensorflow_distributor-1.0.0.tar.gz
- Upload date:
- Size: 9.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cf810d6191b8ee8592e7ed69b9f0906118b5d5cfa6bdc7370a93580e98beec68 |
|
MD5 | cee4b4d32d2a79fbc978c4a3f4385191 |
|
BLAKE2b-256 | a6959cc0a72f6c686e6178777de9926202c260e49eee189170cc50ec29491bec |
File details
Details for the file spark_tensorflow_distributor-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: spark_tensorflow_distributor-1.0.0-py3-none-any.whl
- Upload date:
- Size: 8.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 86c80b651e55a7e0e90017b7c5616e0648f6a22fab41cf7560446e5978f73467 |
|
MD5 | 1d20e265261633ed8d3142e254a8eac0 |
|
BLAKE2b-256 | 7b130551e7f0909d38bb54e64bfe290dc01f4dd6b11aa5108f20139929132a4f |