Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Variable <tf.Variable 'conv2d_286/kernel:0' shape=(3, 3, 3, 32) dtype=float32> has None for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). #3

Closed
rishab-sharma opened this issue May 24, 2019 · 8 comments
Assignees
Labels
bug Something isn't working technique:pruning Regarding tfmot.sparsity.keras APIs and docs

Comments

@rishab-sharma
Copy link

I am Pruning an InceptionV3 model

So I create a end to end model

base_model = applications.InceptionV3(weights='imagenet', include_top=False)#, input_tensor=input_tensor)
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(1024, activation='relu')(x)
predictions = Dense(3, kernel_initializer="glorot_uniform", kernel_regularizer=l2(.0005), activation='softmax')(x)

model = Model(base_model.input, predictions)

Then I prune the model by using this command

ConstantSparsity = pruning_schedule.ConstantSparsity

pruning_params = {
    'pruning_schedule': ConstantSparsity(0.75, begin_step=2000, frequency=100)
}

pruned_model = sparsity.prune_low_magnitude(model, **pruning_params)

opt = SGD(lr=.01, momentum=.9)
pruned_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

But When I try to Fit the data for training

callbacks = [
        pruning_callbacks.UpdatePruningStep(),
        pruning_callbacks.PruningSummaries(log_dir = '/content/')#log_dir=FLAGS.output_dir)
    ]

pruned_model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples,
    epochs=epochs,
    validation_data=validation_generator, callbacks=callbacks)

I get the following Error:

Epoch 1/10
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-06fb096a2de2> in <module>()
      3     steps_per_epoch=nb_train_samples,
      4     epochs=epochs,
----> 5     validation_data=validation_generator, callbacks=callbacks)#[lr_scheduler, csv_logger, checkpointer])
      6 
      7 pruned_model.save_weights("{}/final_{}_{}.h5".format(job_path, job_name, model_name))

5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in get_gradients(self, loss, params)
    396                            "gradient defined (i.e. are differentiable). "
    397                            "Common ops without gradient: "
--> 398                            "K.argmax, K.round, K.eval.".format(param))
    399       if hasattr(self, "clipnorm"):
    400         grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]

ValueError: Variable <tf.Variable 'conv2d_286/kernel:0' shape=(3, 3, 3, 32) dtype=float32> has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

Can Anybody please help with this, I can't seem to find a solution for this on StackOverflow or elsewhere.

Thank you

@raziel
Copy link

raziel commented May 24, 2019

Hi.

Does your model train if you don't include any of the pruning calls?

Just to try to understand where the problem is.

Thanks

@rishab-sharma
Copy link
Author

rishab-sharma commented May 24, 2019

Yes, the model train if I don't use

pruned_model = sparsity.prune_low_magnitude(model, **pruning_params)

and directly fit the model

@rishab-sharma
Copy link
Author

The Complete Traceback is as follows:

Epoch 1/10
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-06fb096a2de2> in <module>()
      3     steps_per_epoch=nb_train_samples,
      4     epochs=epochs,
----> 5     validation_data=validation_generator, callbacks=callbacks)#[lr_scheduler, csv_logger, checkpointer])
      6 
      7 pruned_model.save_weights("{}/final_{}_{}.h5".format(job_path, job_name, model_name))

5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1175         shuffle=shuffle,
   1176         initial_epoch=initial_epoch,
