From a8dfeebbbb9d8d61c66abc64572d7a1ee3185e96 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 13 Feb 2023 11:58:34 -0800 Subject: [PATCH 1/6] perform gradient clipping on global batch when using ShardedStaticAccumulator --- AUTHORS | 8 +++++++ praxis/optimizers.py | 53 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 AUTHORS diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..f0ebbcd0 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,8 @@ +# This is the list of Pax's significant contributors. +# +# This does not necessarily list everyone who has contributed code, +# especially since many employees of one corporation may be contributing. +# To see the full list of contributors, see the revision history in +# source control. +Google LLC +NVIDIA Corporation diff --git a/praxis/optimizers.py b/praxis/optimizers.py index 7cb76461..67aaaa70 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google LLC. +# Copyright 2022 The Pax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2707,6 +2707,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule): def sharded_static_accumulation( num_sub_batches: int, + clip_gradient_norm_to_value: float, + clip_gradient_single_norm_to_value: float, base_tx: ShardedGradientTransformation, ) -> ShardedGradientTransformation: """Gradient transformation for ShardedStaticAccumulator optimizer.""" @@ -2775,8 +2777,54 @@ def update_fn(updates: NestedJTensor, lambda: new_count) def _run_base_tx(): + + def _compute_grad_norm(grads: NestedMap) -> JTensor: + """Computes total grad norm.""" + grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads) + grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared) + return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared))) + + + def scale_gradients( + raw_grads: NestedMap, + clip_grad_norm_to_value: Optional[float] = None, + clip_grad_single_norm_to_value: Optional[float] = None): + + def clip_grads(grads, grad_norm): + if clip_grad_norm_to_value: + assert clip_grad_single_norm_to_value == 0. + grad_scale = jnp.minimum( + jnp.array(1, grad_norm.dtype), + jnp.array(clip_grad_norm_to_value, grad_norm.dtype) + / grad_norm) + grads = jax.tree_map(lambda g: g * grad_scale, grads) + elif clip_grad_single_norm_to_value: + assert clip_grad_norm_to_value == 0. + grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), + grads) + + def scale_gradient(grad, norm): + return grad * jnp.minimum( + jnp.array(1, norm.dtype), + jnp.array(clip_grad_single_norm_to_value, + norm.dtype) / norm) + grads = jax.tree_map(scale_gradient, grads, grad_single_norm) + grad_scale = jnp.array(1.0) + else: + # no clipping is needed. + grad_scale = jnp.array(1.0) + return grads, grad_scale + + raw_grad_norm = _compute_grad_norm(raw_grads) + + grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) + return grads + averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, new_accumulated_update) + averaged_updated = scale_gradients(averaged_updated, + clip_gradient_norm_to_value, + clip_gradient_single_norm_to_value) emission_updates, emission_base_state = base_tx.update( averaged_updated, state.base_state, params) return (emission_updates, @@ -2849,4 +2897,5 @@ def _get_raw_grad_transformation( self, lr: optax.Schedule) -> GeneralGradientTransformation: p = self._hparams base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access - return sharded_static_accumulation(p.num_sub_batches, base_tx) + return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value, + p.clip_gradient_single_norm_to_value, base_tx) From 4380135d4db94d9eb86d8cdcd56383d2a1d90b91 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 23 Feb 2023 12:14:12 -0800 Subject: [PATCH 2/6] remove AUTHORS file --- AUTHORS | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 AUTHORS diff --git a/AUTHORS b/AUTHORS deleted file mode 100644 index f0ebbcd0..00000000 --- a/AUTHORS +++ /dev/null @@ -1,8 +0,0 @@ -# This is the list of Pax's significant contributors. -# -# This does not necessarily list everyone who has contributed code, -# especially since many employees of one corporation may be contributing. -# To see the full list of contributors, see the revision history in -# source control. -Google LLC -NVIDIA Corporation From 08e4292ba7baa407eacae0559a5dc992cc593f5b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 6 Mar 2023 09:36:29 -0800 Subject: [PATCH 3/6] minor refactor, do not return grad_scale --- praxis/optimizers.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/praxis/optimizers.py b/praxis/optimizers.py index 67aaaa70..df244304 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -2790,9 +2790,12 @@ def scale_gradients( clip_grad_norm_to_value: Optional[float] = None, clip_grad_single_norm_to_value: Optional[float] = None): - def clip_grads(grads, grad_norm): + def clip_grads(grads): if clip_grad_norm_to_value: assert clip_grad_single_norm_to_value == 0. + + grad_norm = _compute_grad_norm(raw_grads) + grad_scale = jnp.minimum( jnp.array(1, grad_norm.dtype), jnp.array(clip_grad_norm_to_value, grad_norm.dtype) @@ -2801,7 +2804,7 @@ def clip_grads(grads, grad_norm): elif clip_grad_single_norm_to_value: assert clip_grad_norm_to_value == 0. grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), - grads) + grads) def scale_gradient(grad, norm): return grad * jnp.minimum( @@ -2809,20 +2812,15 @@ def scale_gradient(grad, norm): jnp.array(clip_grad_single_norm_to_value, norm.dtype) / norm) grads = jax.tree_map(scale_gradient, grads, grad_single_norm) - grad_scale = jnp.array(1.0) - else: - # no clipping is needed. - grad_scale = jnp.array(1.0) - return grads, grad_scale - - raw_grad_norm = _compute_grad_norm(raw_grads) + + return grads - grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) + grads = clip_grads(raw_grads) return grads averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, new_accumulated_update) - averaged_updated = scale_gradients(averaged_updated, + scaled_updated = scale_gradients(averaged_updated, clip_gradient_norm_to_value, clip_gradient_single_norm_to_value) emission_updates, emission_base_state = base_tx.update( From 42932eadf5383ec38185aa6a86effd9876a9f8a8 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 6 Mar 2023 11:28:03 -0800 Subject: [PATCH 4/6] fix indent --- praxis/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/praxis/optimizers.py b/praxis/optimizers.py index e997b0ab..18932c9c 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -2827,8 +2827,8 @@ def scale_gradient(grad, norm): averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, new_accumulated_update) scaled_updated = scale_gradients(averaged_updated, - clip_gradient_norm_to_value, - clip_gradient_single_norm_to_value) + clip_gradient_norm_to_value, + clip_gradient_single_norm_to_value) emission_updates, emission_base_state = base_tx.update( averaged_updated, state.base_state, params) return (emission_updates, From 54bdc12c74a92e51c9943b91bccd99505dadaa49 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 8 Mar 2023 10:19:48 -0800 Subject: [PATCH 5/6] fix formatting, small ga bug fix --- praxis/optimizers.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/praxis/optimizers.py b/praxis/optimizers.py index 18932c9c..1343d6fe 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -2714,7 +2714,7 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule): def sharded_static_accumulation( num_sub_batches: int, clip_gradient_norm_to_value: float, - clip_gradient_single_norm_to_value: float, + clip_gradient_single_norm_to_value: float, base_tx: ShardedGradientTransformation, ) -> ShardedGradientTransformation: """Gradient transformation for ShardedStaticAccumulator optimizer.""" @@ -2783,25 +2783,25 @@ def update_fn(updates: NestedJTensor, lambda: new_count) def _run_base_tx(): - + def _compute_grad_norm(grads: NestedMap) -> JTensor: """Computes total grad norm.""" grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads) grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared) return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared))) - - + + def scale_gradients( raw_grads: NestedMap, clip_grad_norm_to_value: Optional[float] = None, clip_grad_single_norm_to_value: Optional[float] = None): - + def clip_grads(grads): if clip_grad_norm_to_value: assert clip_grad_single_norm_to_value == 0. - + grad_norm = _compute_grad_norm(raw_grads) - + grad_scale = jnp.minimum( jnp.array(1, grad_norm.dtype), jnp.array(clip_grad_norm_to_value, grad_norm.dtype) @@ -2811,7 +2811,7 @@ def clip_grads(grads): assert clip_grad_norm_to_value == 0. grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), grads) - + def scale_gradient(grad, norm): return grad * jnp.minimum( jnp.array(1, norm.dtype), @@ -2820,17 +2820,17 @@ def scale_gradient(grad, norm): grads = jax.tree_map(scale_gradient, grads, grad_single_norm) return grads - + grads = clip_grads(raw_grads) return grads - + averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, new_accumulated_update) - scaled_updated = scale_gradients(averaged_updated, + scaled_updated = scale_gradients(averaged_updated, clip_gradient_norm_to_value, clip_gradient_single_norm_to_value) emission_updates, emission_base_state = base_tx.update( - averaged_updated, state.base_state, params) + scaled_updated, state.base_state, params) return (emission_updates, jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32), updates), emission_base_state) From 40a6d805248b7c2e6c4d1c8fb802b05020a16360 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sat, 18 Mar 2023 16:04:39 -0700 Subject: [PATCH 6/6] address PR comments --- praxis/optimizers.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/praxis/optimizers.py b/praxis/optimizers.py index c8deaf25..d7e25d73 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -2771,13 +2771,12 @@ def _compute_grad_norm(grads: NestedMap) -> JTensor: def scale_gradients( raw_grads: NestedMap, - clip_grad_norm_to_value: Optional[float] = None, - clip_grad_single_norm_to_value: Optional[float] = None): + clip_grad_norm_to_value: float = 0.0, + clip_grad_single_norm_to_value: float = 0.0): def clip_grads(grads): + assert not (clip_grad_norm_to_value and clip_grad_single_norm_to_value) if clip_grad_norm_to_value: - assert clip_grad_single_norm_to_value == 0. - grad_norm = _compute_grad_norm(raw_grads) grad_scale = jnp.minimum( @@ -2786,7 +2785,6 @@ def clip_grads(grads): / grad_norm) grads = jax.tree_map(lambda g: g * grad_scale, grads) elif clip_grad_single_norm_to_value: - assert clip_grad_norm_to_value == 0. grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), grads)