-
Notifications
You must be signed in to change notification settings - Fork 16
/
run_cl.py
141 lines (131 loc) · 4.32 KB
/
run_cl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from typing import Callable, Iterable, List
from continualworld.envs import get_cl_env, get_single_env
from continualworld.methods.vcl import VclMlpActor
from continualworld.sac.models import MlpActor
from continualworld.sac.utils.logx import EpochLogger
from continualworld.tasks import TASK_SEQS
from continualworld.utils.enums import BufferType
from continualworld.utils.run_utils import get_sac_class
from continualworld.utils.utils import get_activation_from_str
from input_args import cl_parse_args
def main(
logger: EpochLogger,
tasks: str,
task_list: List[str],
seed: int,
steps_per_task: int,
log_every: int,
replay_size: int,
batch_size: int,
hidden_sizes: Iterable[int],
buffer_type: str,
reset_buffer_on_task_change: bool,
reset_optimizer_on_task_change: bool,
activation: Callable,
use_layer_norm: bool,
lr: float,
gamma: float,
alpha: str,
target_output_std: float,
cl_method: str,
packnet_retrain_steps: int,
regularize_critic: bool,
cl_reg_coef: float,
vcl_first_task_kl: bool,
episodic_mem_per_task: int,
episodic_batch_size: int,
reset_critic_on_task_change: bool,
multihead_archs: bool,
hide_task_id: bool,
clipnorm: float,
agent_policy_exploration: bool,
):
assert (tasks is None) != (task_list is None)
if tasks is not None:
tasks = TASK_SEQS[tasks]
else:
tasks = task_list
train_env = get_cl_env(tasks, steps_per_task)
# Consider normalizing test envs in the future.
num_tasks = len(tasks)
test_envs = [
get_single_env(task, one_hot_idx=i, one_hot_len=num_tasks) for i, task in enumerate(tasks)
]
steps = steps_per_task * len(tasks)
num_heads = num_tasks if multihead_archs else 1
actor_kwargs = dict(
hidden_sizes=hidden_sizes,
activation=get_activation_from_str(activation),
use_layer_norm=use_layer_norm,
num_heads=num_heads,
hide_task_id=hide_task_id,
)
critic_kwargs = dict(
hidden_sizes=hidden_sizes,
activation=get_activation_from_str(activation),
use_layer_norm=use_layer_norm,
num_heads=num_heads,
hide_task_id=hide_task_id,
)
if cl_method == "vcl":
actor_cl = VclMlpActor
else:
actor_cl = MlpActor
vanilla_sac_kwargs = {
"env": train_env,
"test_envs": test_envs,
"logger": logger,
"seed": seed,
"steps": steps,
"log_every": log_every,
"replay_size": replay_size,
"batch_size": batch_size,
"actor_cl": actor_cl,
"actor_kwargs": actor_kwargs,
"critic_kwargs": critic_kwargs,
"buffer_type": BufferType(buffer_type),
"reset_buffer_on_task_change": reset_buffer_on_task_change,
"reset_optimizer_on_task_change": reset_optimizer_on_task_change,
"lr": lr,
"alpha": alpha,
"reset_critic_on_task_change": reset_critic_on_task_change,
"clipnorm": clipnorm,
"gamma": gamma,
"target_output_std": target_output_std,
"agent_policy_exploration": agent_policy_exploration,
}
sac_class = get_sac_class(cl_method)
if cl_method is None:
sac = sac_class(**vanilla_sac_kwargs)
elif cl_method in ["l2", "ewc", "mas"]:
sac = sac_class(
**vanilla_sac_kwargs, cl_reg_coef=cl_reg_coef, regularize_critic=regularize_critic
)
elif cl_method == "vcl":
sac = sac_class(
**vanilla_sac_kwargs,
cl_reg_coef=cl_reg_coef,
regularize_critic=regularize_critic,
first_task_kl=vcl_first_task_kl
)
elif cl_method == "packnet":
sac = sac_class(
**vanilla_sac_kwargs,
regularize_critic=regularize_critic,
retrain_steps=packnet_retrain_steps
)
elif cl_method == "agem":
sac = sac_class(
**vanilla_sac_kwargs,
episodic_mem_per_task=episodic_mem_per_task,
episodic_batch_size=episodic_batch_size
)
else:
raise NotImplementedError("This method is not implemented")
sac.run()
if __name__ == "__main__":
args = vars(cl_parse_args())
logger = EpochLogger(args["logger_output"], config=args, group_id=args["group_id"])
del args["group_id"]
del args["logger_output"]
main(logger, **args)