Skip to main content

This package helps users do distributed training with TensorFlow on their Spark clusters.

Project description

Spark TensorFlow Distributor

This package helps users do distributed training with TensorFlow on their Spark clusters.

Installation

This package requires Python 3.6+, tensorflow>=2.1.0 and pyspark>=3.0.0 to run. To install spark-tensorflow-distributor, run:

pip install spark-tensorflow-distributor

The installation does not install PySpark because for most users, PySpark is already installed. If you do not have PySpark installed, you can install it directly:

pip install pyspark>=3.0.*

Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.

Running Tests

For integration tests, first build the master and worker images and then run the test script.

docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
./tests/integration/run.sh

For linting, run the following.

./tests/lint.sh

To use the autoformatter, run the following.

yapf --recursive --in-place spark_tensorflow_distributor

Examples

Run following example code in pyspark shell:

from spark_tensorflow_distributor import MirroredStrategyRunner


# Taken from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
    import tensorflow_datasets as tfds
    import tensorflow as tf
    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets_unbatched():
        # Scaling MNIST data from (0, 255] to (0., 1.]
        def scale(image, label):
            image = tf.cast(image, tf.float32)
            image /= 255
            return image, label
        datasets, info = tfds.load(
            name='mnist',
            with_info=True,
            as_supervised=True,
        )
        return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])
        model.compile(
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
            metrics=['accuracy'],
        )
        return model

    GLOBAL_BATCH_SIZE = 64 * 8
    train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE).repeat()
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model()
    multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
    return tf.config.experimental.list_physical_devices('GPU')

MirroredStrategyRunner(num_slots=4).run(train)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_tensorflow_distributor-0.0.2.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file spark_tensorflow_distributor-0.0.2.tar.gz.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.2.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.2.tar.gz
Algorithm Hash digest
SHA256 08f34ace5dd5107de924b8cbbfbb37fad8e98edec30d1e550d6c9d9b001f6ce8
MD5 f6681f8616b16309c23623af64151d96
BLAKE2b-256 ce2bc99f9cee88b83ef6b5e4d28d8581fce6a87381deb41bdc947c20255f7db4

See more details on using hashes here.

File details

Details for the file spark_tensorflow_distributor-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3d4c3e004f0e64dfa3d224bd19361af1fc66a084e51487d8f6d1f52a204d6a07
MD5 bb02653e9cc7f52cd289ee031e879017
BLAKE2b-256 1de2b51226113c1c60e88f1eefd17b74356c1aba93e460b7047120db3277e80b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page