From b796e66890fb73112d78c6b233c37d75371f72e2 Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 17:57:54 +0800 Subject: [PATCH 1/6] Create SenseVoice.py --- nodes/SenseVoice.py | 105 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 nodes/SenseVoice.py diff --git a/nodes/SenseVoice.py b/nodes/SenseVoice.py new file mode 100644 index 0000000..1880e80 --- /dev/null +++ b/nodes/SenseVoice.py @@ -0,0 +1,105 @@ +# -*- coding:utf-8 -*- +# @FileName :sense_voice.py.py +# @Time :2024/7/18 15:40 +# @Author :lovemefan +# @Email :lovemefan@outlook.com +import argparse +import logging +import os +import time + +import soundfile as sf +from huggingface_hub import snapshot_download + +from sensevoice.onnx.sense_voice_ort_session import SenseVoiceInferenceSession +from sensevoice.utils.frontend import WavFrontend +from sensevoice.utils.fsmn_vad import FSMNVad + +languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} +formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" +logging.basicConfig(format=formatter, level=logging.INFO) + + +def main(): + arg_parser = argparse.ArgumentParser(description="Sense Voice") + arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model") + download_model_path = os.path.join(os.path.dirname(__file__), "resource") + arg_parser.add_argument( + "-dp", + "--download_path", + default=download_model_path, + type=str, + help="dir path of resource downloaded", + ) + arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device") + arg_parser.add_argument( + "-n", "--num_threads", default=4, type=int, help="Num threads" + ) + arg_parser.add_argument( + "-l", + "--language", + choices=languages.keys(), + default="auto", + type=str, + help="Language", + ) + arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN") + arg_parser.add_argument( + "--use_int8", action="store_true", help="Use int8 onnx model" + ) + args = arg_parser.parse_args() + + if not os.path.exists(download_model_path): + logging.info( + "Downloading model from huggingface hub from https://huggingface.co/lovemefan/SenseVoice-onnx" + ) + logging.info( + "You can speed up with `export HF_ENDPOINT=https://hf-mirror.com`" + ) + snapshot_download( + repo_id="lovemefan/SenseVoice-onnx", local_dir=download_model_path + ) + + front = WavFrontend(os.path.join(download_model_path, "am.mvn")) + + model = SenseVoiceInferenceSession( + os.path.join(download_model_path, "embedding.npy"), + os.path.join( + download_model_path, + "sense-voice-encoder-int8.onnx" + if args.use_int8 + else "sense-voice-encoder.onnx", + ), + os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"), + args.device, + args.num_threads, + ) + waveform, _sample_rate = sf.read( + args.audio_file, + dtype="float32", + always_2d=True + ) + + logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel") + # load vad model + start = time.time() + vad = FSMNVad(download_model_path) + for channel_id, channel_data in enumerate(waveform.T): + segments = vad.segments_offline(channel_data) + results = "" + for part in segments: + audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16]) + asr_result = model( + audio_feats[None, ...], + language=languages[args.language], + use_itn=args.use_itn, + ) + logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") + vad.vad.all_reset_detection() + decoding_time = time.time() - start + logging.info(f"Decoder audio takes {decoding_time} seconds") + logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.") + + +if __name__ == "__main__": + main() \ No newline at end of file From a70a9b4bb1a656cba87cae216e82dc572b36a6fa Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 20:41:05 +0800 Subject: [PATCH 2/6] update --- __init__.py | 9 ++ nodes/SenseVoice.py | 227 ++++++++++++++++++++++++++------------------ 2 files changed, 146 insertions(+), 90 deletions(-) diff --git a/__init__.py b/__init__.py index ef740bc..aedea7b 100644 --- a/__init__.py +++ b/__init__.py @@ -1429,6 +1429,15 @@ def mix_status(request): except Exception as e: logging.info('FishSpeech.available False' ) +try: + from .nodes.SenseVoice import SenseVoiceNode + logging.info('SenseVoice.available') + NODE_CLASS_MAPPINGS['SenseVoiceNode']=SenseVoiceNode + NODE_DISPLAY_NAME_MAPPINGS["SenseVoiceNode"]= "Sense Voice" + +except Exception as e: + logging.info('SenseVoice.available False' ) + logging.info('\033[93m -------------- \033[0m') diff --git a/nodes/SenseVoice.py b/nodes/SenseVoice.py index 1880e80..fa96ebe 100644 --- a/nodes/SenseVoice.py +++ b/nodes/SenseVoice.py @@ -1,105 +1,152 @@ # -*- coding:utf-8 -*- -# @FileName :sense_voice.py.py -# @Time :2024/7/18 15:40 -# @Author :lovemefan -# @Email :lovemefan@outlook.com -import argparse import logging import os import time -import soundfile as sf from huggingface_hub import snapshot_download - +import torch from sensevoice.onnx.sense_voice_ort_session import SenseVoiceInferenceSession from sensevoice.utils.frontend import WavFrontend from sensevoice.utils.fsmn_vad import FSMNVad +import comfy.utils # 假设这个模块包含ProgressBar +from comfy.model_management import get_torch_device # 假设这个函数在这个模块中 +import folder_paths languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} -formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" -logging.basicConfig(format=formatter, level=logging.INFO) - - -def main(): - arg_parser = argparse.ArgumentParser(description="Sense Voice") - arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model") - download_model_path = os.path.join(os.path.dirname(__file__), "resource") - arg_parser.add_argument( - "-dp", - "--download_path", - default=download_model_path, - type=str, - help="dir path of resource downloaded", - ) - arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device") - arg_parser.add_argument( - "-n", "--num_threads", default=4, type=int, help="Num threads" - ) - arg_parser.add_argument( - "-l", - "--language", - choices=languages.keys(), - default="auto", - type=str, - help="Language", - ) - arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN") - arg_parser.add_argument( - "--use_int8", action="store_true", help="Use int8 onnx model" - ) - args = arg_parser.parse_args() - - if not os.path.exists(download_model_path): - logging.info( - "Downloading model from huggingface hub from https://huggingface.co/lovemefan/SenseVoice-onnx" - ) - logging.info( - "You can speed up with `export HF_ENDPOINT=https://hf-mirror.com`" - ) - snapshot_download( - repo_id="lovemefan/SenseVoice-onnx", local_dir=download_model_path - ) - front = WavFrontend(os.path.join(download_model_path, "am.mvn")) - - model = SenseVoiceInferenceSession( - os.path.join(download_model_path, "embedding.npy"), - os.path.join( - download_model_path, - "sense-voice-encoder-int8.onnx" - if args.use_int8 - else "sense-voice-encoder.onnx", - ), - os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"), - args.device, - args.num_threads, - ) - waveform, _sample_rate = sf.read( - args.audio_file, - dtype="float32", - always_2d=True - ) - - logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel") - # load vad model - start = time.time() - vad = FSMNVad(download_model_path) - for channel_id, channel_data in enumerate(waveform.T): - segments = vad.segments_offline(channel_data) - results = "" - for part in segments: - audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16]) - asr_result = model( - audio_feats[None, ...], - language=languages[args.language], - use_itn=args.use_itn, +# +def get_model_path(): + try: + return folder_paths.get_folder_paths('sense_voice')[0] + except: + return os.path.join(folder_paths.models_dir, "sense_voice") + + +class SenseVoiceProcessor: + def __init__(self, download_model_path, device, num_threads, use_int8): + + if not os.path.exists(download_model_path): + logging.info( + "Downloading model from huggingface hub from https://huggingface.co/lovemefan/SenseVoice-onnx" + ) + logging.info( + "You can speed up with `export HF_ENDPOINT=https://hf-mirror.com`" ) - logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") - vad.vad.all_reset_detection() - decoding_time = time.time() - start - logging.info(f"Decoder audio takes {decoding_time} seconds") - logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.") + snapshot_download( + repo_id="lovemefan/SenseVoice-onnx", local_dir=download_model_path + ) + + self.download_model_path = download_model_path + self.device = device + self.num_threads = num_threads + self.use_int8 = use_int8 + self.front = WavFrontend(os.path.join(download_model_path, "am.mvn")) + self.model = SenseVoiceInferenceSession( + os.path.join(download_model_path, "embedding.npy"), + os.path.join( + download_model_path, + "sense-voice-encoder-int8.onnx" + if use_int8 + else "sense-voice-encoder.onnx", + ), + os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"), + device, + num_threads, + ) + self.vad = FSMNVad(download_model_path) + + def process_audio(self, waveform, _sample_rate, language, use_itn): + + start = time.time() + pbar = comfy.utils.ProgressBar(waveform.shape[1]) # 进度条 + + results = [] + + for channel_id, channel_data in enumerate(waveform.T): + segments = self.vad.segments_offline(channel_data) + + for part in segments: + audio_feats = self.front.get_features(channel_data[part[0] * 16 : part[1] * 16]) + asr_result = self.model( + audio_feats[None, ...], + language=languages[language], + use_itn=use_itn, + ) + logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") + + results.append(asr_result) + + self.vad.vad.all_reset_detection() + pbar.update(1) # 更新进度条 + + decoding_time = time.time() - start + logging.info(f"Decoder audio takes {decoding_time} seconds") + logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.") + return results + + +class SenseVoiceNode: + + def __init__(self): + self.processor = None + self.download_model_path=get_model_path() + self.device="cpu" + + @classmethod + def INPUT_TYPES(s): + + return {"required": { + "audio": ("AUDIO", ), + "device": ( ['auto','cpu'], {"default": 'auto'}), + "language": (languages.keys(), {"default": 'auto'}), + "num_threads":("INT",{ + "default":4, + "min": 1, #Minimum value + "max": 32, #Maximum value + "step": 1, #Slider's step + "display": "number" # Cosmetic only: display as "number" or "slider" + }), + "use_int8":("BOOLEAN", {"default": True},), + "use_itn":("BOOLEAN", {"default": True},), + }, + } + + CATEGORY = "♾️Mixlab/Audio" + + OUTPUT_NODE = True + FUNCTION = "run" + RETURN_TYPES = ("SCENE_VIDEO",) + RETURN_NAMES = ("SCENE_VIDEO",) + + def run(self,audio,device,language,num_threads,use_int8,use_itn ): + + if device!=self.device: + self.device=device + self.processor=None + if num_threads!=self.num_threads: + self.num_threads=num_threads + self.processor=None + if use_int8!=self.use_int8: + self.use_int8=use_int8 + self.processor=None + + if device=='auto' and torch.cuda.is_available(): + self.device='cuda' + + # num_threads=4 + # use_int8=True + + if self.processor==None: + self.processor = SenseVoiceProcessor(self.download_model_path, + self.device, + self.num_threads, + self.use_int8) + + if 'waveform' in audio and 'sample_rate' in audio: + waveform_numpy=audio['waveform'].numpy().transpose(1, 0) # 转换为 (num_samples, num_channels) + _sample_rate=audio['sample_rate'] + + results=self.processor.process_audio(waveform_numpy, _sample_rate, language, use_itn) -if __name__ == "__main__": - main() \ No newline at end of file + return (results,) From 5f7190b08f7e1ea790468e68d2e78768addf9447 Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 21:01:27 +0800 Subject: [PATCH 3/6] Update SenseVoice.py --- nodes/SenseVoice.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nodes/SenseVoice.py b/nodes/SenseVoice.py index fa96ebe..b76e902 100644 --- a/nodes/SenseVoice.py +++ b/nodes/SenseVoice.py @@ -14,6 +14,9 @@ languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} +# 设置环境变量 +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + # def get_model_path(): try: From 8afe6d038351f68724ee06e7cb9762aec35264ec Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 22:15:35 +0800 Subject: [PATCH 4/6] update --- nodes/SenseVoice.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/nodes/SenseVoice.py b/nodes/SenseVoice.py index b76e902..844043c 100644 --- a/nodes/SenseVoice.py +++ b/nodes/SenseVoice.py @@ -8,8 +8,7 @@ from sensevoice.onnx.sense_voice_ort_session import SenseVoiceInferenceSession from sensevoice.utils.frontend import WavFrontend from sensevoice.utils.fsmn_vad import FSMNVad -import comfy.utils # 假设这个模块包含ProgressBar -from comfy.model_management import get_torch_device # 假设这个函数在这个模块中 +import comfy.utils import folder_paths languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} @@ -23,7 +22,14 @@ def get_model_path(): return folder_paths.get_folder_paths('sense_voice')[0] except: return os.path.join(folder_paths.models_dir, "sense_voice") - + +class AnyType(str): + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False + +any_type = AnyType("*") class SenseVoiceProcessor: def __init__(self, download_model_path, device, num_threads, use_int8): @@ -94,6 +100,9 @@ def __init__(self): self.processor = None self.download_model_path=get_model_path() self.device="cpu" + self.num_threads = 4 + self.use_int8 = True + self.language='auto' @classmethod def INPUT_TYPES(s): @@ -101,14 +110,14 @@ def INPUT_TYPES(s): return {"required": { "audio": ("AUDIO", ), "device": ( ['auto','cpu'], {"default": 'auto'}), - "language": (languages.keys(), {"default": 'auto'}), + "language": (list(languages.keys()), {"default": 'auto'}),# 不能直接写 languages.keys(),json.dumps会报错 "num_threads":("INT",{ "default":4, "min": 1, #Minimum value "max": 32, #Maximum value "step": 1, #Slider's step "display": "number" # Cosmetic only: display as "number" or "slider" - }), + },), "use_int8":("BOOLEAN", {"default": True},), "use_itn":("BOOLEAN", {"default": True},), }, @@ -118,14 +127,17 @@ def INPUT_TYPES(s): OUTPUT_NODE = True FUNCTION = "run" - RETURN_TYPES = ("SCENE_VIDEO",) - RETURN_NAMES = ("SCENE_VIDEO",) + RETURN_TYPES = (any_type,) + RETURN_NAMES = ("result",) def run(self,audio,device,language,num_threads,use_int8,use_itn ): if device!=self.device: self.device=device self.processor=None + if language!=self.language: + self.language=language + self.processor=None if num_threads!=self.num_threads: self.num_threads=num_threads self.processor=None @@ -146,8 +158,15 @@ def run(self,audio,device,language,num_threads,use_int8,use_itn ): self.use_int8) if 'waveform' in audio and 'sample_rate' in audio: - waveform_numpy=audio['waveform'].numpy().transpose(1, 0) # 转换为 (num_samples, num_channels) - _sample_rate=audio['sample_rate'] + waveform = audio['waveform'] + # print("Original shape:", waveform.shape) # 打印原始形状 + if waveform.ndim == 3 and waveform.shape[0] == 1: # 检查是否为三维且 batch_size 为 1 + waveform = waveform.squeeze(0) # 移除 batch_size 维度 + waveform_numpy = waveform.numpy().transpose(1, 0) # 转换为 (num_samples, num_channels) + else: + raise ValueError("Unexpected waveform dimensions") + + _sample_rate = audio['sample_rate'] results=self.processor.process_audio(waveform_numpy, _sample_rate, language, use_itn) From 228e5d9183a13ad5d7eb50b16994160932461919 Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 23:13:31 +0800 Subject: [PATCH 5/6] Update SenseVoice.py --- nodes/SenseVoice.py | 51 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/nodes/SenseVoice.py b/nodes/SenseVoice.py index 844043c..8c0d4de 100644 --- a/nodes/SenseVoice.py +++ b/nodes/SenseVoice.py @@ -4,7 +4,7 @@ import time from huggingface_hub import snapshot_download -import torch +import torch,re from sensevoice.onnx.sense_voice_ort_session import SenseVoiceInferenceSession from sensevoice.utils.frontend import WavFrontend from sensevoice.utils.fsmn_vad import FSMNVad @@ -29,7 +29,34 @@ class AnyType(str): def __ne__(self, __value: object) -> bool: return False -any_type = AnyType("*") +any_type = AnyType("*") + +# 字幕 +def format_to_srt(channel_id, start_time_ms, end_time_ms, asr_result): + start_time = start_time_ms / 1000 + end_time = end_time_ms / 1000 + + def format_time(seconds): + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + seconds = seconds % 60 + milliseconds = int((seconds - int(seconds)) * 1000) + return f"{hours:02}:{minutes:02}:{int(seconds):02},{milliseconds:03}" + + start_time_str = format_time(start_time) + end_time_str = format_time(end_time) + + pattern = r"<\|(.+?)\|><\|(.+?)\|><\|(.+?)\|><\|(.+?)\|>(.+)" + match = re.match(pattern,asr_result) + lang, emotion, audio_type, itn, text = match.groups() + # 😊 表示高兴,😡 表示愤怒,😔 表示悲伤。对于音频事件,🎼 表示音乐,😀 表示笑声,👏 表示掌声 + + srt_content = f"1\n{start_time_str} --> {end_time_str}\n{text}\n" + + logging.info(f"[Channel {channel_id}] [{start_time}s - {end_time}s] [{lang}] [{emotion}] [{audio_type}] [{itn}] {text}") + + return lang, emotion, audio_type, itn,srt_content,start_time,end_time,text + class SenseVoiceProcessor: def __init__(self, download_model_path, device, num_threads, use_int8): @@ -81,9 +108,23 @@ def process_audio(self, waveform, _sample_rate, language, use_itn): language=languages[language], use_itn=use_itn, ) - logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") - - results.append(asr_result) + + lang, emotion, audio_type, itn,srt_content,start_time,end_time,text=format_to_srt( + channel_id, + part[0] , + part[1], + asr_result) + + results.append({ + "language":lang, + "emotion":emotion, + "audio_type":audio_type, + "itn":itn, + "srt_content":srt_content, + "start_time":start_time, + "end_time":end_time, + "text":text + }) self.vad.vad.all_reset_detection() pbar.update(1) # 更新进度条 From be8ccc1dc46b7ad45a4dee366d759e85eeac2615 Mon Sep 17 00:00:00 2001 From: shadowcz007 Date: Tue, 1 Oct 2024 23:16:02 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20SenseVoice?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 ++ pyproject.toml | 2 +- web/javascript/checkVersion_mixlab.js | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 18718c4..c5da4a3 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ For business cooperation, please contact email 389570357@qq.com ##### `最新`: +- 新增 SenseVoice + - [新增JS-SDK,方便直接在前端项目中使用comfyui](https://github.com/shadowcz007/comfyui-js-sdk) - 新增API调用图像生成节点 TextToImage Siliconflow,可以直接调用Siliconflow提供的flux生成图像 diff --git a/pyproject.toml b/pyproject.toml index 6b41dac..41924b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-mixlab-nodes" description = "3D, ScreenShareNode & FloatingVideoNode, SpeechRecognition & SpeechSynthesis, GPT, LoadImagesFromLocal, Layers, Other Nodes, ..." -version = "0.42.0" +version = "0.43.0" license = "MIT" dependencies = ["numpy", "pyOpenSSL", "watchdog", "opencv-python-headless", "matplotlib", "openai", "simple-lama-inpainting", "clip-interrogator==0.6.0", "transformers>=4.36.0", "lark-parser", "imageio-ffmpeg", "rembg[gpu]", "omegaconf==2.3.0", "Pillow>=9.5.0", "einops==0.7.0", "trimesh>=4.0.5", "huggingface-hub", "scikit-image"] diff --git a/web/javascript/checkVersion_mixlab.js b/web/javascript/checkVersion_mixlab.js index fd881b4..111158e 100644 --- a/web/javascript/checkVersion_mixlab.js +++ b/web/javascript/checkVersion_mixlab.js @@ -3,7 +3,7 @@ import { app } from '../../../scripts/app.js' const repoOwner = 'shadowcz007' // 替换为仓库的所有者 const repoName = 'comfyui-mixlab-nodes' // 替换为仓库的名称 -const version = 'v0.42.0' +const version = 'v0.43.0' fetch(`https://api.github.com/repos/${repoOwner}/${repoName}/releases/latest`) .then(response => response.json())