Skip to content

Commit 540aa44

Browse files
committed
Apply black==23.3.0 and isort==5.11.5
1 parent ee0f363 commit 540aa44

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+8
-113
lines changed

examples/atari/reproduction/a3c/train_a3c.py

-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717

1818
def main():
19-
2019
parser = argparse.ArgumentParser()
2120
parser.add_argument("--processes", type=int, default=16)
2221
parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
@@ -176,7 +175,6 @@ def phi(x):
176175
)
177176
)
178177
else:
179-
180178
# Linearly decay the learning rate to zero
181179
def lr_setter(env, agent, value):
182180
for pg in agent.optimizer.param_groups:

examples/atari/train_acer_ale.py

-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020

2121
def main():
22-
2322
parser = argparse.ArgumentParser()
2423
parser.add_argument("processes", type=int)
2524
parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
@@ -185,7 +184,6 @@ def make_env(process_idx, test):
185184
)
186185
)
187186
else:
188-
189187
# Linearly decay the learning rate to zero
190188
def lr_setter(env, agent, value):
191189
for pg in agent.optimizer.param_groups:

examples/atlas/train_soft_actor_critic_atlas.py

-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def make_env(args, seed, test):
4545

4646

4747
def main():
48-
4948
parser = argparse.ArgumentParser()
5049
parser.add_argument(
5150
"--outdir",

examples/gym/train_dqn_gym.py

-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def make_env(idx=0, test=False):
210210
)
211211

212212
elif not args.actor_learner:
213-
214213
print(
215214
"WARNING: Since https://github.com/pfnet/pfrl/pull/112 we have started"
216215
" setting `eval_during_episode=True` in this script, which affects the"

examples/mujoco/reproduction/ddpg/train_ddpg.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323

2424
def main():
25-
2625
parser = argparse.ArgumentParser()
2726
parser.add_argument(
2827
"--outdir",

examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222

2323
def main():
24-
2524
parser = argparse.ArgumentParser()
2625
parser.add_argument(
2726
"--outdir",

examples/mujoco/reproduction/td3/train_td3.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020

2121
def main():
22-
2322
parser = argparse.ArgumentParser()
2423
parser.add_argument(
2524
"--outdir",

examples/mujoco/reproduction/trpo/train_trpo.py

-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717

1818
def main():
19-
2019
parser = argparse.ArgumentParser()
2120
parser.add_argument(
2221
"--gpu", type=int, default=0, help="GPU device ID. Set to -1 to use CPUs only."
@@ -215,7 +214,6 @@ def ortho_init(layer, gain):
215214
with open(os.path.join(args.outdir, "demo_scores.json"), "w") as f:
216215
json.dump(eval_stats, f)
217216
else:
218-
219217
pfrl.experiments.train_agent_with_evaluation(
220218
agent=agent,
221219
env=env,

pfrl/agents/a2c.py

-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
average_value_decay=0.999,
7272
batch_states=batch_states,
7373
):
74-
7574
self.model = model
7675
if gpu is not None and gpu >= 0:
7776
assert torch.cuda.is_available()

pfrl/agents/a3c.py

-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464
average_value_decay=0.999,
6565
batch_states=batch_states,
6666
):
67-
6867
# Globally shared model
6968
self.shared_model = model
7069

@@ -241,7 +240,6 @@ def observe(self, obs, reward, done, reset):
241240
self._observe_eval(obs, reward, done, reset)
242241

243242
def _act_train(self, obs):
244-
245243
self.past_obs[self.t] = obs
246244

247245
with torch.no_grad():

pfrl/agents/acer.py

-5
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def __init__(
332332
average_kl_decay=0.999,
333333
logger=None,
334334
):
335-
336335
# Globally shared model
337336
self.shared_model = model
338337

@@ -472,7 +471,6 @@ def compute_loss(
472471
action_distribs_mu,
473472
avg_action_distribs,
474473
):
475-
476474
assert np.isscalar(R)
477475
pi_loss = 0
478476
Q_loss = 0
@@ -566,7 +564,6 @@ def update(
566564
action_distribs_mu,
567565
avg_action_distribs,
568566
):
569-
570567
assert np.isscalar(R)
571568
self.assert_shared_memory()
572569

@@ -595,7 +592,6 @@ def update(
595592
self.sync_parameters()
596593

597594
def update_from_replay(self):
598-
599595
if self.replay_buffer is None:
600596
return
601597

@@ -715,7 +711,6 @@ def observe(self, obs, reward, done, reset):
715711
self._observe_eval(obs, reward, done, reset)
716712

717713
def _act_train(self, obs):
718-
719714
statevar = batch_states([obs], self.device, self.phi)
720715

721716
if self.recurrent:

pfrl/agents/al.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self, *args, **kwargs):
2121
super().__init__(*args, **kwargs)
2222

2323
def _compute_y_and_t(self, exp_batch):
24-
2524
batch_state = exp_batch["state"]
2625
batch_size = len(exp_batch["reward"])
2726

pfrl/agents/ddpg.py

-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def __init__(
8181
batch_states=batch_states,
8282
burnin_action_func=None,
8383
):
84-
8584
self.model = nn.ModuleList([policy, q_func])
8685
if gpu is not None and gpu >= 0:
8786
assert torch.cuda.is_available()
@@ -223,7 +222,6 @@ def update_from_episodes(self, episodes, errors_out=None):
223222
batches.append(batch)
224223

225224
with self.model.state_reset(), self.target_model.state_reset():
226-
227225
# Since the target model is evaluated one-step ahead,
228226
# its internal states need to be updated
229227
self.target_q_function.update_state(
@@ -238,7 +236,6 @@ def update_from_episodes(self, episodes, errors_out=None):
238236
self.critic_optimizer.update(lambda: critic_loss / max_epi_len)
239237

240238
with self.model.state_reset():
241-
242239
# Update actor through time
243240
actor_loss = 0
244241
for batch in batches:

pfrl/agents/double_dqn.py

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ class DoubleDQN(dqn.DQN):
1010
"""
1111

1212
def _compute_target_values(self, exp_batch):
13-
1413
batch_next_state = exp_batch["next_state"]
1514

1615
with evaluating(self.model):

pfrl/agents/double_pal.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
class DoublePAL(pal.PAL):
88
def _compute_y_and_t(self, exp_batch):
9-
109
batch_state = exp_batch["state"]
1110
batch_size = len(exp_batch["reward"])
1211

pfrl/agents/dpp.py

-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def _l_operator(self, qout):
1717
raise NotImplementedError()
1818

1919
def _compute_target_values(self, exp_batch):
20-
2120
batch_next_state = exp_batch["next_state"]
2221

2322
if self.recurrent:
@@ -38,7 +37,6 @@ def _compute_target_values(self, exp_batch):
3837
)
3938

4039
def _compute_y_and_t(self, exp_batch):
41-
4240
batch_state = exp_batch["state"]
4341
batch_size = len(exp_batch["reward"])
4442

pfrl/agents/dqn.py

+7-21
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import ctypes
44
import multiprocessing as mp
55
import multiprocessing.synchronize
6-
import time
76
import os
7+
import time
88
from logging import Logger, getLogger
99
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
1010

@@ -512,7 +512,6 @@ def _batch_observe_train(
512512
batch_done: Sequence[bool],
513513
batch_reset: Sequence[bool],
514514
) -> None:
515-
516515
for i in range(len(batch_obs)):
517516
self.t += 1
518517
self._cumulative_steps += 1
@@ -793,34 +792,21 @@ def stop_episode(self) -> None:
793792

794793
def save_snapshot(self, dirname: str) -> None:
795794
self.save(dirname)
796-
torch.save(
797-
self.t, os.path.join(dirname, "t.pt")
798-
)
799-
torch.save(
800-
self.optim_t, os.path.join(dirname, "optim_t.pt")
801-
)
795+
torch.save(self.t, os.path.join(dirname, "t.pt"))
796+
torch.save(self.optim_t, os.path.join(dirname, "optim_t.pt"))
802797
torch.save(
803798
self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt")
804799
)
805-
self.replay_buffer.save(
806-
os.path.join(dirname, "replay_buffer.pkl")
807-
)
808-
800+
self.replay_buffer.save(os.path.join(dirname, "replay_buffer.pkl"))
809801

810802
def load_snapshot(self, dirname: str) -> None:
811803
self.load(dirname)
812-
self.t = torch.load(
813-
os.path.join(dirname, "t.pt")
814-
)
815-
self.optim_t = torch.load(
816-
os.path.join(dirname, "optim_t.pt")
817-
)
804+
self.t = torch.load(os.path.join(dirname, "t.pt"))
805+
self.optim_t = torch.load(os.path.join(dirname, "optim_t.pt"))
818806
self._cumulative_steps = torch.load(
819807
os.path.join(dirname, "_cumulative_steps.pt")
820808
)
821-
self.replay_buffer.load(
822-
os.path.join(dirname, "replay_buffer.pkl")
823-
)
809+
self.replay_buffer.load(os.path.join(dirname, "replay_buffer.pkl"))
824810

825811
def get_statistics(self):
826812
return [

pfrl/agents/pal.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self, *args, **kwargs):
2121
super().__init__(*args, **kwargs)
2222

2323
def _compute_y_and_t(self, exp_batch):
24-
2524
batch_state = exp_batch["state"]
2625
batch_size = len(exp_batch["reward"])
2726

pfrl/agents/ppo.py

-3
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def _add_log_prob_and_value_to_episodes(
115115
obs_normalizer,
116116
device,
117117
):
118-
119118
dataset = list(itertools.chain.from_iterable(episodes))
120119

121120
# Compute v_pred and next_v_pred
@@ -533,7 +532,6 @@ def _update(self, dataset):
533532
self.n_updates += 1
534533

535534
def _update_once_recurrent(self, episodes, mean_advs, std_advs):
536-
537535
assert std_advs is None or std_advs > 0
538536

539537
device = self.device
@@ -636,7 +634,6 @@ def _update_recurrent(self, dataset):
636634
def _lossfun(
637635
self, entropy, vs_pred, log_probs, vs_pred_old, log_probs_old, advs, vs_teacher
638636
):
639-
640637
prob_ratio = torch.exp(log_probs - log_probs_old)
641638

642639
loss_policy = -torch.mean(

pfrl/agents/reinforce.py

-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
max_grad_norm=None,
5858
logger=None,
5959
):
60-
6160
self.model = model
6261
if gpu is not None and gpu >= 0:
6362
assert torch.cuda.is_available()
@@ -103,7 +102,6 @@ def observe(self, obs, reward, done, reset):
103102
self._observe_eval(obs, reward, done, reset)
104103

105104
def _act_train(self, obs):
106-
107105
batch_obs = self.batch_states([obs], self.device, self.phi)
108106
if self.recurrent:
109107
action_distrib, self.train_recurrent_states = one_step_forward(

pfrl/agents/soft_actor_critic.py

-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def __init__(
119119
temperature_optimizer_lr=None,
120120
act_deterministically=True,
121121
):
122-
123122
self.policy = policy
124123
self.q_func1 = q_func1
125124
self.q_func2 = q_func2

pfrl/agents/td3.py

-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def __init__(
101101
policy_update_delay=2,
102102
target_policy_smoothing_func=default_target_policy_smoothing_func,
103103
):
104-
105104
self.policy = policy
106105
self.q_func1 = q_func1
107106
self.q_func2 = q_func2

pfrl/agents/trpo.py

-3
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def __init__(
193193
policy_step_size_stats_window=100,
194194
logger=getLogger(__name__),
195195
):
196-
197196
self.policy = policy
198197
self.vf = vf
199198
self.vf_optimizer = vf_optimizer
@@ -335,7 +334,6 @@ def _update_recurrent(self, dataset):
335334
self._update_vf_recurrent(dataset)
336335

337336
def _update_vf_recurrent(self, dataset):
338-
339337
for epoch in range(self.vf_epochs):
340338
random.shuffle(dataset)
341339
for (
@@ -346,7 +344,6 @@ def _update_vf_recurrent(self, dataset):
346344
self._update_vf_once_recurrent(minibatch)
347345

348346
def _update_vf_once_recurrent(self, episodes):
349-
350347
# Sort episodes desc by length for pack_sequence
351348
episodes = sorted(episodes, key=len, reverse=True)
352349

pfrl/experiments/train_agent.py

-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def train_agent(
3535
eval_during_episode=False,
3636
logger=None,
3737
):
38-
3938
logger = logger or logging.getLogger(__name__)
4039

4140
episode_r = 0
@@ -52,7 +51,6 @@ def train_agent(
5251
episode_len = 0
5352
try:
5453
while t < steps:
55-
5654
# a_t
5755
action = agent.act(obs)
5856
# o_{t+1}, r_{t+1}

0 commit comments

Comments
 (0)