Skip to content

MNIST with TensorFlow using SlurmClusterResolver

This toy-example shows how to do distributed training using TensorFlow 2 and SlurmClusterResolver.

Submit the job

cd jean-zay-doc/docs/examples/tf/tf_distributed
sbatch mnist_example_distributed.slurm

Code

The code for the distributed training (mnist_example.py):

# all taken from https://www.tensorflow.org/guide/keras/functional
import click


@click.command()
def train_dense_model_click():
    return train_dense_model(batch_size=64)


def train_dense_model(batch_size):
    # limit imports oustide the call to the function, in order to launch quickly
    # when using dask
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    # model building
    tf.keras.backend.clear_session()  # For easy reset of notebook state.

    slurm_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=15000)
    communication = tf.distribute.experimental.CommunicationImplementation.NCCL
    mirrored_strategy = tf.distribute.MultiWorkerMirroredStrategy(cluster_resolver=slurm_resolver, 
                                                                  communication_options=communication)
    print('Number of replicas:', mirrored_strategy.num_replicas_in_sync)
    with mirrored_strategy.scope():
        inputs = keras.Input(shape=(784,), name='img')
        x = layers.Dense(64, activation='relu')(inputs)
        x = layers.Dense(64, activation='relu')(x)
        outputs = layers.Dense(10)(x)

        model = keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')

        model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      optimizer=keras.optimizers.RMSprop(),
                      metrics=['accuracy'])

    # training and inference
    # network is not reachable, so we use random data
    x_train = tf.random.normal((60000, 784), dtype='float32')
    x_test = tf.random.normal((10000, 784), dtype='float32')
    y_train = tf.random.uniform((60000,), minval=0, maxval=10, dtype='int32')
    y_test = tf.random.uniform((10000,), minval=0, maxval=10, dtype='int32')


    history = model.fit(x_train, y_train,
                        batch_size=batch_size,
                        epochs=5,
                        validation_split=0.2)
    test_scores = model.evaluate(x_test, y_test, verbose=2)
    print('Test loss:', test_scores[0])
    print('Test accuracy:', test_scores[1])
    return True

if __name__ == '__main__':
    train_dense_model_click()

and the script to launch the job:

#!/bin/bash
#SBATCH --job-name=mnist_tf_distributed     # job name
#SBATCH --nodes=2                 # number of nodes
#SBATCH --ntasks-per-node=1         # number of MPI task per node
#SBATCH --gres=gpu:4                # number of GPUs per node
#SBATCH --cpus-per-task=40          # since nodes have 40 cpus
#SBATCH --hint=nomultithread         # we get physical cores not logical
#SBATCH --distribution=block:block  # distribution, might be better to have contiguous blocks
#SBATCH --time=00:10:00             # job length
#SBATCH --output=mnist_tf_distr_log_%j.out  # std out
#SBATCH --error=mnist_tf_distr_log_%j.out   # std err
#SBATCH --exclusive                 # we reserve the entire node for our job
#SBATCH --qos=qos_gpu-dev         # we are submitting a test job
#SBATCH -A changeme@gpu

unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY

set -x
cd ${SLURM_SUBMIT_DIR}

module purge
module load tensorflow-gpu/py3/2.4.0

srun python ./mnist_example.py