Skip to content

Commit 2f5309b

Browse files
committed
Actor Critic finished
1 parent 1d5d739 commit 2f5309b

File tree

5 files changed

+749
-196
lines changed

5 files changed

+749
-196
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ ActorCritic/out
146146
ActorCritic/runs
147147

148148
runs/
149+
thrash/

ActorCritic/Actor Critic Notes.ipynb

+734-60
Large diffs are not rendered by default.

ActorCritic/actor_critic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def critic_update(self, V, V_target):
9898

9999
def actor_update(self, advantages, actions, mus):
100100
self.actor_optimizer.zero_grad()
101-
actor_loss = advantages * self.actor_loss(actions, mus)
102-
actor_loss.backward()
101+
actor_loss = self.actor_loss(actions, mus)
102+
gradient_term = advantages * actor_loss
103+
gradient_term.backward()
103104
self.actor_optimizer.step()
104105

105106
return actor_loss.item()

ActorCritic/main.py

-129
This file was deleted.

ActorCritic/policy_gradient.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def load_checkpoint(self, filepath):
6868
self.policy_net.load_state_dict(checkpoint['policy'])
6969
self.optimizer.load_state_dict(checkpoint['optimizer'])
7070

71+
def actor_update(self, Q, action, mu):
72+
self.actor_optimizer.zero_grad()
73+
actor_loss = self.actor_loss(action, mu)
74+
gradient_term = Q * actor_loss
75+
gradient_term.backward()
76+
self.actor_optimizer.step()
77+
78+
return actor_loss.item()
79+
7180
def update(self, sample, prepare_state=None):
7281
"""
7382
prepare_state is a function that does feature engineering on the plain state
@@ -85,11 +94,8 @@ def update(self, sample, prepare_state=None):
8594

8695
mu = self.policy_net(state)
8796

88-
self.optimizer.zero_grad()
89-
loss = Q * self.loss(action, mu)
90-
loss.backward()
91-
self.optimizer.step()
92-
episode_running_loss.append(loss.item())
97+
loss = self.actor_update(Q, action, mu)
98+
episode_running_loss.append(loss)
9399

94100
return episode_running_loss
95101

0 commit comments

Comments
 (0)