Skip to main content

This package helps users do distributed training with TensorFlow on their Spark clusters.

Project description

Spark TensorFlow Distributor

This package helps users do distributed training with TensorFlow on their Spark clusters.

Installation

This package requires Python 3.6+, tensorflow>=2.1.0 and pyspark>=3.0.0 to run. To install spark-tensorflow-distributor, run:

pip install spark-tensorflow-distributor

The installation does not install PySpark because for most users, PySpark is already installed. If you do not have PySpark installed, you can install it directly:

pip install pyspark>=3.0.*

Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.

Running Tests

For integration tests, first build the master and worker images and then run the test script.

docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
./tests/integration/run.sh

For linting, run the following.

./tests/lint.sh

To use the autoformatter, run the following.

yapf --recursive --in-place spark_tensorflow_distributor

Examples

Run following example code in pyspark shell:

from spark_tensorflow_distributor import MirroredStrategyRunner


# Taken from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
    import tensorflow_datasets as tfds
    import tensorflow as tf
    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets_unbatched():
        # Scaling MNIST data from (0, 255] to (0., 1.]
        def scale(image, label):
            image = tf.cast(image, tf.float32)
            image /= 255
            return image, label
        datasets, info = tfds.load(
            name='mnist',
            with_info=True,
            as_supervised=True,
        )
        return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])
        model.compile(
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
            metrics=['accuracy'],
        )
        return model

    GLOBAL_BATCH_SIZE = 64 * 8
    train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE).repeat()
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model()
    multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
    return tf.config.experimental.list_physical_devices('GPU')

MirroredStrategyRunner(num_slots=4).run(train)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_tensorflow_distributor-0.0.3.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file spark_tensorflow_distributor-0.0.3.tar.gz.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.3.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.3.tar.gz
Algorithm Hash digest
SHA256 50434b9705f3e52817dda64d9f648869ef8f602f55c6d8ba083899e280d44368
MD5 8432ac961e6d98b96c7836568487783b
BLAKE2b-256 fc22fe3ec5196b2d9a2c51841844266bfd866dbe89f2afdf723de8f94447807b

See more details on using hashes here.

File details

Details for the file spark_tensorflow_distributor-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 188317766a0db2a6f61f2af4423241cb9231e57e16a82f60a618b20d5eb6fbd1
MD5 a7d8c6dd8f9940872ee1e70d1724af8f
BLAKE2b-256 c6d797fee5bfba62cac047ca928bee5e0ac1bfd98aed0b74481aadcbd1ef1285

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page