Skip to content

Commit

Permalink
fix: Use section header from command line
Browse files Browse the repository at this point in the history
  • Loading branch information
shahrukhqasim committed Dec 5, 2017
1 parent 1f905a3 commit c9cbffb
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions python/table_parse_2d/parser_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ def load_datum(self, full_path):
return input_tensor, left_class_one_hot, top_class_one_hot, loss_mask

class Parser2d:
def __init__(self):
def __init__(self, config_section):
config = cp.ConfigParser()
config.read('config.ini')
self.train_path = config['zone_segment']['train_data_path']
self.test_path = config['zone_segment']['test_data_path']
self.validation_data_path = config['zone_segment']['validation_data_path']
self.learning_rate = float(config['zone_segment']['learning_rate'])
self.save_after = int(config['zone_segment']['save_after'])
self.model_path = config['zone_segment']['model_path']
self.from_scratch = int(config['zone_segment']['from_scratch']) == 1
self.batch_size = int(config['zone_segment']['batch_size'])
self.summary_path = config['zone_segment']['summary_path']
self.train_path = config[config_section]['train_data_path']
self.test_path = config[config_section]['test_data_path']
self.validation_data_path = config[config_section]['validation_data_path']
self.learning_rate = float(config[config_section]['learning_rate'])
self.save_after = int(config[config_section]['save_after'])
self.model_path = config[config_section]['model_path']
self.from_scratch = int(config[config_section]['from_scratch']) == 1
self.batch_size = int(config[config_section]['batch_size'])
self.summary_path = config[config_section]['summary_path']

self.alpha_left = float(config['zone_segment']['alpha_left'])
self.alpha_top = float(config['zone_segment']['alpha_top'])
Expand Down Expand Up @@ -172,6 +172,7 @@ def train(self):

if not self.from_scratch:
self.saver_all.restore(sess, self.model_path)
print("\n\nINFO: Saving model\n\n")
with open(self.model_path+'.txt', 'r') as f:
iteration = int(f.read())
else:
Expand Down Expand Up @@ -220,6 +221,6 @@ def train(self):
print("\tFor example, if the section name in the config file is `zone_segment` you can issue:\n"
"\tpython table_parse_2d/parser_2d.py zone_segment")
sys.exit(-1)
parser = Parser2d()
parser = Parser2d(sys.argv[1])
parser.construct_graphs()
parser.train()

0 comments on commit c9cbffb

Please sign in to comment.