Skip to content

Commit

Permalink
Merge pull request #58 from javoweb/feat/tensorboard
Browse files Browse the repository at this point in the history
feat: TensorBoard + Pre-processing for TFOD workflow
  • Loading branch information
rushtehrani authored Jan 18, 2021
2 parents dece0ed + 1f415b1 commit 034a7b5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
35 changes: 26 additions & 9 deletions workflows/tf-object-detection-training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,45 @@ def main(params):
for f in files:
shutil.move(model_dir+'/'+f,'/mnt/data/models')

if params['from_preprocessing']:
train_set = '/tfrecord/train.tfrecord*'
eval_set = '/tfrecord/eval.tfrecord*'
else:
train_set = '/*.tfrecord'
eval_set = '/default.tfrecord'

params = create_pipeline('/mnt/data/models/pipeline.config',
'/mnt/data/models/model.ckpt',
params['dataset']+'/label_map.pbtxt',
params['dataset']+'/*.tfrecord',
params['dataset']+'/default.tfrecord',
params['dataset']+train_set,
params['dataset']+eval_set,
'/mnt/output/pipeline.config',
params)

os.chdir('/mnt/output')
os.mkdir('eval/')
subprocess.call(['python',
'/mnt/src/tf/research/object_detection/legacy/train.py',
'--train_dir=/mnt/output/',
directory = 'eval/'
try:
os.stat(directory)
except:
os.mkdir(directory)
return_code = subprocess.call(['python',
'/mnt/src/tf/research/object_detection/model_main.py',
'--alsologtostderr',
'--model_dir=/mnt/output/',
'--pipeline_config_path=/mnt/output/pipeline.config',
'--num_clones={}'.format(params['num_clones'])
'--num_train_steps={}'.format(params['epochs'])
])
subprocess.call(['python',
if return_code != 0:
raise RuntimeError('Training process failed')
return_code = subprocess.call(['python',
'/mnt/src/tf/research/object_detection/export_inference_graph.py',
'--input-type=image_tensor',
'--pipeline_config_path=/mnt/output/pipeline.config',
'--trained_checkpoint_prefix=/mnt/output/model.ckpt-{}'.format(params['epochs']),
'--output_directory=/mnt/output'
])
if return_code != 0:
raise RuntimeError('Model export process failed')

# generate lable map
convert_labels_to_csv(params['dataset'])
Expand All @@ -60,10 +76,11 @@ def main(params):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train TFOD.')
parser.add_argument('--dataset', default='/mnt/data/datasets')
parser.add_argument('--extras', help='hyperparameters or other configs')
parser.add_argument('--extras', default='', help='hyperparameters or other configs')
parser.add_argument('--sys_finetune_checkpoint', default=' ', help='path to checkpoint')
parser.add_argument('--model', default='frcnn-res50-coco', help='which model to train')
parser.add_argument('--num_classes', default=81, type=int, help='number of classes')
parser.add_argument('--from_preprocessing', default=False, type=bool)
args = parser.parse_args()
# parse parameters
# sample: epochs=100;num_classes=1
Expand Down
17 changes: 16 additions & 1 deletion workflows/tf-object-detection-training/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ arguments:

- name: hyperparameters
value: |-
num-steps=10000
num_steps=10000
displayName: Hyperparameters
visibility: public
type: textarea.textarea
Expand Down Expand Up @@ -145,6 +145,21 @@ templates:
path: /mnt/output
s3:
key: '{{workflow.namespace}}/{{workflow.parameters.cvat-output-path}}/{{workflow.name}}'
sidecars:
- name: tensorboard
image: tensorflow/tensorflow:2.3.0
command:
- sh
- '-c'
env:
- name: ONEPANEL_INTERACTIVE_SIDECAR
value: 'true'
args:
# Read logs from /mnt/output - this directory is auto-mounted from volumeMounts
- tensorboard --logdir /mnt/output/train/
ports:
- containerPort: 6006
name: tensorboard
volumeClaimTemplates:
- metadata:
name: data
Expand Down
4 changes: 3 additions & 1 deletion workflows/tf-object-detection-training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def process_params(params):
model_architecture = 'ssd'
else:
model_architecture = 'frcnn'


model_params['eval_interval_secs'] = 3000
model_params['epochs'] = params.pop('num_steps')

for key in params.keys():
Expand Down Expand Up @@ -181,6 +182,7 @@ def create_pipeline(pipeline_path, model_path, label_path,

pipeline_config.eval_input_reader[0].label_map_path=label_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0]=eval_tfrecord_path
pipeline_config.eval_config.eval_interval_secs = int(model_params['eval_interval_secs'])

config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(out_pipeline_path, 'wb') as f:
Expand Down

0 comments on commit 034a7b5

Please sign in to comment.