-> 1177         steps_name='steps_per_epoch')
   1178 
   1179   def evaluate_generator(self,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    262 
    263       is_deferred = not model._is_compiled
--> 264       batch_outs = batch_function(*batch_data)
    265       if not isinstance(batch_outs, list):
    266         batch_outs = [batch_outs]

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
    916 
    917       self._update_sample_weight_modes(sample_weights=sample_weights)
--> 918       self._make_train_function()
    919       outputs = self.train_function(ins)  # pylint: disable=not-callable
    920 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _make_train_function(self)
   1995           # Training updates
   1996           updates = self.optimizer.get_updates(
-> 1997               params=self._collected_trainable_weights, loss=self.total_loss)
   1998       # Unconditional updates
   1999       updates += self.get_updates_for(None)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in get_updates(self, loss, params)
    489 
    490   def get_updates(self, loss, params):
--> 491     grads = self.get_gradients(loss, params)
    492     grads_and_vars = list(zip(grads, params))
    493     self._assert_valid_dtypes([

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in get_gradients(self, loss, params)
    396                            "gradient defined (i.e. are differentiable). "
    397                            "Common ops without gradient: "
--> 398                            "K.argmax, K.round, K.eval.".format(param))
    399       if hasattr(self, "clipnorm"):
    400         grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]

ValueError: Variable <tf.Variable 'conv2d_286/kernel:0' shape=(3, 3, 3, 32) dtype=float32> has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

@rishab-sharma
Copy link
Author

Also Instead If anyone can provide a working example of pruning InceptionV3 or Resnet, It will help a lot.

Thank you

@nutsiepully
Copy link
Contributor

nutsiepully commented May 30, 2019

Hi Rishab,

Thanks for the bug report.

I was able to try out the code you pasted (almost the same) and it worked for me. I didn't see the error you are getting. I have a few questions.

  1. Are you running this using distributed strategy, or are you running it on single device (default)?
  2. The variable which throws an error in your code is conv2d_286/kernel:0. This seems a bit unlikely since when I try InceptionV3 locally, I only see up till conv2d_93. Possibly there are additional layers getting constructed in your code?

The following is the code I used.


import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

# create the base pre-trained model
base_model = tf.keras.applications.InceptionV3(
        weights='imagenet', include_top=False)

# add a global spatial average pooling layer
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = tf.keras.layers.Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 200 classes
predictions = tf.keras.layers.Dense(200, activation='softmax')(x)

# this is the model we will train
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)

# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
    layer.trainable = False

# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

x_train = np.random.random((1000, 32, 32, 3))
y_train = np.random.randint(200,size=(1000, 1))
labels = tf.keras.utils.to_categorical(y_train, num_classes=200)

# train the model on the new data for a few epochs
# model.fit(x_train, labels)


# Pruning test
import tensorflow_model_optimization.sparsity.keras as sparsity

pruning_params = {
        'pruning_schedule': sparsity.ConstantSparsity(0.75, begin_step=0, frequency=2)
}
pruned_model = sparsity.prune_low_magnitude(model, **pruning_params)

print('Pruned Model - ' + str(pruned_model))
print('Layers ' + str(len(pruned_model.layers)))
for i, layer in enumerate(pruned_model.layers):
      print(i, layer.name, layer.layer.name if not isinstance(layer, tf.keras.layers.InputLayer) else layer.name)


opt = tf.keras.optimizers.SGD(lr=.01, momentum=.9)
pruned_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

callbacks = [
        sparsity.UpdatePruningStep(),
]

pruned_model.fit(
        x_train,
        labels,
        callbacks=callbacks)

This runs training fine for me. Perhaps you can try it out and see if it solves your problem.

If possible can you share the exact code you are running from construction to fit. It would help if I can reproduce it locally.

Thanks.

@nutsiepully nutsiepully added the bug Something isn't working label May 30, 2019
@rishab-sharma
Copy link
Author

rishab-sharma commented May 30, 2019

Yes, you are right
Some additional layers were being added, removing them solves the issue

Just, One more thing, Is low magnitude pruning of InceptionV3 possible on Google-Collab GPU Resources, It's just that when I run your/mine code there, the runtime disconnects.
Does Pruning require more resources?

@nutsiepully
Copy link
Contributor

nutsiepully commented May 30, 2019

Pruning does not require resources significantly different from training the model. It only adds a small overhead. If your resources are capable of training, they should be capable of pruning as well.

As to Google-Colab, I can't answer that. You'll have to estimate your resource requirements, and find computation resources accordingly.

@pren1
Copy link

pren1 commented Sep 20, 2019

For anyone who still could not solve this problem: if you are using Keras, this error might be caused by the custom loss and the model.add_loss function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

No branches or pull requests

5 participants