Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631065779
  • Loading branch information
psc-g committed May 6, 2024
1 parent 0a5b581 commit 4624114
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
5 changes: 4 additions & 1 deletion dopamine/labs/moes/agents/full_rainbow_moe_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Compact implementation of the full Rainbow agent in JAX with MoE modules."""

import functools

from absl import logging
from dopamine.jax import losses
from dopamine.jax.agents.full_rainbow import full_rainbow_agent
Expand Down Expand Up @@ -109,7 +110,7 @@ def loss_fn(params, target, loss_multipliers):
aux_losses = []
if isinstance(net_outputs, arch_types.MoENetworkReturn):
# We may be running a BASELINE agent, which would not contain any MoE
# statistics, so we condition this code on *not* being a BASELINE..
# statistics, so we condition this code on *not* being a BASELINE.
aux_losses = moe_losses.aux_loss(
types.MoELossParameters(
moe_out=net_outputs.moe_out,
Expand Down Expand Up @@ -157,6 +158,7 @@ def loss_fn(params, target, loss_multipliers):
grad, optimizer_state, params=online_params
)
online_params = optax.apply_updates(online_params, updates)

train_returns = {
'optimizer_state': optimizer_state,
'online_params': online_params,
Expand All @@ -171,6 +173,7 @@ def loss_fn(params, target, loss_multipliers):
):
moe_statistics = {}
experts_prob = jnp.mean(jnp.mean(aux_vars['experts_prob'], axis=0), axis=0)

# TODO(gsokar) revisit this if we explore multiple routers.
if networks.MoEType[network_def.moe_type] == networks.MoEType.SOFTMOE:
grads_router = [
Expand Down
2 changes: 1 addition & 1 deletion dopamine/labs/moes/architectures/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModuleReturn:
axis_size=self.num_experts,
# TODO(jfarebro): Supply logical sharding axes
)(self.module, mixture_inputs)

expert_output_dims = experts.shape[-1]
experts_hidden = experts_hidden.reshape(-1, experts_hidden.shape[-1])

# Step 4: Reverse permutation
#
Expand Down
7 changes: 4 additions & 3 deletions dopamine/labs/moes/architectures/softmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModuleReturn:
x_normalized,
phi_weights,
)

dispatch_weights = jax.nn.softmax(logits, axis=0)
combine_weights = jax.nn.softmax(logits, axis=(1, 2))

Expand Down Expand Up @@ -125,8 +126,6 @@ def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModuleReturn:
if self.expert_type == "BIG":
experts = experts.reshape(self.num_experts, num_slots, token_length)

experts_hidden = experts_hidden.reshape(-1, experts_hidden.shape[-1])

# The output tokens are weighted average of all slots.
outputs = jnp.einsum("npd,mnp->md", experts, combine_weights)

Expand All @@ -140,5 +139,7 @@ def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModuleReturn:
),
)
return types.MoEModuleReturn(
values=outputs, router_out=router_out, experts_hidden=experts_hidden
values=outputs,
router_out=router_out,
experts_hidden=experts_hidden,
)
2 changes: 1 addition & 1 deletion dopamine/labs/redo/weight_recyclers.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def _score2mask(self, activation, param, next_param, key):
pass
elif self.score_type == 'random':
new_key = random.fold_in(key, self._last_update_step)
score = random.shuffle(new_key, score)
score = random.permutation(new_key, score, independent=True)
elif self.score_type == 'redo_inverted':
score = -score
# Metric used in Continual Backprop pape.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

setup(
name='dopamine_rl',
version='4.0.6',
version='4.0.7',
description=dopamine_description,
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 4624114

Please sign in to comment.