Skip to main content

This package helps users do distributed training with TensorFlow on their Spark clusters.

Project description

Spark TensorFlow Distributor

This package helps users do distributed training with TensorFlow on their Spark clusters.

Installation

This package requires Python 3.6+, tensorflow>=2.1.0 and pyspark>=3.0.0 to run. To install spark-tensorflow-distributor, run:

pip install spark-tensorflow-distributor

The installation does not install PySpark because for most users, PySpark is already installed. If you do not have PySpark installed, you can install it directly:

pip install pyspark>=3.0.0

Running Tests

For integration tests, first build the master and worker images and then run the test script.

docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7 --build-arg UBUNTU_VERSION=18.04
./tests/integration/run.sh

Examples

Run following example code in pyspark shell:

from spark_tensorflow_distributor import MirroredStrategyRunner


# Taken from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
    import tensorflow_datasets as tfds
    import tensorflow as tf
    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets_unbatched():
        # Scaling MNIST data from (0, 255] to (0., 1.]
        def scale(image, label):
            image = tf.cast(image, tf.float32)
            image /= 255
            return image, label
        datasets, info = tfds.load(
            name='mnist',
            with_info=True,
            as_supervised=True,
        )
        return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])
        model.compile(
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
            metrics=['accuracy'],
        )
        return model

    GLOBAL_BATCH_SIZE = 64 * 8
    train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE).repeat()
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model()
    multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
    return tf.config.experimental.list_physical_devices('GPU')

MirroredStrategyRunner(num_slots=4).run(train)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_tensorflow_distributor-0.0.1.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file spark_tensorflow_distributor-0.0.1.tar.gz.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.1.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.1.tar.gz
Algorithm Hash digest
SHA256 73a361b9410f105537c853f452a45511fd59c0afc228871967e2060ce3cac26b
MD5 5a8e544d1894dba6a26db8bac970c4e9
BLAKE2b-256 719263018f588da56949ba67f859888bd2e62cf3db8b25f172e523e1355e4604

See more details on using hashes here.

File details

Details for the file spark_tensorflow_distributor-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: spark_tensorflow_distributor-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for spark_tensorflow_distributor-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 de0b1aa75110826d774298a02bc30189f97b9b98f9695adc131b1efb36f4c287
MD5 958660042e07481ad4226063ebac54a7
BLAKE2b-256 f7867a7d542e4541ea0f89a280c8423d3eec4528a7a8ebff034124bc5b108ef3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page