-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining_run.py
110 lines (85 loc) · 3.79 KB
/
training_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
from argparse import Namespace
from pathlib import Path
import numpy as np
import yaml
from sample_factory.algorithms.utils.algo_utils import EXTRA_PER_POLICY_SUMMARIES
from sample_factory.utils.utils import log
import wandb
from sample_factory.envs.env_registry import global_env_registry
from sample_factory.run_algorithm import run_algorithm
import sys
from sample_factory.algorithms.appo.model_utils import register_custom_encoder
from models import ResnetEncoderWithTarget
from create_env import make_iglu
from utils.config_validation import Experiment
from torch.multiprocessing import Pool, Process, set_start_method
#sys.path.append('../')
#from agents.mhb_baseline.generator import DialogueFigure
def iglu_extra_summaries(policy_id, policy_avg_stats, env_steps, summary_writer, cfg):
for key in policy_avg_stats:
for metric in ['SuccessRate', 'steps_do', 'wins', 'CoplitedRate']:
if metric in key:
if metric == 'steps_do' or metric == 'CoplitedRate':
avg = np.mean(policy_avg_stats[key])
else:
avg = np.mean(policy_avg_stats[key])
summary_writer.add_scalar(key, avg, env_steps)
log.debug(f'{key}: {round(float(avg), 3)}')
def make_env(full_env_name, cfg=None, env_config=None):
#full_env_name = DialogueFigure
return make_iglu()
def register_custom_components():
global_env_registry().register_env(
env_name_prefix='IGLUSilentBuilder-v0',
make_env_func=make_env,
)
register_custom_encoder('custom_env_encoder', ResnetEncoderWithTarget)
EXTRA_PER_POLICY_SUMMARIES.append(iglu_extra_summaries)
def validate_config(config):
exp = Experiment(**config)
flat_config = Namespace(**exp.async_ppo.dict(),
**exp.experiment_settings.dict(),
**exp.global_settings.dict(),
**exp.evaluation.dict(),
full_config=exp.dict()
)
return exp, flat_config
def main():
register_custom_components()
import argparse
parser = argparse.ArgumentParser(description='Process training config.')
parser.add_argument('--config_path', type=str, action="store",
help='path to yaml file with single run configuration', required=False)
parser.add_argument('--raw_config', type=str, action='store',
help='raw json config', required=False)
parser.add_argument('--wandb_thread_mode', type=bool, action='store', default=False,
help='Run wandb in thread mode. Usefull for some setups.', required=False)
# parser.add_argument('--with_wandb', action='store_true', default=True)
params = parser.parse_args()
if params.raw_config:
config = json.loads(params.raw_config)
else:
if params.config_path is None:
config = Experiment().dict()
else:
with open(params.config_path, "r") as f:
config = yaml.safe_load(f)
exp, flat_config = validate_config(config)
log.debug(exp.global_settings.experiments_root)
if exp.global_settings.use_wandb:
import os
if params.wandb_thread_mode:
os.environ["WANDB_START_METHOD"] = "thread"
wandb.login(key=os.environ.get('WANDB_APIKEY'))
wandb.init(project=exp.name, config=exp.dict(), save_code=False, sync_tensorboard=True)
status = run_algorithm(flat_config)
if exp.global_settings.use_wandb:
import shutil
path = Path(exp.global_settings.train_dir) / exp.global_settings.experiments_root
zip_name = str(path)
shutil.make_archive(zip_name, 'zip', path)
wandb.save(zip_name + '.zip')
return status
if __name__ == '__main__':
sys.exit(main())