-
Notifications
You must be signed in to change notification settings - Fork 13
/
main.py
42 lines (28 loc) · 1020 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#-- coding: utf-8 -*-
import argparse
import logging
from hbconfig import Config
import tensorflow as tf
import experiment
def main(mode):
params = tf.contrib.training.HParams(**Config.model.to_dict())
run_config = tf.contrib.learn.RunConfig(
model_dir=Config.train.model_dir)
tf.contrib.learn.learn_runner.run(
experiment_fn=experiment.experiment_fn,
run_config=run_config,
schedule=mode,
hparams=params
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', type=str, default='config',
help='config file name')
parser.add_argument('--mode', type=str, default='train',
help='Mode (train/test/train_and_evaluate)')
args = parser.parse_args()
tf.logging._logger.setLevel(logging.INFO)
Config(args.config)
print("Config: ", Config)
main(args.mode)