-
Notifications
You must be signed in to change notification settings - Fork 2
/
task_retrain.py
36 lines (28 loc) · 1 KB
/
task_retrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# encoding: utf-8
import os
from tensorflow.python.platform import gfile
import tensorflow as tf
import Network
import time
current_time = time.strftime("%Y-%m-%d--%H-%M-%S", time.gmtime())
def train():
Network.TRAIN_FILE = "train-voxel-gta-offroad-filtered.csv"
Network.TEST_FILE = "train-voxel-gta-offroad-filtered.csv"
Network.PREDICT_DIR = os.path.join('predict-offroad', current_time)
Network.CHECKPOINT_DIR = os.path.join('checkpoint-offroad', current_time)
network = Network.Network()
network.train()
def main(argv=None):
if not gfile.Exists("./train"):
gfile.MakeDirs("./train")
if not gfile.Exists("./test"):
gfile.MakeDirs("./test")
if not gfile.Exists(Network.PREDICT_DIR):
gfile.MakeDirs(Network.PREDICT_DIR)
if not gfile.Exists(Network.CHECKPOINT_DIR):
gfile.MakeDirs(Network.CHECKPOINT_DIR)
if not gfile.Exists(Network.LOGS_DIR):
gfile.MakeDirs(Network.LOGS_DIR)
train()
if __name__ == '__main__':
tf.app.run()