Skip to main content

This package helps users do distributed training with TensorFlow on their Spark clusters.

Project description

Spark TensorFlow Distributor

This package helps users do distributed training with TensorFlow on their Spark clusters.

Installation

This package requires Python 3.6+, tensorflow>=2.1.0 and pyspark>=3.0.0 to run. To install spark-tensorflow-distributor, run:

pip install spark-tensorflow-distributor

The installation does not install PySpark because for most users, PySpark is already installed. In addition, tensorflow not installed so that users may choose between regular and CPU-only installation via pip install tensorflow and pip install tensorflow-cpu. If you do not have PySpark installed, you can install it directly:

pip install pyspark>=3.0.*

Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.

Running Tests

For integration tests, first build the master and worker images and then run the test script.

docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
./tests/integration/run.sh

For linting, run the following.

./tests/lint.sh

To use the autoformatter, run the following.

yapf --recursive --in-place spark_tensorflow_distributor

Examples

Run following example code in pyspark shell:

from spark_tensorflow_distributor import MirroredStrategyRunner

# Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
    import tensorflow as tf
    import uuid

    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets():
        (mnist_images, mnist_labels), _ = \
            tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz')

        dataset = tf.data.Dataset.from_tensor_slices((
            tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
            tf.cast(mnist_labels, tf.int64))
        )
        dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
        return dataset

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])
        model.compile(
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
            metrics=['accuracy'],
        )
        return model

    train_datasets = make_datasets()
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model()
    multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)

MirroredStrategyRunner(num_slots=8).run(train)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_tensorflow_distributor-1.0.0.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file spark_tensorflow_distributor-1.0.0.tar.gz.

File metadata

  • Download URL: spark_tensorflow_distributor-1.0.0.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for spark_tensorflow_distributor-1.0.0.tar.gz
Algorithm Hash digest
SHA256 cf810d6191b8ee8592e7ed69b9f0906118b5d5cfa6bdc7370a93580e98beec68
MD5 cee4b4d32d2a79fbc978c4a3f4385191
BLAKE2b-256 a6959cc0a72f6c686e6178777de9926202c260e49eee189170cc50ec29491bec

See more details on using hashes here.

File details

Details for the file spark_tensorflow_distributor-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: spark_tensorflow_distributor-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for spark_tensorflow_distributor-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 86c80b651e55a7e0e90017b7c5616e0648f6a22fab41cf7560446e5978f73467
MD5 1d20e265261633ed8d3142e254a8eac0
BLAKE2b-256 7b130551e7f0909d38bb54e64bfe290dc01f4dd6b11aa5108f20139929132a4f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page