Skip to content

Commit

Permalink
ppg
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 1, 2024
1 parent 5679b6b commit a28313e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
21 changes: 3 additions & 18 deletions xuance/mindspore/agents/policy_gradient/ppg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,12 @@ def action(self, observations: np.ndarray,
dists: The policy distributions.
log_pi: Log of stochastic actions.
"""
shape_obs = observations.shape
if len(shape_obs) > 2:
observations = observations.reshape(-1, shape_obs[-1])
if self.continuous_control:
_, policy_mean, values, _ = self.policy(observations)
policy_mean = policy_mean.numpy().reshape(shape_obs[:-1] + self.action_space.shape)
policy_std = ms.exp(self.policy.actor.logstd).numpy()
self.policy.actor.dist.set_param(policy_mean, policy_std)
else:
_, policy_logits, values, _ = self.policy(observations)
policy_logits = policy_logits.numpy().reshape(shape_obs[:-1] + (self.action_space.n, ))
self.policy.actor.dist.set_param(logits=policy_logits)
values = ms.reshape(values, shape_obs[:-1])
else:
_, _, values, _ = self.policy(observations)
policy_dists = self.policy.actor.dist
_, policy_dists, values, _ = self.policy(observations)
actions = policy_dists.stochastic_sample()
log_pi = policy_dists.log_prob(actions) if return_logpi else None
dists = split_distributions(policy_dists) if return_dists else None
actions = actions.numpy()
values = values.numpy()
actions = actions.asnumpy()
values = values.asnumpy()
return {"actions": actions, "values": values, "dists": dists, "log_pi": log_pi}

def get_aux_info(self, policy_output: dict = None):
Expand Down
9 changes: 6 additions & 3 deletions xuance/mindspore/learners/policy_gradient/ppg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def forward_fn_critic(self, obs_batch, ret_batch):
loss = self.mse_loss(v_pred, ret_batch)
return loss, v_pred

def forward_fn_auxiliary(self, obs_batch, ret_batch, old_dist):
def forward_fn_auxiliary(self, obs_batch, ret_batch, old_probs):
_, a_dist, v, aux_v = self.policy(obs_batch)
aux_loss = self.mse_loss(v, aux_v)
kl_loss = self._categorical.kl_loss('Categorical', a_dist, old_dist).mean()
kl_loss = a_dist.distribution.kl_loss('Categorical', old_probs).mean()
value_loss = self.mse_loss(v, ret_batch)
loss = aux_loss + self.kl_beta * kl_loss + value_loss
return loss, v
Expand Down Expand Up @@ -99,9 +99,12 @@ def update_critic(self, **samples):
def update_auxiliary(self, **samples):
obs_batch = samples['obs']
ret_batch = Tensor(samples['returns'])
act_batch = Tensor(samples['actions'])
old_dist = merge_distributions(samples['aux_batch']['old_dist'])
old_logp_batch = old_dist.log_prob(act_batch)
old_probs = self._exp(old_logp_batch)

(loss, v), grads = self.grad_fn_auxiliary(obs_batch, ret_batch, old_dist)
(loss, v), grads = self.grad_fn_auxiliary(obs_batch, ret_batch, old_probs)
self.optimizer(grads)

info = {
Expand Down
34 changes: 30 additions & 4 deletions xuance/mindspore/policies/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def construct(self, observation: Tensor):


class PPGActorCritic(Module):
"""
Actor-Critic for PPG with categorical distributions. (Discrete action space)
Args:
action_space (Discrete): The discrete action space.
representation (Module): The representation module.
actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network.
critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
action_space: Discrete,
representation: Module,
Expand All @@ -95,12 +108,25 @@ def __init__(self,
normalize, initialize, activation)

def construct(self, observation: Tensor):
"""
Returns the actors representation output, action distribution, values, and auxiliary values.
Parameters:
observation: The original observation of agent.
Returns:
policy_outputs: The outputs of actor representation.
a_dist: The distribution of actions output by actor.
value: The state values output by critic.
aux_value: The auxiliary values output by aux_critic.
"""
policy_outputs = self.actor_representation(observation)
critic_outputs = self.critic_representation(observation)
a = self.actor(policy_outputs['state'])
v = self.critic(critic_outputs['state'])
aux_v = self.aux_critic(policy_outputs['state'])
return policy_outputs, a, v, aux_v
aux_critic_outputs = self.aux_critic_representation(observation)
a_dist = self.actor(policy_outputs['state'])
value = self.critic(critic_outputs['state'])
aux_value = self.aux_critic(aux_critic_outputs['state'])
return policy_outputs, a_dist, value, aux_value


# class SACDISPolicy(Module):
Expand Down

0 comments on commit a28313e

Please sign in to comment.