Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:MushroomRL/mushroom-rl into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Jan 23, 2024
2 parents 59aa527 + 861451c commit 146ba8e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _line_search(self, obs, act, adv, old_log_prob, old_pol_dist, prev_loss, ste
direction = self._fisher_vector_product(stepdir, obs, old_pol_dist).detach()
shs = .5 * stepdir.dot(direction)
lm = torch.sqrt(shs / self._max_kl())
full_step = (stepdir / lm).detach().cpu().numpy()
full_step = (stepdir / lm).detach()
stepsize = 1.

# Save old policy parameters
Expand Down
22 changes: 13 additions & 9 deletions mushroom_rl/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class TorchPolicy(Policy):
required.
"""

# TODO: remove TorchUtils.to_float_tensor(array) and update the docstring to replace np.ndarray.

def __init__(self, policy_state_shape=None):
"""
Constructor.
Expand All @@ -28,14 +31,14 @@ def __init__(self, policy_state_shape=None):
super().__init__(policy_state_shape)

def __call__(self, state, action, policy_state=None):
s = TorchUtils.to_float_tensor(np.atleast_2d(state))
a = TorchUtils.to_float_tensor(np.atleast_2d(action))
s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
a = TorchUtils.to_float_tensor(torch.atleast_2d(action))

return np.exp(self.log_prob_t(s, a).item())
return torch.exp(self.log_prob_t(s, a))

def draw_action(self, state, policy_state=None):
with torch.no_grad():
s = TorchUtils.to_float_tensor(np.atleast_2d(state))
s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
a = self.draw_action_t(s)

return torch.squeeze(a, dim=0).detach(), None
Expand Down Expand Up @@ -71,7 +74,7 @@ def entropy(self, state=None):
"""
s = TorchUtils.to_float_tensor(state) if state is not None else None

return self.entropy_t(s).detach().cpu().numpy().item()
return self.entropy_t(s).detach()

def draw_action_t(self, state):
"""
Expand Down Expand Up @@ -189,7 +192,7 @@ def __init__(self, network, input_shape, output_shape, std_0=1., policy_state_sh
self._mu = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params)
self._predict_params = dict()

log_sigma_init = TorchUtils.to_float_tensor(torch.ones(self._action_dim) * np.log(std_0))
log_sigma_init = torch.ones(self._action_dim, device=TorchUtils.get_device()) * torch.log(TorchUtils.to_float_tensor(std_0))

self._log_sigma = nn.Parameter(log_sigma_init)

Expand All @@ -207,7 +210,8 @@ def log_prob_t(self, state, action):
return self.distribution_t(state).log_prob(action)[:, None]

def entropy_t(self, state=None):
return self._action_dim / 2 * np.log(2 * np.pi * np.e) + torch.sum(self._log_sigma)
return self._action_dim / 2 * torch.log(TorchUtils.to_float_tensor(2 * np.pi * np.e))\
+ torch.sum(self._log_sigma)

def distribution_t(self, state):
mu, chol_sigma = self.get_mean_and_chol(state)
Expand All @@ -225,9 +229,9 @@ def set_weights(self, weights):

def get_weights(self):
mu_weights = self._mu.get_weights()
sigma_weights = self._log_sigma.data.detach().cpu().numpy()
sigma_weights = self._log_sigma.data.detach()

return np.concatenate([mu_weights, sigma_weights])
return torch.concatenate([mu_weights, sigma_weights])

def parameters(self):
return chain(self._mu.model.network.parameters(), [self._log_sigma])
Expand Down
16 changes: 8 additions & 8 deletions tests/policy/test_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def test_gaussian_torch_policy():
torch.manual_seed(88)
pi = GaussianTorchPolicy(Network, (3,), (2,), n_features=50)

state = np.random.rand(3)
state = torch.as_tensor(np.random.rand(3))
action, _ = pi.draw_action(state)
action_test = np.array([-0.21276927, 0.27437747])
assert np.allclose(action, action_test)
assert np.allclose(action.detach().cpu().numpy(), action_test)

p_sa = pi(state, action)
p_sa = pi(state, torch.as_tensor(action))
p_sa_test = 0.07710557966732147
assert np.allclose(p_sa, p_sa_test)
assert np.allclose(p_sa.detach().cpu().numpy(), p_sa_test)

entropy = pi.entropy()
entropy_test = 2.837877
Expand All @@ -79,16 +79,16 @@ def test_boltzmann_torch_policy():
beta = Parameter(1.0)
pi = BoltzmannTorchPolicy(Network, (3,), (2,), beta, n_features=50)

state = np.random.rand(3, 3)
state = torch.as_tensor(np.random.rand(3, 3))
action, _ = pi.draw_action(state)
action_test = np.array([1, 0, 0])
assert np.allclose(action, action_test)
assert np.allclose(action.detach().cpu().numpy(), action_test)

p_sa = pi(state[0], action[0])
p_sa_test = 0.24054041611818922
assert np.allclose(p_sa, p_sa_test)
assert np.allclose(p_sa.detach(), p_sa_test)

states = np.random.rand(1000, 3)
entropy = pi.entropy(states)
entropy_test = 0.5428627133369446
assert np.allclose(entropy, entropy_test)
assert np.allclose(entropy.detach().cpu().numpy(), entropy_test)

0 comments on commit 146ba8e

Please sign in to comment.