Skip to content

Commit

Permalink
Merge pull request #79 from javoweb/fix/tfod/checkpoint/matcher
Browse files Browse the repository at this point in the history
Fix/tfod/checkpoint/matcher
  • Loading branch information
rushtehrani authored Jun 28, 2021
2 parents 8892cbb + 5b13867 commit 96b15db
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions workflows/tf-object-detection-training/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import glob
import yaml
import shutil
import tarfile
Expand All @@ -10,6 +11,20 @@

from utils import convert_labels_to_csv, create_pipeline

def is_checkpoint_directory(dir):
matched_files = get_checkpoint_files(dir)
return True if len(matched_files)>0 else False

def get_last_checkpoint_filename(dir):
matched_files = get_checkpoint_files(dir)
matched_files.sort()
match_file = matched_files[-1].split('/')[-1]
return '.'.join(match_file.split('.')[0:2])

def get_checkpoint_files(dir):
search_expression = os.path.join(dir, 'model.ckpt*')
return [f for f in glob.glob(search_expression) if os.path.isfile(f)]


def main(params):

Expand Down Expand Up @@ -43,13 +58,20 @@ def main(params):
files = os.listdir(files_dir)
for f in files:
shutil.move(os.path.join(files_dir , f),model_dir)
elif os.path.isfile(os.path.join(model_dir , 'output/model/model.ckpt.index')):
elif is_checkpoint_directory(os.path.join(model_dir , 'output/model')):
model_dir = os.path.join(model_dir , 'output/model')
elif os.path.isfile(os.path.join(model_dir , 'model/model.ckpt.index')):
elif is_checkpoint_directory(os.path.join(model_dir , 'output/checkpoint')):
model_dir = os.path.join(model_dir , 'output/checkpoint')
elif is_checkpoint_directory(os.path.join(model_dir , 'model')):
model_dir = os.path.join(model_dir , 'model')
elif not os.path.isfile(os.path.join(model_dir , 'model.ckpt.index')):
elif is_checkpoint_directory(os.path.join(model_dir , 'checkpoint')):
model_dir = os.path.join(model_dir , 'checkpoint')
elif not is_checkpoint_directory(model_dir):
raise ValueError("No valid checkpoint found")

checkpoint_name = get_last_checkpoint_filename(model_dir)
print(checkpoint_name)

if params['from_preprocessing']:
train_set = 'tfrecord/train.tfrecord*'
eval_set = 'tfrecord/eval.tfrecord*'
Expand All @@ -58,7 +80,7 @@ def main(params):
eval_set = 'default.tfrecord'

params = create_pipeline(os.path.join(model_dir , 'pipeline.config'),
os.path.join(model_dir , 'model.ckpt'),
os.path.join(model_dir , checkpoint_name),
os.path.join(data_dir, 'label_map.pbtxt'),
os.path.join(data_dir, train_set),
os.path.join(data_dir, eval_set),
Expand Down

0 comments on commit 96b15db

Please sign in to comment.