4
4
from typing import Optional
5
5
from .models .Subtitles import Subtitles , SegmentsIterable
6
6
from .utils .files import filename , write_srt
7
- from .utils .ffmpeg import get_audio , add_subtitles , preprocess_audio
7
+ from .utils .ffmpeg import get_audio , add_subtitles , preprocess_audio , file_has_audio
8
8
from .utils .whisper import WhisperAI
9
9
from .translation .easynmt_utils import EasyNMTWrapper
10
10
@@ -39,7 +39,7 @@ def process(args: dict):
39
39
"subtitle_type" : args .pop ("subtitle_type" )
40
40
}
41
41
42
- videos = args .pop ('video' )
42
+ paths_to_process = args .pop ('video' )
43
43
audio_channel = args .pop ('audio_channel' )
44
44
model_args = {
45
45
"model_size_or_path" : model_name ,
@@ -51,16 +51,42 @@ def process(args: dict):
51
51
device = model_args ['device' ]) if target_language != 'en' else None
52
52
53
53
os .makedirs (output_args ["output_dir" ], exist_ok = True )
54
- for video in videos :
55
- if video .endswith ('.wav' ):
56
- audio = preprocess_audio (video , audio_channel , sample_interval )
57
- else :
58
- audio = get_audio (video , audio_channel , sample_interval )
54
+ for path_to_process in paths_to_process :
55
+ process_path (audio_channel , language , output_args , path_to_process , sample_interval ,
56
+ target_language , transcribe_model , translate_model )
59
57
60
- transcribed , translated = perform_task (video , audio , language , target_language ,
61
- transcribe_model , translate_model )
62
58
63
- save_result (video , transcribed , translated , sample_interval , output_args )
59
+ def process_path (audio_channel , language , output_args , path_to_process , sample_interval ,
60
+ target_language , transcribe_model , translate_model ):
61
+ if not os .path .exists (path_to_process ):
62
+ logger .error ("File %s does not exist." , path_to_process )
63
+ return
64
+
65
+ if not os .path .isdir (path_to_process ):
66
+ process_file (audio_channel , language , output_args , sample_interval , target_language ,
67
+ transcribe_model , translate_model , path_to_process )
68
+ return
69
+
70
+ logger .info ("Processing all files in directory %s" , path_to_process )
71
+ for file_name in os .listdir (path_to_process ):
72
+ process_file (audio_channel , language , output_args , sample_interval , target_language ,
73
+ transcribe_model , translate_model , os .path .join (path_to_process , file_name ))
74
+
75
+
76
+ def process_file (audio_channel , language , output_args , sample_interval , target_language ,
77
+ transcribe_model , translate_model , file_name ):
78
+ if not file_has_audio (file_name ):
79
+ logger .info ("File %s has no audio, skipping." , file_name )
80
+ return
81
+
82
+ if file_name .endswith ('.wav' ):
83
+ audio = preprocess_audio (file_name , audio_channel , sample_interval )
84
+ else :
85
+ audio = get_audio (file_name , audio_channel , sample_interval )
86
+
87
+ transcribed , translated = perform_task (file_name , audio , language , target_language ,
88
+ transcribe_model , translate_model )
89
+ save_result (file_name , transcribed , translated , sample_interval , output_args )
64
90
65
91
66
92
def save_result (video : str , transcribed : Subtitles , translated : Subtitles , sample_interval : list ,
@@ -84,9 +110,7 @@ def perform_task(video: str, audio: str, language: str, target_language: str,
84
110
transcribed = get_subtitles (video , audio , transcribe_model )
85
111
translated = None
86
112
87
- logger .info ('Subtitles generated.' )
88
113
if target_language != 'en' :
89
- logger .info ('Translating subtitles... This might take a while.' )
90
114
translated = translate_subtitles (
91
115
transcribed , language , target_language , translate_model )
92
116
@@ -99,8 +123,11 @@ def translate_subtitles(subtitles: Subtitles, source_lang: str, target_lang: str
99
123
if src_lang == '' or src_lang is None :
100
124
src_lang = subtitles .language
101
125
126
+ segments = list (subtitles .segments )
127
+ logger .info ('Subtitles generated.' )
128
+ logger .info ('Translating subtitles... This might take a while.' )
102
129
translated_segments = model .translate (
103
- list ( subtitles . segments ) , src_lang , target_lang )
130
+ segments , src_lang , target_lang )
104
131
105
132
return Subtitles (SegmentsIterable (translated_segments ), target_lang )
106
133
0 commit comments