Skip to content

loading pre-trained weights for keras model is not supported in distributed training #264

@WuyangLI

Description

@WuyangLI

System Information

  • Framework (e.g. TensorFlow) / Algorithm (e.g. KMeans): Tensorflow
  • Framework Version: 1.8
  • Python Version: 2
  • CPU or GPU: GPU
  • Python SDK Version: 1.5.1
  • Are you using a custom image: No

Describe the problem

I created a distributed training job which trains a transfer learning model using VGG16.
The job would succeed if I don't load pre-trained weights when creating VGG16 backbone model.

    backend = tf.keras.applications.vgg16.VGG16(weights=None, 
                                                include_top=False,
                                                input_shape=(224, 224, 3))

However, exception would throw when I try to load pre-trained weights as done in the following code snippet:

    backend = tf.keras.applications.vgg16.VGG16(weights='imagenet', 
                                                include_top=False,
                                                input_shape=(224, 224, 3))

Note that, for non-distributed training, loading pre-trained weights would not cause any exception.

Minimal repro / logs

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'block5_conv3/bias': Operation was explicitly assigned to /job:ps/task:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
#011 [[Node: block5_conv3/bias = VariableV2[_class=["loc:@block5_conv3/bias"], container="", dtype=DT_FLOAT, shape=[512], shared_name="", _device="/job:ps/task:0"]()]]

Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/container_support/training.py", line 36, in start
    fw.train()
  File "/usr/local/lib/python2.7/dist-packages/tf_container/train_entry_point.py", line 164, in train
    train_wrapper.train()
  File "/usr/local/lib/python2.7/dist-packages/tf_container/trainer.py", line 73, in train
    tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/training.py", line 439, in train_and_evaluate
    executor.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/training.py", line 546, in run
    getattr(self, task_to_run)()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/training.py", line 601, in run_master
    self._start_distributed_training(saving_listeners=saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/training.py", line 739, in _start_distributed_training
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 363, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 856, in _train_model_default
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 831, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tf_container/trainer.py", line 108, in _model_fn
    return self.customer_script.model_fn(features, labels, mode, params)
  File "/opt/ml/code/keras_distributed_transfer_learning.py", line 14, in model_fn
    input_shape=(224, 224, 3))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/applications/vgg16.py", line 225, in VGG16
    model.load_weights(weights_path)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/engine/network.py", line 1190, in load_weights
    saving.load_weights_from_hdf5_group(f, self.layers)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/engine/saving.py", line 719, in load_weights_from_hdf5_group
    K.batch_set_value(weight_value_tuples)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/backend.py", line 2707, in batch_set_value
    get_session().run(assign_ops, feed_dict=feed_dict)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/backend.py", line 442, in get_session
    _initialize_variables(session)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/_impl/keras/backend.py", line 666, in _initialize_variables
    [variables_module.is_variable_initialized(v) for v in candidate_vars])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
  • Exact command to reproduce:
    code for creating the job
import sagemaker

from sagemaker.tensorflow import TensorFlow
from sagemaker.session import s3_input
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
role = get_execution_role()
training_steps = 100
evaluation_steps = 10

estimator = TensorFlow(
    entry_point='keras_distributed_transfer_learning.py',
    source_dir='./',
    role=role,
    training_steps=100,
    evaluation_steps=10,
    train_instance_count=2,
    train_instance_type='ml.p2.xlarge',
    input_mode='File')

input_dataset = s3_input('s3://xxxx/cats_and_dogs/')
estimator.fit(input_dataset)

keras_distributed_transfer_learning.py

import tensorflow as tf
from tensorflow.python.estimator.model_fn import ModeKeys as Modes

INPUT_TENSOR_NAME = "input_1"
NUM_CLASSES = 2
BATCH_SIZE = 10


def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    backend = tf.keras.applications.vgg16.VGG16(weights='imagenet', 
                                                include_top=False,
                                                input_shape=(224, 224, 3))
    x = backend.output
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
    model = tf.keras.models.Model(inputs=backend.input, outputs=x)
    image = tf.keras.layers.Input(tensor=features[INPUT_TENSOR_NAME])
    # Define operations
    if mode in (Modes.PREDICT, Modes.EVAL):
        logits = model(image, training=False)
        predicted_indices = tf.argmax(input=logits, axis=1)
        probabilities = tf.nn.softmax(logits, name='softmax_tensor')

    if mode in (Modes.TRAIN):
        logits = model(image, training=True)
        global_step = tf.train.get_or_create_global_step()
        loss = tf.losses.softmax_cross_entropy(
            onehot_labels=labels, logits=logits)
        tf.summary.scalar('OptimizeLoss', loss)
        
    if mode in (Modes.EVAL):
        logits = model(image, training=False)
        global_step = tf.train.get_or_create_global_step()
        loss = tf.losses.softmax_cross_entropy(
            onehot_labels=labels, logits=logits)
        tf.summary.scalar('OptimizeLoss', loss)

    if mode == Modes.PREDICT:
        predictions = {
            'classes': predicted_indices,
            'probabilities': probabilities
        }
        export_outputs = {
            'predictions': tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(
            mode, predictions=predictions, export_outputs=export_outputs)

    if mode == Modes.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(loss, global_step=global_step)
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

    if mode == Modes.EVAL:
        eval_metric_ops = {
            'accuracy': tf.metrics.accuracy(tf.argmax(labels, 1), predicted_indices)
        }
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=eval_metric_ops)


def _input_fn(training_dir, input_shape, batch_size):
    generator = tf.keras.preprocessing.image.ImageDataGenerator().flow_from_directory(training_dir, target_size=input_shape, batch_size=batch_size)

    tensor_shapes = (tf.TensorShape([None, input_shape[0], input_shape[1], 3]), tf.TensorShape([None, NUM_CLASSES]))
    tensor_types = (tf.float32, tf.float32)
    dataset = tf.data.Dataset.from_generator(lambda: generator, tensor_types, tensor_shapes)
    features, labels = dataset.make_one_shot_iterator().get_next()
    return {INPUT_TENSOR_NAME: features}, labels


def train_input_fn(training_dir, hyperparameters):
    return _input_fn(training_dir + '/train/', (224, 224), BATCH_SIZE)


def eval_input_fn(training_dir, hyperparameters):
    return _input_fn(training_dir + '/test/', (224, 224), BATCH_SIZE)


def serving_input_fn(hyperparameters):
    inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 224, 224, 3])}
    return tf.estimator.export.ServingInputReceiver(inputs, inputs)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions