-
Notifications
You must be signed in to change notification settings - Fork 10
/
configs.py
executable file
·44 lines (32 loc) · 1.08 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import collections
class OnPiTrainConfig(collections.namedtuple(
"TrainConfig",
("models", "optimizer", "gamma", "init_epi", "total_epi", "device")
)):
pass
class TrainConfig(collections.namedtuple(
"TrainConfig",
("batch_size", "models", "optimizer", "grad_clip", "gamma",
"init_step", "total_steps", "init_epi",
"buffer_device", "batch_device", "max_step_per_epi", "save_f")
)):
def print_info(self):
print("Batch size: %d\nGrad clip: %.1f\ngamma: %.2f"%(self.batch_size, self.grad_clip, self.gamma))
print("max_step_per_epi=%d"%self.max_step_per_epi)
class EpsilonGreedyConfig(collections.namedtuple(
"EpsilonGreedyConfig",
("eps_schedule_f", "action_space_f", "eps_print_diff", "t_exploration")
)):
pass
# Supervised configs
class DataConfig(collections.namedtuple(
"DataConfig",
()
)):
pass
class DataTrainConfig(collections.namedtuple(
"TrainConfig",
("optimizer", "criterion", "device", "save_f",
"train_d_loader", "train_d_size", "test_d_loader", "test_d_size")
)):
pass