Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch transcribe dir (recursively) with transcribe.py #2879

Merged
merged 2 commits into from
Apr 3, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
import glob

from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
Expand Down Expand Up @@ -75,13 +76,20 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float)


def transcribe_many(path_pairs):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
for i, (src_path, dst_path) in enumerate(path_pairs):
p = Process(target=transcribe_file, args=(src_path, dst_path))
def transcribe_many(paths, kind):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(paths)).start()
if kind == 'dir':
# the user pointed to a dir of files
src_paths = paths
dst_paths = [ path.replace('.wav','.tlog') for path in paths ]
elif kind == 'catalog':
# the user pointed to a catalog dir from DSAlign
src_paths,dst_paths = zip(*paths)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why so complicated? This if block and the kind parameter are not required if all the case-specific preparations would just be done in the corresponding cases in main - just before calling transcribe_many.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

un-complicat-ified the function:)

for i in range(len(paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(paths), src_paths[i], dst_paths[i]))
pbar.update(i)
pbar.finish()

Expand All @@ -100,47 +108,65 @@ def resolve(base_path, spec_path):


def main(_):
if not FLAGS.src:
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail('You have to specify which file or catalog to transcribe via the --src flag.')
src_path = os.path.abspath(FLAGS.src)
if not os.path.isfile(src_path):
fail('Path in --src not existing')
if src_path.endswith('.catalog'):
if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
fail('Missing destination directory for at least one catalog entry')
transcribe_many(catalog_entries)
else:
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith('.catalog'):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
fail('Missing destination directory for at least one catalog entry')
transcribe_many(catalog_entries, 'catalog')
else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail('Missing destination directory')
# Transcribe one file
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail('Missing destination directory')
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail('Destination file not supported for batch decoding jobs.')
else:
if not FLAGS.recursive:
print("If you wish to recursively scan --src, then you must use --recursive")
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
transcribe_many(wav_paths, 'dir')


if __name__ == '__main__':
create_flags()
tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
'for audio files. If --dst not set, transcription logs (.tlog) will be '
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
'Catalog files should be formatted from DSAlign. A directory will'
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
'written in-place using the source filenames with '
'suffix ".tlog" instead of ".wav".')
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
'If --src is a directory, this one also has to be a directory '
'and the required sub-dir tree of --src will get replicated.')
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
'transcription logs (.tlog)')
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
Expand Down