From ea296266ce6e01709d0050c96d4df908ceb4bd1a Mon Sep 17 00:00:00 2001 From: oliverhu Date: Sat, 10 Jul 2021 09:25:50 -0700 Subject: [PATCH] TensorFlow 2.4+ compatibility --- byteps/_keras/__init__.py | 32 +++++++++++++++---- example/tensorflow/tensorflow2_keras_mnist.py | 3 +- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/byteps/_keras/__init__.py b/byteps/_keras/__init__.py index 3626e4036d..098d3c44be 100644 --- a/byteps/_keras/__init__.py +++ b/byteps/_keras/__init__.py @@ -40,9 +40,23 @@ def get_gradients(self, loss, params): gradients = super(self.__class__, self).get_gradients(loss, params) return self._push_pull(gradients) - def _aggregate_gradients(self, grads_and_vars): - gradients = [grad for grad, var in grads_and_vars] - return self._push_pull(gradients) + def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): + """ + Compute gradients of all trainable variables. + See Optimizer.get_gradients() for more info. + In DistributedOptimizer, get_gradients() is overriden to also + allreduce the gradients before returning them. + """ + tape = backprop.GradientTape() if tape is None else tape + grads_and_vars = super(self.__class__, self)._compute_gradients( + # pylint: disable=protected-access + loss, + var_list, + grad_loss, + tape=tape) + grads, weights = list(zip(*grads_and_vars)) + grads = self._push_pull(grads) + return list(zip(grads, weights)) def _push_pull(self, gradients): self._aggregated_gradients = True @@ -51,9 +65,15 @@ def _push_pull(self, gradients): with tf.name_scope(self._name + "_Push_Pull") as scope: for grad in gradients: if grad is not None: - if self._sparse_as_dense and \ - isinstance(grad, tf.IndexedSlices): - grad = tf.convert_to_tensor(grad) + if isinstance(grad, tf.IndexedSlices): + if self.sparse_as_dense: + return tf.convert_to_tensor(grad) + else: + raise ValueError( + "IndexedSlices are not supported when " + "`backward_passes_per_step` > 1 and " + "`sparse_as_dense` is False." + ) avg_grad = bps.push_pull(grad, scope, device_dense=self._device_dense, device_sparse=self._device_sparse, diff --git a/example/tensorflow/tensorflow2_keras_mnist.py b/example/tensorflow/tensorflow2_keras_mnist.py index 82ded21641..147e40b4c9 100644 --- a/example/tensorflow/tensorflow2_keras_mnist.py +++ b/example/tensorflow/tensorflow2_keras_mnist.py @@ -61,8 +61,7 @@ # uses bps.DistributedOptimizer() to compute gradients. mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(), optimizer=opt, - metrics=['accuracy'], - experimental_run_tf_function=False) + metrics=['accuracy']) callbacks = [ # byteps: broadcast initial variable states from rank 0 to all other processes.