Skip to content

Commit

Permalink
Merge pull request #2879 from JRMeyer/batch-transcribe
Browse files Browse the repository at this point in the history
batch transcribe dir (recursively) with transcribe.py
  • Loading branch information
tilmankamp authored Apr 3, 2020
2 parents 0cee75c + cad03d3 commit 510e29f
Showing 1 changed file with 55 additions and 34 deletions.
89 changes: 55 additions & 34 deletions transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
import glob

from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
Expand Down Expand Up @@ -75,13 +76,13 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float)


def transcribe_many(path_pairs):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
for i, (src_path, dst_path) in enumerate(path_pairs):
p = Process(target=transcribe_file, args=(src_path, dst_path))
def transcribe_many(src_paths,dst_paths):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(src_paths)).start()
for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(src_paths), src_paths[i], dst_paths[i]))
pbar.update(i)
pbar.finish()

Expand All @@ -100,47 +101,67 @@ def resolve(base_path, spec_path):


def main(_):
if not FLAGS.src:
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail('You have to specify which file or catalog to transcribe via the --src flag.')
src_path = os.path.abspath(FLAGS.src)
if not os.path.isfile(src_path):
fail('Path in --src not existing')
if src_path.endswith('.catalog'):
if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
fail('Missing destination directory for at least one catalog entry')
transcribe_many(catalog_entries)
else:
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith('.catalog'):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
fail('Missing destination directory for at least one catalog entry')
src_paths,dst_paths = zip(*paths)
transcribe_many(src_paths,dst_paths)
else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail('Missing destination directory')
# Transcribe one file
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail('Missing destination directory')
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail('Destination file not supported for batch decoding jobs.')
else:
if not FLAGS.recursive:
print("If you wish to recursively scan --src, then you must use --recursive")
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace('.wav','.tlog') for path in wav_paths]
transcribe_many(wav_paths,dst_paths)


if __name__ == '__main__':
create_flags()
tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
'for audio files. If --dst not set, transcription logs (.tlog) will be '
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
'Catalog files should be formatted from DSAlign. A directory will'
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
'written in-place using the source filenames with '
'suffix ".tlog" instead of ".wav".')
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
'If --src is a directory, this one also has to be a directory '
'and the required sub-dir tree of --src will get replicated.')
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
'transcription logs (.tlog)')
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
Expand Down

0 comments on commit 510e29f

Please sign in to comment.