diff --git a/.gitignore b/.gitignore index 945a3d4a..2883d27d 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ main-dev.py tmp.* *.wav *.mp3 +*.mp4 *.flv *.vtt @@ -206,4 +207,4 @@ runs .ruff_cache/ # xml -*.xml \ No newline at end of file +*.xml diff --git a/deploy/modal/README.md b/deploy/modal/README.md index 0eaabcd6..bec23f94 100644 --- a/deploy/modal/README.md +++ b/deploy/modal/README.md @@ -174,6 +174,9 @@ IMAGE_NAME=minicpmo IMAGE_CONCURRENT_CN=1 IMAGE_GPU=L4 modal serve -e achatbot s # moonshotai/Kimi-VL-A3B-Instruct (or Thinking) use 2xL4 like deepseek-ai/deepseek-vl2-small IMAGE_NAME=kimi IMAGE_CONCURRENT_CN=1 IMAGE_GPU=L4:2 modal serve -e achatbot src/fastapi_webrtc_vision_bot_serve.py + +# webrtc_vision_bot serve on qwen2.5omni vision llm pip image +IMAGE_NAME=qwen2.5omni IMAGE_CONCURRENT_CN=1 IMAGE_GPU=L4 modal serve -e achatbot src/fastapi_webrtc_vision_bot_serve.py ``` - curl api to run chat room bot with webrtc (daily/livekit/agora) ```shell @@ -384,7 +387,7 @@ curl --location 'https://weedge-achatbot--fastapi-webrtc-freeze-omni-voice-bo-4b ``` ### webrtc_minicpmo_vision_voice_bot -- run webrtc_minicpmo_vision_voice_bot serve with task queue(redis) +- run webrtc_minicpmo_vision_voice_bot serve ```shell # webrtc_audio_bot serve on default pip image # need create .env.example to modal Secrets for webrtc key @@ -560,6 +563,168 @@ curl --location 'https://weedge-achatbot--fastapi-webrtc-minicpmo-omni-bot-srv-a "config_list": [] }' ``` +### webrtc_qwen2_5omni_vision_voice_bot +- run webrtc_qwen2_5omni_vision_voice_bot serve with webrtc +```shell +# webrtc_audio_bot serve on default pip image +# need create .env.example to modal Secrets for webrtc key +IMAGE_CONCURRENT_CN=1 IMAGE_GPU=L40s modal serve -e achatbot src/fastapi_webrtc_qwen2_5omni_vision_voice_bot_serve.py +``` +- curl api to run chat room bot with webrtc (livekit_room) +```shell +# thinker gen chunk token and hidden states -> talker gen vq codes token -> code2wav gen chunk wav | don't use_sliding_window_code2wav +curl --location 'https://weedge-achatbot--fastapi-webrtc-qwen2-5omni-bot-srv-app-dev.modal.run/bot_join/chat-room/LivekitQwen2_5OmniVisionVoiceBot' \ +--header 'Content-Type: application/json' \ +--data '{ + "chat_bot_name": "LivekitQwen2_5OmniVisionVoiceBot", + "room_name": "chat-room", + "room_url": "", + "token": "", + "room_manager": { + "tag": "livekit_room", + "args": { + "bot_name": "LivekitQwen2_5OmniVisionVoiceBot", + "is_common_session": false + } + }, + "services": { + "pipeline": "achatbot", + "vad": "silero", + "omni_llm": "llm_transformers_manual_qwen2_5omni_vision_voice" + }, + "config": { + "vad": { + "tag": "silero_vad_analyzer", + "args": { "stop_secs": 0.7 } + }, + "omni_llm": { + "tag": "llm_transformers_manual_qwen2_5omni_vision_voice", + "args": { + "lm_device": "cuda", + "lm_torch_dtype": "bfloat16", + "lm_attn_impl": "flash_attention_2", + "warmup_steps": 1, + "chat_history_size": 0, + "thinker_eos_token_ids": [151644, 151645], + "thinker_args": { + "lm_gen_temperature": 0.95, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_min_new_tokens": 1, + "lm_gen_max_new_tokens": 1024, + "lm_gen_max_tokens_per_step": 10, + "lm_gen_repetition_penalty": 1.1 + }, + "talker_args": { + "lm_gen_temperature": 0.95, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_min_new_tokens": 1, + "lm_gen_max_new_tokens": 2048, + "lm_gen_repetition_penalty": 1.1 + }, + "talker_skip_thinker_token_ids": [], + "talker_eos_token_ids": [8292, 8294], + "code2wav_args": { + "model_path": "/root/.achatbot/models/Qwen/Qwen2.5-Omni-7B", + "enable_torch_compile": false, + "enable_torch_compile_first_chunk": false, + "odeint_method": "euler", + "odeint_method_relaxed": false, + "batched_chunk": 3, + "frequency": "50hz", + "device": "cuda", + "num_steps": 10, + "guidance_scale": 0.5, + "sway_coefficient": -1.0, + "code2wav_dynamic_batch": false + }, + "speaker": "Chelsie", + "is_use_sliding_window_code2wav": false, + "lm_model_name_or_path": "/root/.achatbot/models/Qwen/Qwen2.5-Omni-7B" + } + } + }, + "config_list": [] +} +' +# thinker gen chunk token and hidden states -> talker gen vq codes token -> code2wav gen chunk wav | use_sliding_window_code2wav | no torch.compile +curl --location 'https://weedge-achatbot--fastapi-webrtc-qwen2-5omni-bot-srv-app-dev.modal.run/bot_join/chat-room/LivekitQwen2_5OmniVisionVoiceBot' \ +--header 'Content-Type: application/json' \ +--data '{ + "chat_bot_name": "LivekitQwen2_5OmniVisionVoiceBot", + "room_name": "chat-room", + "room_url": "", + "token": "", + "room_manager": { + "tag": "livekit_room", + "args": { + "bot_name": "LivekitQwen2_5OmniVisionVoiceBot", + "is_common_session": false + } + }, + "services": { + "pipeline": "achatbot", + "vad": "silero", + "omni_llm": "llm_transformers_manual_qwen2_5omni_vision_voice" + }, + "config": { + "vad": { + "tag": "silero_vad_analyzer", + "args": { "stop_secs": 0.7 } + }, + "omni_llm": { + "tag": "llm_transformers_manual_qwen2_5omni_vision_voice", + "args": { + "lm_device": "cuda", + "lm_torch_dtype": "bfloat16", + "lm_attn_impl": "flash_attention_2", + "warmup_steps": 1, + "chat_history_size": 0, + "thinker_eos_token_ids": [151644, 151645], + "thinker_args": { + "lm_gen_temperature": 0.95, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_min_new_tokens": 1, + "lm_gen_max_new_tokens": 1024, + "lm_gen_max_tokens_per_step": 10, + "lm_gen_repetition_penalty": 1.1 + }, + "talker_args": { + "lm_gen_temperature": 0.95, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_min_new_tokens": 1, + "lm_gen_max_new_tokens": 2048, + "lm_gen_repetition_penalty": 1.1 + }, + "talker_skip_thinker_token_ids": [], + "talker_eos_token_ids": [8292, 8294], + "code2wav_args": { + "model_path": "/root/.achatbot/models/Qwen/Qwen2.5-Omni-7B", + "enable_torch_compile": false, + "enable_torch_compile_first_chunk": false, + "odeint_method": "euler", + "odeint_method_relaxed": false, + "batched_chunk": 3, + "frequency": "50hz", + "device": "cuda", + "num_steps": 10, + "guidance_scale": 0.5, + "sway_coefficient": -1.0, + "code2wav_dynamic_batch": false + }, + "speaker": "Chelsie", + "is_use_sliding_window_code2wav": true, + "lm_model_name_or_path": "/root/.achatbot/models/Qwen/Qwen2.5-Omni-7B" + } + } + }, + "config_list": [] +} +' +``` ### webrtc_step_voice_bot - run webrtc_step_voice_bot serve with task queue(redis) diff --git a/deploy/modal/src/download_models.py b/deploy/modal/src/download_models.py index 14521596..76844241 100644 --- a/deploy/modal/src/download_models.py +++ b/deploy/modal/src/download_models.py @@ -27,7 +27,7 @@ retries=0, cpu=8.0, image=download_image, - secrets=[modal.Secret.from_name("achatbot")], + # secrets=[modal.Secret.from_name("achatbot")], volumes={HF_MODEL_DIR: hf_model_vol}, timeout=1200, scaledown_window=1200, diff --git a/deploy/modal/src/fastapi_webrtc_minicpmo_vision_voice_bot_serve.py b/deploy/modal/src/fastapi_webrtc_minicpmo_vision_voice_bot_serve.py index 8f3bb3a6..9e906a00 100644 --- a/deploy/modal/src/fastapi_webrtc_minicpmo_vision_voice_bot_serve.py +++ b/deploy/modal/src/fastapi_webrtc_minicpmo_vision_voice_bot_serve.py @@ -16,9 +16,6 @@ class ContainerRuntimeConfig: "ACHATBOT_PKG": "1", "LOG_LEVEL": os.getenv("LOG_LEVEL", "info"), "IMAGE_NAME": os.getenv("IMAGE_NAME", "default"), - "ASR_TAG": "sense_voice_asr", - "ASR_LANG": "zn", - "ASR_MODEL_NAME_OR_PATH": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall", "USE_GPTQ_CKPT": os.getenv("USE_GPTQ_CKPT", ""), "LLM_MODEL_NAME_OR_PATH": f'/root/.achatbot/models/{os.getenv("LLM_MODEL_NAME_OR_PATH", "openbmb/MiniCPM-o-2_6")}', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list @@ -42,10 +39,9 @@ class ContainerRuntimeConfig: "fastapi_bot_server," "livekit,livekit-api,daily,agora," "silero_vad_analyzer," - "sense_voice_asr,deepgram_asr_processor," "llm_transformers_manual_vision_voice_minicpmo," "queue" - "]~=0.0.8.12", + "]~=0.0.9.post10", "huggingface_hub[hf_transfer]==0.26.0", "wget", ], @@ -170,8 +166,20 @@ def setup(self): @modal.enter() def enter(self): - print("enter done") - # volume.reload() + # run container runtime to enter when container is starting + import subprocess + import torch + + subprocess.run("nvidia-smi --version", shell=True) + gpu_prop = None + if torch.cuda.is_available(): + gpu_prop = torch.cuda.get_device_properties("cuda:0") + print(gpu_prop) + torch.multiprocessing.set_start_method("spawn", force=True) + else: + print("CUDA is not available.") + + # todo: init model to load, now use api to load model to run bot with config @modal.asgi_app() def app(self): diff --git a/deploy/modal/src/fastapi_webrtc_qwen2_5omni_vision_voice_bot_serve.py b/deploy/modal/src/fastapi_webrtc_qwen2_5omni_vision_voice_bot_serve.py new file mode 100644 index 00000000..10dfbe48 --- /dev/null +++ b/deploy/modal/src/fastapi_webrtc_qwen2_5omni_vision_voice_bot_serve.py @@ -0,0 +1,86 @@ +import modal +import os + +achatbot_version = os.getenv("ACHATBOT_VERSION", "0.0.9.post10") +qwen2_5omni_img = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install("git", "git-lfs", "ffmpeg", "clang", "cmake") + .pip_install("wheel") + .pip_install( + [ + "achatbot[" + "fastapi_bot_server," + "livekit,livekit-api,daily,agora," + "silero_vad_analyzer,asr_processor," + "llm_transformers_manual_vision_voice_qwen," + "queue" + f"]=={achatbot_version}", + ], + extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), + ) + .run_commands( + "pip install git+https://github.com/huggingface/transformers@v4.51.3-Qwen2.5-Omni-preview" + ) + .pip_install("flash-attn", extra_options="--no-build-isolation") + .env( + { + "ACHATBOT_PKG": "1", + "LOG_LEVEL": os.getenv("LOG_LEVEL", "info"), + "LLM_MODEL_NAME_OR_PATH": f'/root/.achatbot/models/{os.getenv("LLM_MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-Omni-7B")}', + # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list + } + ) +) + + +# ----------------------- app ------------------------------- +app = modal.App("fastapi_webrtc_qwen2_5omni_bot") + +HF_MODEL_DIR = "/root/.achatbot/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +ASSETS_DIR = "/root/.achatbot/assets" +assets_dir = modal.Volume.from_name("assets", create_if_missing=True) + + +# 128 MiB of memory and 0.125 CPU cores by default container runtime +@app.cls( + image=qwen2_5omni_img, + gpu=os.getenv("IMAGE_GPU", None), + secrets=[modal.Secret.from_name("achatbot")], + volumes={ + HF_MODEL_DIR: hf_model_vol, + ASSETS_DIR: assets_dir, + }, + cpu=2.0, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=1, + allow_concurrent_inputs=int(os.getenv("IMAGE_CONCURRENT_CN", "1")), +) +class Srv: + @modal.enter() + def enter(self): + # run container runtime to enter when container is starting + import subprocess + import torch + + subprocess.run("nvidia-smi --version", shell=True) + gpu_prop = None + if torch.cuda.is_available(): + gpu_prop = torch.cuda.get_device_properties("cuda:0") + print(gpu_prop) + torch.multiprocessing.set_start_method("spawn", force=True) + else: + print("CUDA is not available.") + + # todo: init model to load, now use api to load model to run bot with config + + @modal.asgi_app() + def app(self): + from achatbot.cmd.http.server.fastapi_daily_bot_serve import app as fastapi_app + + return fastapi_app diff --git a/deploy/modal/src/fastapi_webrtc_vision_bot_serve.py b/deploy/modal/src/fastapi_webrtc_vision_bot_serve.py index b4b1fd92..9aaa5cff 100644 --- a/deploy/modal/src/fastapi_webrtc_vision_bot_serve.py +++ b/deploy/modal/src/fastapi_webrtc_vision_bot_serve.py @@ -1,7 +1,7 @@ import modal import os -achatbot_version = os.getenv("ACHATBOT_VERSION", "0.0.9.post8") +achatbot_version = os.getenv("ACHATBOT_VERSION", "0.0.9.post10") vision_bot_img = ( # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags @@ -10,6 +10,7 @@ add_python="3.10", ) .apt_install("git", "git-lfs", "ffmpeg", "cmake") + .pip_install("wheel") .pip_install( [ "achatbot[" @@ -25,7 +26,6 @@ ], extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), ) - .pip_install("wheel") .pip_install("flash-attn", extra_options="--no-build-isolation") .env( { @@ -123,6 +123,21 @@ class ContainerRuntimeConfig: } ) ), + "qwen2_5omni": ( + vision_bot_img.pip_install( + [ + f"achatbot[llm_transformers_manual_vision_voice_qwen]=={achatbot_version}", + ], + extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), + ) + .run_commands("pip install git+https://github.com/huggingface/transformers@v4.51.3-Qwen2.5-Omni-preview") + .env( + { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "LLM_MODEL_NAME_OR_PATH": f'/root/.achatbot/models/{os.getenv("LLM_MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-Omni-7B")}', + } + ) + ), } @staticmethod diff --git a/deploy/modal/src/llm/transformers/qwen2_5omni.py b/deploy/modal/src/llm/transformers/qwen2_5omni.py new file mode 100644 index 00000000..1c201b64 --- /dev/null +++ b/deploy/modal/src/llm/transformers/qwen2_5omni.py @@ -0,0 +1,2651 @@ +from time import perf_counter +import time +from typing import Optional +import modal +import os + +app = modal.App("qwen2_5_omni") +omni_img = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install("git", "git-lfs", "ffmpeg", "clang", "cmake") + .pip_install("wheel", "openai", "qwen-omni-utils[decord]") + .pip_install( + "accelerate", + "torch==2.6.0", + "torchaudio==2.6.0", + "torchvision==0.21.0", + "soundfile==0.13.0", + "librosa==0.11.0", + ) + .run_commands( + "pip install git+https://github.com/huggingface/transformers@v4.51.3-Qwen2.5-Omni-preview" + ) + .pip_install("flash-attn", extra_options="--no-build-isolation") + .env( + { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + ) +) + +achatbot_version = os.getenv("ACHATBOT_VERSION", "") +if achatbot_version: + omni_img = ( + omni_img.pip_install( + f"achatbot[llm_transformers_manual_vision_voice_qwen]=={achatbot_version}", + extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), + ) + # .pip_install("flash-attn==2.5.8", extra_options="--no-build-isolation") + .env( + { + "ACHATBOT_PKG": "1", + "LOG_LEVEL": os.getenv("LOG_LEVEL", "info"), + } + ) + ) + +HF_MODEL_DIR = "/root/.achatbot/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +ASSETS_DIR = "/root/.achatbot/assets" +assets_dir = modal.Volume.from_name("assets", create_if_missing=True) + +# NOTE: if want to generate speech, need use this system prompt to generate speech +SPEECH_SYS_PROMPT = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." +# Voice settings +SPEAKER_LIST = ["Chelsie", "Ethan"] +DEFAULT_SPEAKER = "Ethan" + +with omni_img.imports(): + import subprocess + from threading import Thread + from queue import Queue + import numpy as np + from transformers.generation.streamers import BaseStreamer + + import torch + from transformers import ( + Qwen2_5OmniForConditionalGeneration, + Qwen2_5OmniProcessor, + TextIteratorStreamer, + AutoConfig, + AutoProcessor, + ) + from qwen_omni_utils import process_mm_info + + def print_model_params(model: torch.nn.Module, extra_info=""): + # print the number of parameters in the model + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + # print(model) + print(f"{extra_info} {model_million_params} M parameters") + + class Qwen2_5OmniForConditionalGenerationNew(Qwen2_5OmniForConditionalGeneration): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print_model_params(self.thinker, "qwen2.5omni_thinker") + print_model_params(self.talker, "qwen2.5omni_talker") + print_model_params(self.token2wav, "qwen2.5omni_token2wav") + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError( + f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}" + ) + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError( + "Qwen2.5-Omni currently does not support batched inference with audio output" + ) + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) + if thinker_kwargs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = ( + audio_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = ( + image_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = ( + video_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to( + self.talker.device + ) + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.talker.device + ), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like( + input_ids, + fill_value=self.talker.codec_mask_token, + device=self.talker.device, + ), + torch.tensor( + [[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device + ), + torch.tensor( + [[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device + ), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device + ) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to( + self.talker.device + ) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + # stream + skip_prompt = kwargs.get("skip_prompt", True) + streamer = TokenStreamer(skip_prompt=skip_prompt) + talker_kwargs = dict( + input_ids=talker_input_ids, + streamer=streamer, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{ + k: (v.to(self.talker.device) if torch.is_tensor(v) else v) + for k, v in talker_kwargs.items() + }, + ) + # print(talker_kwargs.keys()) + thread = Thread(target=self.talker.generate, kwargs=talker_kwargs) + thread.start() + talker_generate_codes = [] + times = [] + start_time = perf_counter() + for token_id in streamer: + # print(token_id) + times.append(perf_counter() - start_time) + start_time = perf_counter() + talker_generate_codes.append(token_id) + print( + f"generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + offset = 0 + if skip_prompt is False: + offset = talker_input_ids.shape[1] + # print( + # talker_input_ids.shape[1], + # # talker_generate_codes, + # talker_generate_codes[:offset], + # talker_generate_codes[offset:-1], + # ) + talker_generate_codes = torch.tensor( + [talker_generate_codes[offset:-1]], + dtype=torch.long, + device=self.talker.device, + ) + + # no stream + # talker_result = self.talker.generate( + # input_ids=talker_input_ids, + # input_text_ids=talker_input_text_ids, + # thinker_reply_part=thinker_reply_part, + # inputs_embeds=talker_inputs_embeds, + # attention_mask=talker_attention_mask, + # suppress_tokens=[self.talker.codec_bos_token], + # **{ + # k: (v.to(self.talker.device) if torch.is_tensor(v) else v) + # for k, v in talker_kwargs.items() + # }, + # ) + # print(talker_result.shape, talker_result) + # talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # print(f"talker_generate_codes:{talker_generate_codes.shape} {talker_generate_codes}") + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + # print(self.token2wav.device, speaker_params, token2wav_kwargs) + + wav = self.token2wav( + talker_generate_codes.to(self.token2wav.device), + conditioning=speaker_params["cond"].to(self.token2wav.device).float(), + reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + subprocess.run("nvidia-smi --version", shell=True) + subprocess.run("nvcc --version", shell=True) + gpu_prop = torch.cuda.get_device_properties("cuda") + print(gpu_prop) + + if not os.getenv("ACHATBOT_PKG"): + model_path = os.path.join(HF_MODEL_DIR, "Qwen/Qwen2.5-Omni-7B") + config = AutoConfig.from_pretrained(model_path) + model = Qwen2_5OmniForConditionalGenerationNew.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map={"": 0}, + attn_implementation="flash_attention_2", + config=config, + ).eval() + + # NOTE: when disable talker, generate must set return_audio=False + # model.disable_talker() + + print_model_params(model, "Qwen2.5Omni") + + # processor = Qwen2_5OmniProcessor.from_pretrained(model_path) + processor = AutoProcessor.from_pretrained( + model_path, + min_pixels=256 * 28 * 28, + max_pixels=1280 * 28 * 28, + trust_remote_code=True, + ) + + # subprocess.run("nvidia-smi", shell=True) + + def inference( + messages, + return_audio=False, + use_audio_in_video=False, + thinker_do_sample=False, + speaker=DEFAULT_SPEAKER, + ): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs, video_inputs = process_vision_info([messages]) + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + for k, v in inputs.items(): + print(k, v.shape) + + output = model.generate( + **inputs, + use_audio_in_video=use_audio_in_video, + return_audio=return_audio, + speaker=speaker, + thinker_do_sample=thinker_do_sample, + ) + print("\n====generate use memory=====\n") + subprocess.run( + """nvidia-smi --query-gpu=index,memory.used,memory.total --format=csv,noheader,nounits | awk -F',' '{print "GPU "$1": "$2"/"$3" MiB\\n"}'""", + shell=True, + ) + print("\n=========\n") + # print(output) + text_token_ids = output + audio = None + if return_audio and len(output) > 1: + text_token_ids = output[0].detach() + audio = output[1].unsqueeze(0).detach() + + text = processor.batch_decode( + text_token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + torch.cuda.empty_cache() + + return text, audio + + def thinker_inference_stream( + messages, + use_audio_in_video=False, + ): + print(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs, video_inputs = process_vision_info([messages]) + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + print(text) + {print(f"audios[{i}]: {item.shape}") for i, item in enumerate(audios)} if audios else print( + audios + ) + {print(f"images[{i}]: {item}") for i, item in enumerate(images)} if images else print( + images + ) + {print(f"videos[{i}]: {item.shape}") for i, item in enumerate(videos)} if videos else print( + videos + ) + + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + for k, v in inputs.items(): + print(k, v.shape) + + streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) + + generation_kwargs = dict( + **inputs, + streamer=streamer, + use_audio_in_video=use_audio_in_video, + return_audio=False, + thinker_do_sample=True, + # do_sample=True, + top_k=20, + top_p=0.8, + temperature=0.1, + repetition_penalty=1.0, + min_new_tokens=0, + max_new_tokens=1024, + ) + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + generated_text = "" + times = [] + start_time = perf_counter() + for new_text in streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + generated_text += new_text + yield new_text + print( + f"generate [{generated_text}] first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + torch.cuda.empty_cache() + + def thinker_inference_chunk_stream( + messages, + use_audio_in_video=False, + max_new_tokens=2048, + max_tokens_per_step=3, # Controls how many tokens to generate *per step* + eos_token_ids=[151644, 151645], # Define EOS tokens + output_hidden_states=False, + ): + print(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs, video_inputs = process_vision_info([messages]) + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + print(text) + {print(f"audios[{i}]: {item.shape}") for i, item in enumerate(audios)} if audios else print( + audios + ) + {print(f"images[{i}]: {item}") for i, item in enumerate(images)} if images else print( + images + ) + {print(f"videos[{i}]: {item.shape}") for i, item in enumerate(videos)} if videos else print( + videos + ) + + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + for k, v in inputs.items(): + print(k, v.shape) + """ + e.g.: + input_ids torch.Size([1, 20174]) + attention_mask torch.Size([1, 20174]) + + pixel_values_videos torch.Size([77760, 1176]) + image_grid_thw torch.Size([1, 3]) # just image only + + video_grid_thw torch.Size([1, 3]) + video_second_per_grid torch.Size([1]) + + feature_attention_mask torch.Size([1, 30000]) + input_features torch.Size([1, 128, 30000]) + """ + return thinker_generate_chunk( + inputs, + max_new_tokens=max_new_tokens, + max_tokens_per_step=max_tokens_per_step, + use_audio_in_video=use_audio_in_video, + eos_token_ids=eos_token_ids, + output_hidden_states=output_hidden_states, + ) + + @torch.no_grad() + def thinker_generate_chunk( + inputs: dict, + max_new_tokens=2048, + max_tokens_per_step=10, # Controls how many tokens to generate *per step* + use_audio_in_video=False, + eos_token_ids=[151644, 151645], # Define EOS tokens + output_hidden_states=False, + stop_strings_per_step=[".", "。"], + ): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", None) + + if max_tokens_per_step > max_new_tokens: + max_tokens_per_step = max_new_tokens + + # Keep track of the full generated sequence full_generated_ids = input_ids.clone() + # Ensure full_attention_mask is correctly initialized and expanded + full_attention_mask = ( + attention_mask.clone() + if attention_mask is not None + else torch.ones_like(input_ids, device=input_ids.device) + ) + full_generated_ids = input_ids.clone() + + # KV cache + # past_key_values = None + + # Inputs for the current step + current_input_ids = full_generated_ids + # The attention mask passed to generate should cover the sequence length for the current step + current_attention_mask = full_attention_mask + + total_new_tokens_generated = 0 + generated_text = "" + hidden_states = None + hidden_states_len = 0 + + times = [] + while total_new_tokens_generated < max_new_tokens: + # Prepare inputs for generate call + # print(current_input_ids, current_attention_mask.shape) + # https://huggingface.co/docs/transformers/v4.51.3/en/main_classes/text_generation#transformers.GenerationMixin.generate + model_inputs = { + "input_ids": current_input_ids, + "attention_mask": current_attention_mask, + # "past_key_values": past_key_values, + "use_cache": True, + "use_audio_in_video": use_audio_in_video, + "do_sample": True, + "top_k": 10, + "top_p": 0.9, + "temperature": 0.95, + "repetition_penalty": 1.1, + "min_new_tokens": 1, # Ensure at least one token is generated if possible + "max_new_tokens": max_tokens_per_step, # Generate in smaller steps + # output_hidden_states/scores can consume memory, + # enable if needed downstream(talker) + "output_hidden_states": output_hidden_states, + "return_dict_in_generate": True, + # "output_scores": True, + "eos_token_id": eos_token_ids, + "pad_token_id": processor.tokenizer.pad_token_id, + } + model_inputs = {**inputs, **model_inputs} + for k, v in model_inputs.items(): + if isinstance(v, torch.Tensor): + print(k, v.shape) + else: + print(k, v) + if len(stop_strings_per_step) > 0: + model_inputs["stop_strings"] = stop_strings_per_step + model_inputs["tokenizer"] = processor.tokenizer + + start_time = perf_counter() + outputs = model.thinker.generate(**model_inputs) + times.append(perf_counter() - start_time) + + # Extract newly generated token IDs *for this step* + # `outputs.sequences` contains the input_ids for this step + new tokens generated in this step + step_new_ids = outputs.sequences[:, current_input_ids.shape[1] :] + num_step_new_tokens = step_new_ids.shape[1] + + if num_step_new_tokens == 0: # Handle case where generate stops early + print("Warning: generate produced 0 new tokens in this step.") + break + + if output_hidden_states is True: + hidden_states = outputs.hidden_states + print(hidden_states[0][0].shape) + hidden_states_len = ( + hidden_states_len if hidden_states_len > 0 else hidden_states[0][0].shape[1] + ) + print(f"hidden_states_len: {hidden_states_len}") + # new generate thinker_token_embeds + thinker_new_token_embeds = hidden_states[0][0][:, :hidden_states_len, :] + hidden_states = ( + (thinker_new_token_embeds,) + hidden_states[0][1:], + ) + hidden_states[1:] + # new generate thinker_hidden_states + thinker_new_hidden_states = hidden_states[0][-1][:, :hidden_states_len, :] + hidden_states = ( + hidden_states[0][:-1] + (thinker_new_hidden_states,), + ) + hidden_states[1:] + + # Decode and print only the text generated in this step + step_new_text = processor.decode(step_new_ids[0], skip_special_tokens=True) + yield { + "thinker_generate_text": step_new_text, + "thinker_generate_ids": step_new_ids, + "thinker_generate_hidden_states": hidden_states, + } # TODO: put async queue here + generated_text += step_new_text + total_new_tokens_generated += num_step_new_tokens + + # Update the full sequence + full_generated_ids = torch.cat([full_generated_ids, step_new_ids], dim=1) + + # Prepare for the next iteration: + # Input is only the last generated token + # NOTE: need use past_key_values to keep the context by manually, + # current_input_ids = step_new_ids[:, -1:] + # so we can't use the last generated token, use cache instead + # input ids need to be the full sequence for next generation + current_input_ids = full_generated_ids + + # Update past_key_values + # past_key_values = outputs.past_key_values + + # Update attention mask by appending 1s for the new tokens + full_attention_mask = torch.cat( + [full_attention_mask, torch.ones_like(step_new_ids)], dim=1 + ) + current_attention_mask = full_attention_mask + + # torch.cuda.empty_cache() + + # Check if EOS token was generated in this step + if step_new_ids[0, -1].item() in eos_token_ids: + print("EOS token generated.") + break + + # Check if max_new_tokens limit is reached (after processing the step) + if total_new_tokens_generated >= max_new_tokens: + print("Max new tokens limit reached.") + break + + print(f"Total generated text: {generated_text}") + print(f"Total new tokens generated: {total_new_tokens_generated}") + print( + f"max_tokens_per_step: {max_tokens_per_step} | first chunk generated cost: {times[0]} s | total cost: {sum(times)} s" + ) + + def talker_generate_chunk( + inputs: dict, + thinker_chunk_stream, + speaker=DEFAULT_SPEAKER, + talker_eos_token_id: list[int] = [8292, 8294], + mask_embedding: bool = True, + ): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", None) + + for chunk in thinker_chunk_stream: + thinker_generate_text = chunk["thinker_generate_text"] + if thinker_generate_text in " \n\r,;.?!,;。?!": + yield (thinker_generate_text, torch.empty([1, 0])) + continue + thinker_generate_hidden_states = chunk["thinker_generate_hidden_states"] + if thinker_generate_hidden_states is None: + yield (thinker_generate_text, torch.empty([1, 0])) + continue + + processed_thinker_hidden = thinker_generate_hidden_states + if mask_embedding is True: + print(f"mask embedding") + embeds_to_talker = ( + thinker_generate_hidden_states[0][0].clone().to(model.talker.device) + ) + if inputs.get("input_features", None) is not None: + audio_ids_mask = input_ids == model.config.thinker_config.audio_token_index + audio_mask = ( + audio_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if inputs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == model.config.thinker_config.image_token_index + image_mask = ( + image_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if inputs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == model.config.thinker_config.video_token_index + video_mask = ( + video_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_generate_hidden_states[0][1:], + ) + thinker_generate_hidden_states[1:] + + thinker_generate_ids = chunk["thinker_generate_ids"].to(model.talker.device) + thinker_token_embeds = [ + token_hidden_states[0].to(model.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(model.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + print( + f"[{thinker_generate_text}] len(thinker_generate_hidden_states):{len(processed_thinker_hidden)}" + ) + for i in range(len(processed_thinker_hidden)): + print( + f"[{thinker_generate_text}] thinker_generate_hidden_states[{i}]:{processed_thinker_hidden[i][0].shape}, {processed_thinker_hidden[i][-1].shape}" + ) + # print( + # f"[{thinker_generate_text}] thinker_generate_hidden_states[0]:{thinker_generate_hidden_states[0][0][:,:5,:]}, {thinker_generate_hidden_states[0][-1][:,:5,:]}" + # ) + + talker_text_bos_token = model.speaker_map[speaker]["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(model.talker.device), + torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=model.talker.device + ), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + print( + f"[{thinker_generate_text}] talker_input_text_ids.shape:{talker_input_text_ids.shape}" + ) + + talker_input_ids = torch.cat( + [ + torch.full_like( + input_ids, + fill_value=model.talker.codec_mask_token, + device=model.talker.device, + ), + torch.tensor( + [[model.talker.codec_pad_token]], + dtype=torch.long, + device=model.talker.device, + ), + torch.tensor( + [[model.talker.codec_bos_token]], + dtype=torch.long, + device=model.talker.device, + ), + ], + dim=1, + ) + print(f"[{thinker_generate_text}] talker_input_ids.shape:{talker_input_ids.shape}") + + thinker_embed_tokens = model.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=model.thinker.device + ) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to( + model.talker.device + ) + print( + f"[{thinker_generate_text}] talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape} thinker_reply_part.shape {thinker_reply_part.shape}" + ) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + print( + f"[{thinker_generate_text}] talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape}" + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor( + [[model.talker.text_eos_token]], dtype=torch.long, device=model.thinker.device + ) + ).to(model.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor( + [[model.talker.text_pad_token]], dtype=torch.long, device=model.thinker.device + ) + ).to(model.talker.device) + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + print(f"[{thinker_generate_text}] thinker_reply_part.shape:{thinker_reply_part.shape}") + + talker_attention_mask = None + if attention_mask is not None: + talker_attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((1, 2))], dim=1 + ).to(model.talker.device) + + streamer = TokenStreamer(skip_prompt=True) + talker_kwargs = dict( + input_ids=talker_input_ids, + streamer=streamer, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[model.talker.codec_bos_token], + eos_token_id=talker_eos_token_id, + pad_token_id=8292, + do_sample=True, + top_k=10, + top_p=0.9, + temperature=0.95, + repetition_penalty=1.1, + min_new_tokens=0, + max_new_tokens=8192, + ) + # print(talker_kwargs.keys()) + thread = Thread(target=model.talker.generate, kwargs=talker_kwargs) + thread.start() + + # 3. Generate wavs from code + if model.token2wav.dtype != torch.float: + model.token2wav.float() + + code2wav_times = [] + talker_generate_codes = [] + times = [] + start_time = perf_counter() + pre_offset = 0 + for token_id in streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + if token_id in talker_eos_token_id: + break + talker_generate_codes.append(token_id) + chunk_code_length = len(talker_generate_codes) * 2 - 24 + if chunk_code_length > 0 and chunk_code_length % 48 == 0: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=model.talker.device, + ) + pre_offset = len(talker_generate_codes) + wav = ( + model.token2wav( + codes_tensor.to(model.token2wav.device), + conditioning=model.speaker_map[speaker]["cond"] + .to(model.token2wav.device) + .float(), + reference_mel=model.speaker_map[speaker]["ref_mel"] + .to(model.token2wav.device) + .float(), + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ) + .unsqueeze(0) + .detach() + ) + code2wav_times.append(perf_counter() - start_time) + yield (thinker_generate_text, wav) + start_time = perf_counter() + + print( + f"[{thinker_generate_text}] generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + + if len(talker_generate_codes) > pre_offset: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=model.talker.device, + ) + wav = ( + model.token2wav( + codes_tensor.to(model.token2wav.device), + conditioning=model.speaker_map[speaker]["cond"] + .to(model.token2wav.device) + .float(), + reference_mel=model.speaker_map[speaker]["ref_mel"] + .to(model.token2wav.device) + .float(), + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ) + .unsqueeze(0) + .detach() + ) + code2wav_times.append(perf_counter() - start_time) + yield (thinker_generate_text, wav) + + print( + f"[{thinker_generate_text}] code2wav streaming first chunk time: {code2wav_times[0]} s | cost: {sum(code2wav_times)} s" + ) + + torch.cuda.empty_cache() + + def generate_stream( + messages, + use_audio_in_video=False, + speaker=DEFAULT_SPEAKER, + thinker_max_new_tokens=2048, + thinker_max_tokens_per_step=10, # Controls how many tokens to generate *per step* + thinker_stop_strings_per_step=[".", "。"], + thinker_eos_token_ids=[151644, 151645], # Define EOS tokens + mask_embedding: bool = False, + ): + print(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs, video_inputs = process_vision_info([messages]) + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + print(text) + {print(f"audios[{i}]: {item.shape}") for i, item in enumerate(audios)} if audios else print( + audios + ) + {print(f"images[{i}]: {item}") for i, item in enumerate(images)} if images else print( + images + ) + {print(f"videos[{i}]: {item.shape}") for i, item in enumerate(videos)} if videos else print( + videos + ) + + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + for k, v in inputs.items(): + print(k, v.shape) + thinker_chunk_stream = thinker_generate_chunk( + inputs, + max_new_tokens=thinker_max_new_tokens, + max_tokens_per_step=thinker_max_tokens_per_step, + stop_strings_per_step=thinker_stop_strings_per_step, + use_audio_in_video=use_audio_in_video, + eos_token_ids=thinker_eos_token_ids, + output_hidden_states=True, + ) + return talker_generate_chunk( + inputs=inputs, + thinker_chunk_stream=thinker_chunk_stream, + speaker=speaker, + mask_embedding=mask_embedding, + ) + + @torch.no_grad() + def thinker_talker_inference_stream( + messages, + use_audio_in_video=False, + speaker=DEFAULT_SPEAKER, + talker_eos_token_id: list[int] = [8292, 8294], + ): + print(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs, video_inputs = process_vision_info([messages]) + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + print(text) + {print(f"audios[{i}]: {item.shape}") for i, item in enumerate(audios)} if audios else print( + audios + ) + {print(f"images[{i}]: {item}") for i, item in enumerate(images)} if images else print( + images + ) + {print(f"videos[{i}]: {item.shape}") for i, item in enumerate(videos)} if videos else print( + videos + ) + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + + thinker_result = model.thinker.generate( + **inputs, + use_audio_in_video=use_audio_in_video, + do_sample=True, + top_k=10, + top_p=0.9, + temperature=0.95, + repetition_penalty=1.1, + min_new_tokens=0, + max_new_tokens=2048, + output_hidden_states=True, + return_dict_in_generate=True, + ) + print(f" len(thinker_generate_hidden_states):{len(thinker_result.hidden_states)}") + for i in range(len(thinker_result.hidden_states)): + print( + f"thinker_generate_hidden_states[{i}]:{thinker_result.hidden_states[i][0].shape}, {thinker_result.hidden_states[i][-1].shape}" + ) + # 2. Generate speech tokens from talker module + input_ids = inputs["input_ids"] + + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(model.talker.device) + if inputs.get("input_features", None) is not None: + audio_ids_mask = input_ids == model.config.thinker_config.audio_token_index + audio_mask = ( + audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if inputs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == model.config.thinker_config.image_token_index + image_mask = ( + image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if inputs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == model.config.thinker_config.video_token_index + video_mask = ( + video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=model.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to( + model.talker.device + ) + thinker_token_embeds = [ + token_hidden_states[0].to(model.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(model.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + gen_text = processor.batch_decode( + thinker_result.sequences, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + # print(gen_text) + + talker_text_bos_token = model.speaker_map[speaker]["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(model.talker.device), + torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=model.talker.device + ), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + print(f"talker_input_text_ids.shape:{talker_input_text_ids.shape}") + + talker_input_ids = torch.cat( + [ + torch.full_like( + input_ids, fill_value=model.talker.codec_mask_token, device=model.talker.device + ), + torch.tensor( + [[model.talker.codec_pad_token]], dtype=torch.long, device=model.talker.device + ), + torch.tensor( + [[model.talker.codec_bos_token]], dtype=torch.long, device=model.talker.device + ), + ], + dim=1, + ) + print(f"talker_input_ids.shape:{talker_input_ids.shape}") + + thinker_embed_tokens = model.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=model.thinker.device + ) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(model.talker.device) + print( + f"talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape} thinker_reply_part.shape {thinker_reply_part.shape}" + ) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + print( + f"talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape}" + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor( + [[model.talker.text_eos_token]], dtype=torch.long, device=model.thinker.device + ) + ).to(model.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor( + [[model.talker.text_pad_token]], dtype=torch.long, device=model.thinker.device + ) + ).to(model.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + print(f"thinker_reply_part.shape:{thinker_reply_part.shape}") + + talker_attention_mask = None + if "attention_mask" in inputs: + talker_attention_mask = torch.cat( + [inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(model.talker.device) + + # talker_result = model.talker.generate( + # input_ids=talker_input_ids, + # input_text_ids=talker_input_text_ids, + # thinker_reply_part=thinker_reply_part, + # inputs_embeds=talker_inputs_embeds, + # attention_mask=talker_attention_mask, + # suppress_tokens=[model.talker.codec_bos_token], + # do_sample=True, + # top_k=10, + # top_p=0.9, + # temperature=0.95, + # repetition_penalty=1.1, + # min_new_tokens=0, + # max_new_tokens=8192, + # ) + # talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + # print(talker_generate_codes) + + streamer = TokenStreamer(skip_prompt=True) + talker_kwargs = dict( + input_ids=talker_input_ids, + streamer=streamer, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[model.talker.codec_bos_token], + eos_token_id=talker_eos_token_id, + pad_token_id=8292, + do_sample=True, + top_k=10, + top_p=0.9, + temperature=0.95, + repetition_penalty=1.1, + min_new_tokens=0, + max_new_tokens=8192, + ) + # print(talker_kwargs.keys()) + thread = Thread(target=model.talker.generate, kwargs=talker_kwargs) + thread.start() + + # 3. Generate wavs from code + if model.token2wav.dtype != torch.float: + model.token2wav.float() + + code2wav_times = [] + talker_generate_codes = [] + times = [] + start_time = perf_counter() + pre_offset = 0 + for token_id in streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + if token_id in talker_eos_token_id: + break + talker_generate_codes.append(token_id) + chunk_code_length = len(talker_generate_codes) * 2 - 24 + if chunk_code_length > 0 and chunk_code_length % 48 == 0: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=model.talker.device, + ) + pre_offset = len(talker_generate_codes) + wav = ( + model.token2wav( + codes_tensor.to(model.token2wav.device), + conditioning=model.speaker_map[speaker]["cond"] + .to(model.token2wav.device) + .float(), + reference_mel=model.speaker_map[speaker]["ref_mel"] + .to(model.token2wav.device) + .float(), + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ) + .unsqueeze(0) + .detach() + ) + code2wav_times.append(perf_counter() - start_time) + yield (gen_text, wav) + start_time = perf_counter() + + print( + f"generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + + if len(talker_generate_codes) > pre_offset: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=model.talker.device, + ) + wav = ( + model.token2wav( + codes_tensor.to(model.token2wav.device), + conditioning=model.speaker_map[speaker]["cond"] + .to(model.token2wav.device) + .float(), + reference_mel=model.speaker_map[speaker]["ref_mel"] + .to(model.token2wav.device) + .float(), + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ) + .unsqueeze(0) + .detach() + ) + code2wav_times.append(perf_counter() - start_time) + yield (gen_text, wav) + + print( + f"code2wav streaming first chunk time: {code2wav_times[0]} s | cost: {sum(code2wav_times)} s" + ) + + torch.cuda.empty_cache() + + +@app.function( + gpu=os.getenv("IMAGE_GPU", "L40s"), + cpu=2.0, + retries=0, + image=omni_img, + volumes={ + HF_MODEL_DIR: hf_model_vol, + ASSETS_DIR: assets_dir, + }, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=1, +) +def run(func): + func() + + +def voice_chatting(): + import torchaudio + + sys_msg = { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + } + + for audio_path in ["guess_age_gender.wav", "translate_to_chinese.wav"]: + audio_path = os.path.join(ASSETS_DIR, audio_path) + audio_msg = [ + sys_msg, + { + "role": "user", + "content": [ + {"type": "audio", "audio": audio_path}, + ], + }, + ] + + texts, audio = inference(audio_msg, return_audio=True, use_audio_in_video=True) + print(texts[0], audio.shape) + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_{os.path.basename(audio_path)}") + torchaudio.save(save_audio_path, audio, sample_rate=24000) + print(f"Audio saved to {save_audio_path}") + + +def multi_round_omni_chatting(): + import torchaudio + + conversations = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + ] + for video_path in ["draw1.mp4", "draw2.mp4", "draw3.mp4"]: + video_path = os.path.join(ASSETS_DIR, video_path) + conversations.append({"role": "user", "content": [{"type": "video", "video": video_path}]}) + texts, audio = inference(conversations, return_audio=True, use_audio_in_video=True) + print(texts[0], audio.shape) + save_audio_path = os.path.join(ASSETS_DIR, f"generated_{os.path.basename(video_path)}") + torchaudio.save(save_audio_path, audio, sample_rate=24000) + print(f"Audio saved to {save_audio_path}") + + +def omni_chatting_for_math(): + import torchaudio + + video_path = os.path.join(ASSETS_DIR, "math.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + + response, audio = inference(messages, return_audio=True, use_audio_in_video=True) + print(response[0], audio.shape) + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_{os.path.basename(video_path)}") + torchaudio.save(save_audio_path, audio, sample_rate=24000) + print(f"Audio saved to {save_audio_path}") + + +def omni_chatting_for_math_stream(): + import torchaudio + import soundfile as sf + + video_path = os.path.join(ASSETS_DIR, "math.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + + for _ in range(1): # warmup and test + streamer = thinker_talker_inference_stream(messages, use_audio_in_video=True) + audios = [] + times = [] + start_time = time.perf_counter() + for i, (texts, audio) in enumerate(streamer): + if i == 0: + print(texts[0]) + times.append(time.perf_counter() - start_time) + audios.append(audio.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_math_stream_{i}.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_math_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def omni_chatting_for_math_chunk_stream(): + import torchaudio + import soundfile as sf + + video_path = os.path.join(ASSETS_DIR, "math.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + + thinker_eos_token_ids = [151644, 151645] + print(thinker_eos_token_ids) + for _ in range(1): # warmup and test + streamer = generate_stream( + messages, + use_audio_in_video=True, + thinker_max_new_tokens=100, + thinker_max_tokens_per_step=15, + thinker_stop_strings_per_step=[".", "。"], + thinker_eos_token_ids=thinker_eos_token_ids, + ) + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, (text, wav) in enumerate(streamer): + times.append(time.perf_counter() - start_time) + if gen_text != text: + gen_text = text + gen_all_text += gen_text + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_math_chunk_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_math_chunk_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def omni_chatting_for_music(): + import torchaudio + + video_path = os.path.join(ASSETS_DIR, "music.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + response, audio = inference(messages, return_audio=True, use_audio_in_video=True) + print(response[0]) + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_{os.path.basename(video_path)}") + torchaudio.save(save_audio_path, audio, sample_rate=24000) + print(f"Audio saved to {save_audio_path}") + + +def omni_chatting_for_music_stream(): + import torchaudio + import soundfile as sf + + video_path = os.path.join(ASSETS_DIR, "music.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + for _ in range(1): # warmup and test + streamer = thinker_talker_inference_stream(messages, use_audio_in_video=True) + audios = [] + times = [] + start_time = time.perf_counter() + for i, (texts, audio) in enumerate(streamer): + if i == 0: + print(texts[0]) + times.append(time.perf_counter() - start_time) + audios.append(audio.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_math_stream_{i}.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_music_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def omni_chatting_for_music_chunk_stream(): + import torchaudio + import soundfile as sf + + video_path = os.path.join(ASSETS_DIR, "music.mp4") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + ], + }, + ] + thinker_eos_token_ids = [151644, 151645] + print(thinker_eos_token_ids) + for _ in range(1): # warmup and test + streamer = generate_stream( + messages, + use_audio_in_video=True, + thinker_max_new_tokens=100, + thinker_max_tokens_per_step=15, + thinker_stop_strings_per_step=[".", "。"], + thinker_eos_token_ids=thinker_eos_token_ids, + ) + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, (text, wav) in enumerate(streamer): + times.append(time.perf_counter() - start_time) + if gen_text != text: + gen_text = text + gen_all_text += gen_text + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_music_chunk_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_for_music_chunk_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def screen_recording_interaction(): + video_path = os.path.join(ASSETS_DIR, "screen.mp4") + for prompt in [ + "What the browser is used in this video?", + "浏览器中的论文叫什么名字?", + "这篇论文主要解决什么问题呢?", + ]: + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + + response, _ = inference(messages, return_audio=False, use_audio_in_video=False) + print(response[0]) + + +def screen_recording_interaction_stream(): + video_path = os.path.join(ASSETS_DIR, "screen.mp4") + for prompt in [ + "What the browser is used in this video?", + "浏览器中的论文叫什么名字?", + "这篇论文主要解决什么问题呢?", + ]: + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + + text_stream = thinker_inference_stream(messages, use_audio_in_video=False) + for text in text_stream: + print(text) + + +def screen_recording_interaction_chunk_stream(): + video_path = os.path.join(ASSETS_DIR, "screen.mp4") + for prompt in [ + "What the browser is used in this video?", + "浏览器中的论文叫什么名字?", + "这篇论文主要解决什么问题呢?", + ]: + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + + text_stream = thinker_inference_chunk_stream(messages, use_audio_in_video=False) + for text in text_stream: + print(text) + + +def universal_audio_understanding(): + for case in [ + { + "audio_path": "1272-128104-0000.flac", + "prompt": "Transcribe the English audio into text without any punctuation marks.", + "sys_prompt": "You are a speech recognition model.", + }, + { + "audio_path": "BAC009S0764W0121.wav", + "prompt": "请将这段中文语音转换为纯文本,去掉标点符号。", + "sys_prompt": "You are a speech recognition model.", + }, + { + "audio_path": "10000611681338527501.wav", + "prompt": "Transcribe the Russian audio into text without including any punctuation marks.", + "sys_prompt": "You are a speech recognition model.", + }, + { + "audio_path": "7105431834829365765.wav", + "prompt": "Transcribe the French audio into text without including any punctuation marks.", + "sys_prompt": "You are a speech recognition model.", + }, + { + "audio_path": "1272-128104-0000.flac", + "prompt": "Listen to the provided English speech and produce a translation in Chinese text.", + "sys_prompt": "You are a speech translation model.", + }, + { + "audio_path": "cough.wav", + "prompt": "Classify the given human vocal sound in English.", + "sys_prompt": "You are a voice classification model.", + }, + ]: + audio_path = os.path.join(ASSETS_DIR, case["audio_path"]) + messages = [ + {"role": "system", "content": [{"type": "text", "text": case["sys_prompt"]}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": case["prompt"]}, + {"type": "audio", "audio": audio_path}, + ], + }, + ] + texts, _ = inference(messages, use_audio_in_video=True, return_audio=False) + print(texts[0]) + + +def video_information_extracting(): + video_path = os.path.join(ASSETS_DIR, "shopping.mp4") + sys_msg = { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + } + for prompt in [ + "How many kind of drinks can you see in the video?", + "How many bottles of drinks have I picked up?", + "How many milliliters are there in the bottle I picked up second time?", + "视屏中的饮料叫什么名字呢?", + "跑步🏃🏻累了,适合喝什么饮料补充体力呢?", + ]: + messages = [ + sys_msg, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + texts, _ = inference(messages, return_audio=False, use_audio_in_video=False) + print(texts[0]) + + +def video_information_extracting_stream(): + video_path = os.path.join(ASSETS_DIR, "shopping.mp4") + sys_msg = { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + } + for prompt in [ + "How many kind of drinks can you see in the video?", + "How many bottles of drinks have I picked up?", + "How many milliliters are there in the bottle I picked up second time?", + "视屏中的饮料叫什么名字呢?", + "跑步🏃🏻累了,适合喝什么饮料补充体力呢?", + ]: + messages = [ + sys_msg, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + text_stream = thinker_inference_stream(messages, use_audio_in_video=False) + for text in text_stream: + print(text) + + +def video_information_extracting_chunk_stream(): + video_path = os.path.join(ASSETS_DIR, "shopping.mp4") + sys_msg = { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + } + for prompt in [ + "How many kind of drinks can you see in the video?", + "How many bottles of drinks have I picked up?", + "How many milliliters are there in the bottle I picked up second time?", + "视屏中的饮料叫什么名字呢?", + "跑步🏃🏻累了,适合喝什么饮料补充体力呢?", + ]: + messages = [ + sys_msg, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "video", "video": video_path}, + ], + }, + ] + text_stream = thinker_inference_chunk_stream(messages, use_audio_in_video=False) + for text in text_stream: + print(text) + + +def batch_requests(): + """need return_audio=False""" + # Conversation with video only + conversation1 = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": os.path.join(ASSETS_DIR, "draw1.mp4")}, + ], + }, + ] + + # Conversation with audio only + conversation2 = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "audio", "audio": os.path.join(ASSETS_DIR, "1272-128104-0000.flac")}, + ], + }, + ] + + # Conversation with pure text + conversation3 = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": "who are you?"}]}, + ] + + # Conversation with mixed media + conversation4 = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + {"type": "video", "video": os.path.join(ASSETS_DIR, "music.mp4")}, + {"type": "audio", "audio": os.path.join(ASSETS_DIR, "1272-128104-0000.flac")}, + { + "type": "text", + "text": "What are the elements can you see and hear in these medias?", + }, + ], + }, + ] + + # Combine messages for batch processing + conversations = [conversation1, conversation2, conversation3, conversation4] + texts, _ = inference(conversations, return_audio=False, use_audio_in_video=True) + print(texts) + + +def image_stream(): + for case in [ + { + "image_path": "03-Confusing-Pictures.jpg", + "prompt": "请描述一下图片中的内容", + "sys_prompt": "You are a vision recognition model.", + }, + ]: + image_path = os.path.join(ASSETS_DIR, case["image_path"]) + messages = [ + {"role": "system", "content": [{"type": "text", "text": case["sys_prompt"]}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": case["prompt"]}, + {"type": "image", "image": image_path}, + ], + }, + ] + text_streamer = thinker_inference_stream(messages, use_audio_in_video=False) + for text in text_streamer: + print(text) + + +def image_chunk_stream(): + for case in [ + { + "image_path": "03-Confusing-Pictures.jpg", + "prompt": "请描述一下图片中的内容", + "sys_prompt": "You are a vision recognition model.", + }, + ]: + image_path = os.path.join(ASSETS_DIR, case["image_path"]) + messages = [ + {"role": "system", "content": [{"type": "text", "text": case["sys_prompt"]}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": case["prompt"]}, + {"type": "image", "image": image_path}, + ], + }, + ] + text_streamer = thinker_inference_chunk_stream( + messages, + use_audio_in_video=False, + output_hidden_states=False, + max_new_tokens=1024, + ) + for text in text_streamer: + print(text) + + +def asr_stream(): + for case in [ + { + "audio_path": "1272-128104-0000.flac", + "prompt": "Listen to the provided English speech and produce a translation in Chinese text.", + "sys_prompt": "You are a speech translation model.", + }, + { + "audio_path": "BAC009S0764W0121.wav", + "prompt": "请将这段中文语音转换为纯文本", + "sys_prompt": "You are a speech recognition model.", + }, + { + "audio_path": "asr_example_zh.wav", + "prompt": "请将这段中文语音转换为纯文本", + "sys_prompt": "You are a speech recognition model.", + }, + ]: + audio_path = os.path.join(ASSETS_DIR, case["audio_path"]) + messages = [ + {"role": "system", "content": [{"type": "text", "text": case["sys_prompt"]}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": case["prompt"]}, + {"type": "audio", "audio": audio_path}, + ], + }, + ] + text_streamer = thinker_inference_stream(messages, use_audio_in_video=True) + for text in text_streamer: + print(text) + + +def asr_chunk_stream(): + for case in [ + { + "audio_path": "asr_example_zh.wav", + "prompt": "请将这段中文语音转换为纯文本", + "sys_prompt": "You are a speech recognition model.", + }, + ]: + audio_path = os.path.join(ASSETS_DIR, case["audio_path"]) + messages = [ + {"role": "system", "content": [{"type": "text", "text": case["sys_prompt"]}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": case["prompt"]}, + {"type": "audio", "audio": audio_path}, + ], + }, + ] + chunk_streamer = thinker_inference_chunk_stream(messages, use_audio_in_video=False) + for chunk in chunk_streamer: + print(chunk) + + +def thinker_stream(): + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": "who are you?"}]}, + ] + chunk_stream = thinker_inference_stream(messages) + for chunk in chunk_stream: + print(chunk) + + +def thinker_chunk_stream(): + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": "who are you?"}]}, + ] + chunk_stream = thinker_inference_chunk_stream( + messages, + use_audio_in_video=False, + output_hidden_states=False, + max_new_tokens=100, + ) + for chunk in chunk_stream: + print(chunk) + + +def omni_chatting_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": "who are you?"}]}, + ] + # response, audio = inference( + # messages, return_audio=True, use_audio_in_video=False, thinker_do_sample=True + # ) + # print(response[0]) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_omni_chatting_stream.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + # return + + for _ in range(1): # warmup and test + streamer = thinker_talker_inference_stream(messages, use_audio_in_video=False) + audios = [] + times = [] + start_time = time.perf_counter() + for i, (texts, audio) in enumerate(streamer): + if i == 0: + print(texts[0]) + times.append(time.perf_counter() - start_time) + audios.append(audio.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_omni_chatting_stream_{i}.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_omni_chatting_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def omni_chatting_segment_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": "who are you?"}]}, + ] + thinker_eos_token_ids = [151644, 151645] + print(thinker_eos_token_ids) + chunk_stream = generate_stream( + messages, + use_audio_in_video=False, + thinker_max_new_tokens=100, + thinker_max_tokens_per_step=15, + thinker_stop_strings_per_step=[".", "。"], + thinker_eos_token_ids=thinker_eos_token_ids, + ) + + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, (text, wav) in enumerate(chunk_stream): + times.append(time.perf_counter() - start_time) + if gen_text != text: + gen_text = text + gen_all_text += gen_text + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_segment_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + save_audio_path = os.path.join(ASSETS_DIR, f"omni_chatting_segment_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def image_chatting_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": "请描述一下图片中的内容"}, + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + ], + }, + ] + # response, audio = inference( + # messages, return_audio=True, use_audio_in_video=False, thinker_do_sample=True + # ) + # print(response[0]) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_image_chatting_stream.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + # return + + for _ in range(1): # warmup and test + streamer = thinker_talker_inference_stream(messages, use_audio_in_video=False) + audios = [] + times = [] + start_time = time.perf_counter() + for i, (texts, audio) in enumerate(streamer): + if i == 0: + print(texts[0]) + times.append(time.perf_counter() - start_time) + audios.append(audio.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_image_chatting_stream_{i}.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_image_chatting_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def image_chatting_segment_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "text", "text": "请描述一下图片中的内容"}, + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + ], + }, + ] + thinker_eos_token_ids = [151644, 151645] + print(thinker_eos_token_ids) + mask_embedding = True + chunk_stream = generate_stream( + messages, + use_audio_in_video=False, + thinker_max_new_tokens=100, + thinker_max_tokens_per_step=15, + thinker_stop_strings_per_step=[",", ".", ",", "。"], + thinker_eos_token_ids=thinker_eos_token_ids, + mask_embedding=mask_embedding, + ) + + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, (text, wav) in enumerate(chunk_stream): + times.append(time.perf_counter() - start_time) + if gen_text != text: + gen_text = text + gen_all_text += gen_text + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"image_chatting_segment_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + save_audio_path = os.path.join( + ASSETS_DIR, f"image_chatting_segment_stream_{mask_embedding}.wav" + ) + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def audio_image_chatting_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "audio", "audio": os.path.join(ASSETS_DIR, "image.mp3")}, + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + ], + }, + ] + # response, audio = inference( + # messages, return_audio=True, use_audio_in_video=False, thinker_do_sample=True + # ) + # print(response[0]) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_audio_image_chatting_stream.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + # return + + for _ in range(1): # warmup and test + streamer = thinker_talker_inference_stream(messages, use_audio_in_video=False) + audios = [] + times = [] + start_time = time.perf_counter() + for i, (texts, audio) in enumerate(streamer): + if i == 0: + print(texts[0]) + times.append(time.perf_counter() - start_time) + audios.append(audio.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"generated_audio_image_chatting_stream_{i}.wav") + # torchaudio.save(save_audio_path, audio, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + save_audio_path = os.path.join(ASSETS_DIR, f"generated_audio_image_chatting_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def audio_image_chatting_segment_stream(): + import torchaudio + import soundfile as sf + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SPEECH_SYS_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "audio", "audio": os.path.join(ASSETS_DIR, "image.mp3")}, + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + ], + }, + ] + thinker_eos_token_ids = [151644, 151645] + print(thinker_eos_token_ids) + chunk_stream = generate_stream( + messages, + use_audio_in_video=False, + thinker_max_new_tokens=100, + thinker_max_tokens_per_step=15, + thinker_stop_strings_per_step=[".", "。"], + thinker_eos_token_ids=thinker_eos_token_ids, + ) + + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, (text, wav) in enumerate(chunk_stream): + times.append(time.perf_counter() - start_time) + if gen_text != text: + gen_text = text + gen_all_text += gen_text + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"audio_image_chatting_segment_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + save_audio_path = os.path.join(ASSETS_DIR, f"audio_image_chatting_segment_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +def tokenizer(): + print(processor.tokenizer.pad_token_id) + print(",", processor.tokenizer.encode(",")) + print(".", processor.tokenizer.encode(".")) + print("?", processor.tokenizer.encode("?")) + print(",", processor.tokenizer.encode(",")) + print("。", processor.tokenizer.encode("。")) + print("?", processor.tokenizer.encode("?")) + print("!", processor.tokenizer.encode("!")) + + +def tokenizer_sentences(): + # return processor.tokenizer.encode(";.?!;。?!") + token_ids = [] + for i in ",;.?,;。?!": + # for i in ",.": + token_id = processor.tokenizer.encode(i) + token_ids.extend(token_id) + return token_ids + + +class TokenStreamer(BaseStreamer): + def __init__(self, skip_prompt: bool = False, timeout=None): + self.skip_prompt = skip_prompt + + # variables used in the streaming process + self.token_queue = Queue() + self.stop_signal = None + self.next_tokens_are_prompt = True + self.timeout = timeout + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + for token in value.tolist(): + self.token_queue.put(token) + + def end(self): + self.token_queue.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + value = self.token_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value + + +def achatbot_generate(): + import torchaudio + import soundfile as sf + from achatbot.core.llm.transformers.manual_vision_voice_qwen import ( + TransformersManualQwen2_5OmniLLM, + ) + from achatbot.common.session import Session, SessionCtx + from achatbot.core.llm import LLMEnvInit + from achatbot.common.logger import Logger + + Logger.init(os.getenv("LOG_LEVEL", "info").upper(), is_file=False, is_console=True) + + session = Session(**SessionCtx("test_client_id", 16000, 2).__dict__) + args = LLMEnvInit.get_qwen2_5omni_transformers_args() + args["speaker"] = "Ethan" + args["lm_attn_impl"] = "flash_attention_2" + args["warmup_steps"] = 1 + args["warnup_prompt"] = "你叫什么名字?" + args["is_use_sliding_window_code2wav"] = True + args["thinker_all_talker_stream"] = False + args["code2wav_args"]["enable_torch_compile"] = False + args["code2wav_args"]["enable_torch_compile_first_chunk"] = False + llm = TransformersManualQwen2_5OmniLLM(**args) + + print("----start generate stream----") + + session.ctx.state["prompt"] = [ + {"type": "text", "text": "请描述一下图片中的内容"}, + {"type": "image", "image": os.path.join(ASSETS_DIR, "03-Confusing-Pictures.jpg")}, + # {"type": "video", "video": ""}, + # {"type": "audio", "audio": ""}, + ] + kwargs = { + "use_audio_in_video": False, + "thinker_top_k": 10, + "thinker_top_p": 0.9, + "thinker_temperature": 0.95, + "thinker_repetition_penalty": 1.1, + "thinker_min_new_tokens": 1, + "thinker_max_tokens_per_step": 15, + "thinker_stop_strings_per_step": [",", ".", ",", "。"], + "thinker_max_new_tokens": 150, + "thinker_eos_token_ids": [ + 151644, + 151645, + ], + "thinker_pad_token_id": 151643, + } + chunk_stream = llm.generate(session, **kwargs) + gen_text = "" + gen_all_text = "" + audios = [] + times = [] + start_time = time.perf_counter() + for i, chunk in enumerate(chunk_stream): + times.append(time.perf_counter() - start_time) + text = chunk["text"] + if gen_text != text: + gen_text = text + gen_all_text += gen_text + if "audio_wav" in chunk: + wav = chunk["audio_wav"] + print(text, wav.shape) + audios.append(wav.squeeze().cpu().numpy()) + # save_audio_path = os.path.join(ASSETS_DIR, f"achatbot_generate_stream-{i}-{text}.wav") + # torchaudio.save(save_audio_path, wav, sample_rate=24000) + # print(f"Audio saved to {save_audio_path}") + else: + print(text) + start_time = time.perf_counter() + + print(f"gen all text: {gen_all_text}") + if len(audios) > 0: + save_audio_path = os.path.join(ASSETS_DIR, f"achatbot_generate_stream.wav") + sf.write(save_audio_path, np.concatenate(audios), samplerate=24000) + print(f"All Audio saved to {save_audio_path}") + info = sf.info(save_audio_path, verbose=True) + print( + f"thinker->talker->code2wav chunk streaming first chunk time: {times[0]} s | wav duration: {info.duration} s | cost: {sum(times)} s | RTF: {sum(times)/info.duration}" + ) + + +""" +# NOTE: if want to generate speech, need use SPEECH_SYS_PROMPT to generate speech + +# asr (audio understanding) +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task universal_audio_understanding + +# audio to text and speech +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task voice_chatting + +# vision(video no audio) to text +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task video_information_extracting +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task screen_recording_interaction + +# vision(video with audio) to text and speech +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_math +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_music + +# vision(video with audio) to text and speech with multi rounds chat, but need more GPU memory +IMAGE_GPU=A100-80GB modal run src/llm/transformers/qwen2_5omni.py --task multi_round_omni_chatting + +# batch requests +IMAGE_GPU=A100-80GB modal run src/llm/transformers/qwen2_5omni.py --task batch_requests + +# stream +# text -> text stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task thinker_stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task thinker_chunk_stream +# image -> text stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task image_stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task image_chunk_stream +# audio -> text stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task asr_stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task asr_chunk_stream +# video -> text stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task screen_recording_interaction_stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task video_information_extracting_stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task video_information_extracting_chunk_stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task screen_recording_interaction_chunk_stream + +# text -> text + chunk speech stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_stream + +# text -> chunk text+speech stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_segment_stream + +# text+image -> text + chunk speech stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task image_chatting_stream +# text+image -> chunk text+speech stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task image_chatting_segment_stream + +# audio+image -> text + chunk speech stream +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task audio_image_chatting_stream +# audio+image -> chunk text+speech stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task audio_image_chatting_segment_stream + +# vision(video with audio) -> text + chunk speech stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_math_stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_music_stream + +# vision(video with audio) -> chunk text+speech stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_math_chunk_stream +IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task omni_chatting_for_music_chunk_stream + + +# text/vision/audio -> chunk text+speech stream use sliding window code2wav +ACHATBOT_VERSION=0.0.9.post10 IMAGE_GPU=L40s modal run src/llm/transformers/qwen2_5omni.py --task achatbot_generate + +IMAGE_GPU=L4 modal run src/llm/transformers/qwen2_5omni.py --task tokenizer +""" + + +@app.local_entrypoint() +def main(task: str = "universal_audio_understanding"): + tasks = { + "tokenizer": tokenizer, + "achatbot_generate": achatbot_generate, + "universal_audio_understanding": universal_audio_understanding, + "voice_chatting": voice_chatting, + "video_information_extracting": video_information_extracting, + "video_information_extracting_stream": video_information_extracting_stream, + "video_information_extracting_chunk_stream": video_information_extracting_chunk_stream, + "screen_recording_interaction": screen_recording_interaction, + "screen_recording_interaction_stream": screen_recording_interaction_stream, + "screen_recording_interaction_chunk_stream": screen_recording_interaction_chunk_stream, + "omni_chatting_for_math": omni_chatting_for_math, + "omni_chatting_for_math_stream": omni_chatting_for_math_stream, + "omni_chatting_for_math_chunk_stream": omni_chatting_for_math_chunk_stream, + "omni_chatting_for_music": omni_chatting_for_music, + "omni_chatting_for_music_stream": omni_chatting_for_music_stream, + "omni_chatting_for_music_chunk_stream": omni_chatting_for_music_chunk_stream, + "image_chatting_stream": image_chatting_stream, + "image_chatting_segment_stream": image_chatting_segment_stream, + "audio_image_chatting_stream": audio_image_chatting_stream, + "audio_image_chatting_segment_stream": audio_image_chatting_segment_stream, + "multi_round_omni_chatting": multi_round_omni_chatting, + "batch_requests": batch_requests, + "thinker_stream": thinker_stream, + "thinker_chunk_stream": thinker_chunk_stream, + "image_stream": image_stream, + "image_chunk_stream": image_chunk_stream, + "asr_stream": asr_stream, + "asr_chunk_stream": asr_chunk_stream, + "omni_chatting_stream": omni_chatting_stream, + "omni_chatting_segment_stream": omni_chatting_segment_stream, + } + if task not in tasks: + raise ValueError(f"task {task} not found") + print(f"running task {task}") + run.remote(tasks[task]) diff --git a/deploy/modal/src/llm/transformers/qwen2_5omni_web_demo.py b/deploy/modal/src/llm/transformers/qwen2_5omni_web_demo.py new file mode 100644 index 00000000..8a1694a5 --- /dev/null +++ b/deploy/modal/src/llm/transformers/qwen2_5omni_web_demo.py @@ -0,0 +1,460 @@ +import modal +import os +import io + +app = modal.App("qwen2_5_omni_web_demo") +omni_img = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install("git", "git-lfs", "ffmpeg", "cmake") + .pip_install("wheel", "openai", "qwen-omni-utils[decord]") + .run_commands( + f"pip install git+https://github.com/huggingface/transformers", + ) + .pip_install("accelerate", "torch", "torchvision", "torchaudio") + .pip_install("flash-attn", extra_options="--no-build-isolation") + .env( + { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + ) + .pip_install( + "gradio==5.23.1", + "gradio_client==1.8.0", + "ffmpeg==1.4", + "ffmpeg-python==0.2.0", + "soundfile==0.13.0", + "librosa==0.11.0", + "modelscope_studio==1.2.2", + "av", + ) +) + +HF_MODEL_DIR = "/root/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +ASSETS_DIR = "/root/assets" +assets_dir = modal.Volume.from_name("assets", create_if_missing=True) + +# NOTE: if want to generate speech, need use this system prompt to generate speech +SPEECH_SYS_PROMPT = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + +with omni_img.imports(): + import subprocess + import ffmpeg + import numpy as np + import soundfile as sf + import torch, torchaudio + + import modelscope_studio.components.base as ms + import modelscope_studio.components.antd as antd + + import gradio as gr + import gradio.processing_utils as processing_utils + from gradio_client import utils as client_utils + + from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor + from qwen_omni_utils import process_mm_info + + subprocess.run("nvidia-smi --version", shell=True) + subprocess.run("nvcc --version", shell=True) + gpu_prop = torch.cuda.get_device_properties("cuda") + print(gpu_prop) + + model_path = os.path.join(HF_MODEL_DIR, "Qwen/Qwen2.5-Omni-7B") + model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="flash_attention_2", + ).eval() + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + # print(model) + print(f"{model_million_params} M parameters") + + processor = Qwen2_5OmniProcessor.from_pretrained(model_path) + + subprocess.run("nvidia-smi", shell=True) + + +def _launch_demo(model, processor, ui_language, share, inbrowser, server_port, server_name): + # Voice settings + VOICE_LIST = ["Chelsie", "Ethan"] + DEFAULT_VOICE = "Chelsie" + + default_system_prompt = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + + language = ui_language + + def get_text(text: str, cn_text: str): + if language == "en": + return text + if language == "zh": + return cn_text + return text + + def convert_webm_to_mp4(input_file, output_file): + try: + ( + ffmpeg.input(input_file) + .output(output_file, acodec="aac", ar="16000", audio_bitrate="192k") + .run(quiet=True, overwrite_output=True) + ) + print(f"Conversion successful: {output_file}") + except ffmpeg.Error as e: + print("An error occurred during conversion.") + print(e.stderr.decode("utf-8")) + + def format_history(history: list, system_prompt: str): + messages = [] + messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) + for item in history: + if isinstance(item["content"], str): + messages.append({"role": item["role"], "content": item["content"]}) + elif item["role"] == "user" and ( + isinstance(item["content"], list) or isinstance(item["content"], tuple) + ): + file_path = item["content"][0] + + mime_type = client_utils.get_mimetype(file_path) + if mime_type.startswith("image"): + messages.append( + {"role": item["role"], "content": [{"type": "image", "image": file_path}]} + ) + elif mime_type.startswith("video"): + messages.append( + {"role": item["role"], "content": [{"type": "video", "video": file_path}]} + ) + elif mime_type.startswith("audio"): + messages.append( + { + "role": item["role"], + "content": [ + { + "type": "audio", + "audio": file_path, + } + ], + } + ) + return messages + + def predict(messages, voice=DEFAULT_VOICE): + print("predict history: ", messages) + + text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + + audios, images, videos = process_mm_info(messages, use_audio_in_video=True) + + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=True, + ) + inputs = inputs.to(model.device).to(model.dtype) + + text_ids, audio = model.generate( + **inputs, speaker=voice, use_audio_in_video=True, return_auidio=True + ) + + response = processor.batch_decode( + text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + response = response[0].split("\n")[-1] + yield {"type": "text", "data": response} + + audio = np.array(audio * 32767).astype(np.int16) + wav_io = io.BytesIO() + sf.write(wav_io, audio, samplerate=24000, format="WAV") + wav_io.seek(0) + wav_bytes = wav_io.getvalue() + audio_path = processing_utils.save_bytes_to_cache( + wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE + ) + yield {"type": "audio", "data": audio_path} + + def media_predict(audio, video, history, system_prompt, voice_choice): + # First yield + yield ( + None, # microphone + None, # webcam + history, # media_chatbot + gr.update(visible=False), # submit_btn + gr.update(visible=True), # stop_btn + ) + + if video is not None: + convert_webm_to_mp4(video, video.replace(".webm", ".mp4")) + video = video.replace(".webm", ".mp4") + files = [audio, video] + + for f in files: + if f: + history.append({"role": "user", "content": (f,)}) + + formatted_history = format_history( + history=history, + system_prompt=system_prompt, + ) + + history.append({"role": "assistant", "content": ""}) + + for chunk in predict(formatted_history, voice_choice): + if chunk["type"] == "text": + history[-1]["content"] = chunk["data"] + yield ( + None, # microphone + None, # webcam + history, # media_chatbot + gr.update(visible=False), # submit_btn + gr.update(visible=True), # stop_btn + ) + if chunk["type"] == "audio": + history.append({"role": "assistant", "content": gr.Audio(chunk["data"])}) + + # Final yield + yield ( + None, # microphone + None, # webcam + history, # media_chatbot + gr.update(visible=True), # submit_btn + gr.update(visible=False), # stop_btn + ) + + def chat_predict(text, audio, image, video, history, system_prompt, voice_choice): + # Process text input + if text: + history.append({"role": "user", "content": text}) + + # Process audio input + if audio: + history.append({"role": "user", "content": (audio,)}) + + # Process image input + if image: + history.append({"role": "user", "content": (image,)}) + + # Process video input + if video: + history.append({"role": "user", "content": (video,)}) + + formatted_history = format_history(history=history, system_prompt=system_prompt) + + yield None, None, None, None, history + + history.append({"role": "assistant", "content": ""}) + for chunk in predict(formatted_history, voice_choice): + if chunk["type"] == "text": + history[-1]["content"] = chunk["data"] + yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history + if chunk["type"] == "audio": + history.append({"role": "assistant", "content": gr.Audio(chunk["data"])}) + yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history + + with gr.Blocks() as demo, ms.Application(), antd.ConfigProvider(): + with gr.Sidebar(open=False): + system_prompt_textbox = gr.Textbox(label="System Prompt", value=default_system_prompt) + with antd.Flex(gap="small", justify="center", align="center"): + with antd.Flex(vertical=True, gap="small", align="center"): + antd.Typography.Title( + "Qwen2.5-Omni Demo", level=1, elem_style=dict(margin=0, fontSize=28) + ) + with antd.Flex(vertical=True, gap="small"): + antd.Typography.Text( + get_text("🎯 Instructions for use:", "🎯 使用说明:"), strong=True + ) + antd.Typography.Text( + get_text( + "1️⃣ Click the Audio Record button or the Camera Record button.", + "1️⃣ 点击音频录制按钮,或摄像头-录制按钮", + ) + ) + antd.Typography.Text(get_text("2️⃣ Input audio or video.", "2️⃣ 输入音频或者视频")) + antd.Typography.Text( + get_text( + "3️⃣ Click the submit button and wait for the model's response.", + "3️⃣ 点击提交并等待模型的回答", + ) + ) + voice_choice = gr.Dropdown(label="Voice Choice", choices=VOICE_LIST, value=DEFAULT_VOICE) + with gr.Tabs(): + with gr.Tab("Online"): + with gr.Row(): + with gr.Column(scale=1): + microphone = gr.Audio(sources=["microphone"], type="filepath") + webcam = gr.Video(sources=["webcam"], height=400, include_audio=True) + submit_btn = gr.Button(get_text("Submit", "提交"), variant="primary") + stop_btn = gr.Button(get_text("Stop", "停止"), visible=False) + clear_btn = gr.Button(get_text("Clear History", "清除历史")) + with gr.Column(scale=2): + media_chatbot = gr.Chatbot(height=650, type="messages") + + def clear_history(): + return [], gr.update(value=None), gr.update(value=None) + + submit_event = submit_btn.click( + fn=media_predict, + inputs=[ + microphone, + webcam, + media_chatbot, + system_prompt_textbox, + voice_choice, + ], + outputs=[microphone, webcam, media_chatbot, submit_btn, stop_btn], + ) + stop_btn.click( + fn=lambda: (gr.update(visible=True), gr.update(visible=False)), + inputs=None, + outputs=[submit_btn, stop_btn], + cancels=[submit_event], + queue=False, + ) + clear_btn.click( + fn=clear_history, inputs=None, outputs=[media_chatbot, microphone, webcam] + ) + + with gr.Tab("Offline"): + chatbot = gr.Chatbot(type="messages", height=650) + + # Media upload section in one row + with gr.Row(equal_height=True): + audio_input = gr.Audio( + sources=["upload"], + type="filepath", + label="Upload Audio", + elem_classes="media-upload", + scale=1, + ) + image_input = gr.Image( + sources=["upload"], + type="filepath", + label="Upload Image", + elem_classes="media-upload", + scale=1, + ) + video_input = gr.Video( + sources=["upload"], + label="Upload Video", + elem_classes="media-upload", + scale=1, + ) + + # Text input section + text_input = gr.Textbox(show_label=False, placeholder="Enter text here...") + + # Control buttons + with gr.Row(): + submit_btn = gr.Button(get_text("Submit", "提交"), variant="primary", size="lg") + stop_btn = gr.Button(get_text("Stop", "停止"), visible=False, size="lg") + clear_btn = gr.Button(get_text("Clear History", "清除历史"), size="lg") + + def clear_chat_history(): + return ( + [], + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + ) + + submit_event = gr.on( + triggers=[submit_btn.click, text_input.submit], + fn=chat_predict, + inputs=[ + text_input, + audio_input, + image_input, + video_input, + chatbot, + system_prompt_textbox, + voice_choice, + ], + outputs=[text_input, audio_input, image_input, video_input, chatbot], + ) + + stop_btn.click( + fn=lambda: (gr.update(visible=True), gr.update(visible=False)), + inputs=None, + outputs=[submit_btn, stop_btn], + cancels=[submit_event], + queue=False, + ) + + clear_btn.click( + fn=clear_chat_history, + inputs=None, + outputs=[chatbot, text_input, audio_input, image_input, video_input], + ) + + # Add some custom CSS to improve the layout + gr.HTML(""" + + """) + + demo.queue(default_concurrency_limit=100, max_size=100).launch( + max_threads=100, + ssr_mode=False, + share=share, + inbrowser=inbrowser, + server_port=server_port, + server_name=server_name, + ) + + +@app.function( + gpu=os.getenv("IMAGE_GPU", "L40s"), + cpu=2.0, + image=omni_img, + volumes={ + HF_MODEL_DIR: hf_model_vol, + ASSETS_DIR: assets_dir, + }, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=100, +) +def server( + ui_language="en", + server_port="7860", + server_name="127.0.0.1", +): + _launch_demo( + model, + processor, + ui_language, + True, + False, + int(server_port), + server_name, + ) + + +""" +modal run src/llm/transformers/qwen2_5omni_web_demo.py +""" diff --git a/deploy/modal/src/llm/transformers/run_omni_cases.sh b/deploy/modal/src/llm/transformers/run_omni_cases.sh new file mode 100644 index 00000000..8ee1e23d --- /dev/null +++ b/deploy/modal/src/llm/transformers/run_omni_cases.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# copyright 2025 by weedge (weege007@gmail.com) +# bash src/llm/transformers/run_omni_cases.sh + +# https://modal.com/docs/guide +if command -v modal &> /dev/null +then + echo "modal command found. start to run ..." +else + echo "pip install modal ..." + pip install -q modal + modal setup +fi +modal --version + +set -e + +#---- default values ---- + +IMAGE_GPU="L40S" +STAGE="all" +CASE="all" +MODEL_TYPE="qwen2_5omni" + + +#----- function ------- + +usage() { + echo "Usage: $0 [-h] [-s STAGE] [-d IMAGE_GPU] [-m MODEL_TYPE] [-c CASE]" + echo " -h Show this help message and exit." + echo " -s STAGE Set the stage (default: all)." + echo " Valid options: download, run, run_all, all" + echo " -m MODEL_TYPE model type (default: qwen2_5omni)." + echo " -c CASE run case (default: all)." + echo " Valid options e.g.: all" + echo " universal_audio_understanding" + echo " voice_chatting" + echo " video_information_extracting, screen_recording_interaction" + echo " omni_chatting_for_math, omni_chatting_for_music, multi_round_omni_chatting,asr_stream" + echo " -d IMAGE_GPU Set the GPU image (default: L40S)." + echo " Valid options: A10G A100 A100-80GB L4 L40S H100 https://fullstackdeeplearning.com/cloud-gpus/" + echo "e.g.: " + echo "bash run_omni_cases.sh -s all" + echo "bash run_omni_cases.sh -s download " + echo "bash run_omni_cases.sh -s run_all" + echo "bash run_omni_cases.sh -s run -c all" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c universal_audio_understanding" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c voice_chatting" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c video_information_extracting" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c screen_recording_interaction" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_math" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_music" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c multi_round_omni_chatting -d A100-80G" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c image_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c image_chunk_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c asr_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c asr_chunk_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c thinker_chunk_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_stream -d L4" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c screen_recording_interaction_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c screen_recording_interaction_chunk_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c video_information_extracting_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c video_information_extracting_chunk_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_math_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_music_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_math_chunk_stream -d L40s" + echo "bash run_omni_cases.sh -s run -m qwen2_5omni -c omni_chatting_for_music_chunk_stream -d L40s" +} + +run() { + #local CASE=$1 + echo "run $MODEL_TYPE $CASE $IMAGE_GPU $TAG_OR_COMMIT" + if [ -e "src/llm/transformers/$MODEL_TYPE.py" ]; then + echo "src/llm/transformers/$MODEL_TYPE.py exists" + cd src/llm/transformers/ + else + cd $SCRIPT_DIR + wget -q https://raw.githubusercontent.com/ai-bot-pro/achatbot/refs/heads/feat/vision_voice/deploy/modal/src/llm/transformers/$MODEL_TYPE.py -O $MODEL_TYPE.py + fi + all_cases=( + "universal_audio_understanding" + "voice_chatting" + "video_information_extracting" + "video_information_extracting_stream" + "video_information_extracting_chunk_stream" + "screen_recording_interaction" + "screen_recording_interaction_stream" + "screen_recording_interaction_chunk_stream" + "omni_chatting_for_math" + "omni_chatting_for_math_stream" + "omni_chatting_for_math_chunk_stream" + "omni_chatting_for_music" + "omni_chatting_for_music_stream" + "omni_chatting_for_music_chunk_stream" + "multi_round_omni_chatting" + "thinker_chunk_stream" + "image_stream" + "image_chunk_stream" + "asr_stream" + "asr_chunk_stream" + "omni_chatting_stream" + "omni_chatting_segment_stream" + ) + #return + case $CASE in + all) + for CASE in "${all_cases[@]}"; do + [[ $CASE == "multi_round_omni_chatting" ]] && IMAGE_GPU="A100-80GB" + echo "IMAGE_GPU=$IMAGE_GPU modal run $MODEL_TYPE.py --task $CASE" + IMAGE_GPU=$IMAGE_GPU modal run $MODEL_TYPE.py --task $CASE + done + ;; + *) + if [[ " ${all_cases[@]} " =~ " ${CASE} " ]]; then + [[ $CASE == "multi_round_omni_chatting" ]] && IMAGE_GPU="A100-80GB" + echo "IMAGE_GPU=$IMAGE_GPU modal run $MODEL_TYPE.py --task $CASE" + IMAGE_GPU=$IMAGE_GPU modal run $MODEL_TYPE.py --task $CASE + else + echo "$CASE not in ${all_cases[*]}" + usage + exit 1 + fi + ;; + esac +} + +download_models() { + if [ -e "src/download_models.py" ]; then + echo "src/download_models.py exists" + cd src + else + cd $SCRIPT_DIR + wget -q https://raw.githubusercontent.com/ai-bot-pro/achatbot/refs/heads/main/deploy/modal/src/download_models.py -O download_models.py + fi + + modal run download_models.py --repo-ids "Qwen/Qwen2.5-Omni-7B" + cd - +} + +download_assets() { + if [ -e "src/download_assets.py" ]; then + echo "src/download_assets.py exists" + cd src + else + cd $SCRIPT_DIR + wget -q https://raw.githubusercontent.com/ai-bot-pro/achatbot/refs/heads/main/deploy/modal/src/download_assets.py -O download_assets.py + fi + + modal run download_assets.py --asset-urls "https://raw.githubusercontent.com/ai-bot-pro/achatbot/refs/heads/main/test/img_files/03-Confusing-Pictures.jpg,https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw1.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw2.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw3.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/screen.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/music.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/math.mp4,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/BAC009S0764W0121.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/10000611681338527501.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/7105431834829365765.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/cough.wav,https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/shopping.mp4" + cd - +} + + +#----- let's go ------ + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." || exit + +export TAG_OR_COMMIT=$TAG_OR_COMMIT + +# 处理命令行参数 +while getopts ":d:s:m:c:h" opt; do + case ${opt} in + d ) + IMAGE_GPU=$OPTARG + ;; + s ) + STAGE=$OPTARG + ;; + c ) + CASE=$OPTARG + ;; + m ) + MODEL_TYPE=$OPTARG + ;; + h ) + usage + exit 0 + ;; + \? ) + echo "Invalid option: $OPTARG" 1>&2 + usage + exit 1 + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + usage + exit 1 + ;; + esac +done +shift $((OPTIND -1)) + + +ALLOWED_MODEL_TYPE=("qwen2_5omni" "minicpmo") +if [[ ! " ${ALLOWED_MODEL_TYPE[@]} " =~ " ${MODEL_TYPE} " ]]; then + echo "Invalid model type: $MODEL_TYPE" 1>&2 + usage + exit 1 +fi + +ALLOWED_GPUS=("A100" "A100-80GB" "A10G" "L4" "L40S" "H100") +if [[ ! " ${ALLOWED_GPUS[@]} " =~ " ${IMAGE_GPU} " ]]; then + echo "if use flash attention, need gpu arch >= 8.0 e.g.: A100 A100-80G A10G L4 L40S H100" +fi + +#export EXTRA_INDEX_URL="https://pypi.org/simple/" +#export ACHATBOT_VERSION="0.0.9.post8" + + +case $STAGE in + run) + run + ;; + download) + download_models + download_assets + ;; + run_all) + CASE=all + run + ;; + all) + download_models + download_assets + CASE=all + run + ;; + *) + echo "Invalid stage: $STAGE" 1>&2 + usage + exit 1 + ;; +esac \ No newline at end of file diff --git a/deploy/modal/src/llm/vllm/qwen2_5omni.py b/deploy/modal/src/llm/vllm/qwen2_5omni.py new file mode 100644 index 00000000..f286073f --- /dev/null +++ b/deploy/modal/src/llm/vllm/qwen2_5omni.py @@ -0,0 +1,229 @@ +import math +import os +import modal + +app = modal.App("vllm-generate") + +vllm_image = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install( + "git", + "git-lfs", + "ffmpeg", + "software-properties-common", + "libsndfile1", + "wget", + ) + .pip_install( + "torch", + "torchvision", + "torchaudio", + "accelerate", + "torchdiffeq", + "x_transformers", + "setuptools_scm", + "resampy", + "qwen-omni-utils", + ) + .run_commands( + "wget https://github.com/Kitware/CMake/releases/download/v3.26.1/cmake-3.26.1-Linux-x86_64.sh \ + -q -O /tmp/cmake-install.sh \ + && chmod u+x /tmp/cmake-install.sh \ + && mkdir /opt/cmake-3.26.1 \ + && /tmp/cmake-install.sh --skip-license --prefix=/opt/cmake-3.26.1 \ + && rm /tmp/cmake-install.sh \ + && ln -s /opt/cmake-3.26.1/bin/* /usr/local/bin" + ) + .run_commands( + "git lfs install", + "git clone -b new_qwen2_omni_public https://github.com/ai-bot-pro/vllm.git", + "cd /vllm && git checkout 50952d6e2b954063a7cfee9cb436aa57db065738", + "cd /vllm && pip install -r requirements/cuda.txt", + "cd /vllm && pip install . --no-build-isolation", # u can see a little film + ) + .run_commands( + "pip install git+https://github.com/BakerBunker/transformers@21dbefaa54e5bf180464696aa70af0bfc7a61d53", + ) + .pip_install( + "flashinfer-python==0.2.0.post2", + extra_index_url="https://flashinfer.ai/whl/cu121/torch2.6/", + ) + .pip_install("flash-attn", extra_options="--no-build-isolation") + .env( + { + "HF_HUB_ENABLE_HF_TRANSFER": "1", + "TORCH_CUDA_ARCH_LIST": "7.5 8.0 8.6 8.7 8.9 9.0", + } + ) # faster model transfers +) + + +# PP need close v1 +vllm_image = vllm_image.env( + { + "VLLM_USE_V1": os.getenv("VLLM_USE_V1", "0"), + } +).run_commands( + "rm -rf /vllm && git clone -b new_qwen2_omni_public https://github.com/ai-bot-pro/vllm.git", + "cd /vllm && git checkout 84b00e332c5005f59215865120822480b6c0fa2d", +) + +HF_MODEL_DIR = "/root/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +VLLM_CACHE_DIR = "/root/.cache/vllm" +vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) +ASSETS_DIR = "/root/assets" +assets_dir = modal.Volume.from_name("assets", create_if_missing=True) + + +with vllm_image.imports(): + import subprocess + import torch + + device_count = torch.cuda.device_count() + devices = ",".join([f"{i}" for i in range(device_count)]) + os.environ["CUDA_VISIBLE_DEVICES"] = devices + if device_count > 1: + subprocess.run("nvidia-smi topo -m", shell=True, env=os.environ) + + +@app.function( + gpu=os.getenv("IMAGE_GPU", "L4"), + cpu=2.0, + retries=0, + image=vllm_image, + volumes={ + HF_MODEL_DIR: hf_model_vol, + VLLM_CACHE_DIR: vllm_cache_vol, + ASSETS_DIR: assets_dir, + }, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=100, +) +def run(func, thinker_gpu_memory_utilization, talker_gpu_memory_utilization, other_cmd_args): + subprocess.run("nvidia-smi --version", shell=True) + subprocess.run("nvcc --version", shell=True) + gpu_prop = None + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + for i in range(device_count): + gpu_prop = torch.cuda.get_device_properties(f"cuda:{i}") + print(gpu_prop) + else: + print("CUDA is not available.") + + func( + thinker_gpu_memory_utilization=thinker_gpu_memory_utilization, + talker_gpu_memory_utilization=talker_gpu_memory_utilization, + other_cmd_args=other_cmd_args, + ) + + +def thinker_only(**kwargs): + """ + Only use the Thinker model to generate text. + + multi-modal data text/image/audio/video -> thinker -> text + """ + thinker_gpu_memory_utilization = kwargs.get("thinker_gpu_memory_utilization", 0.8) + device_count = torch.cuda.device_count() + print(f"CUDA device count: {device_count}") + thinker_devices = ",".join([f"{i}" for i in range(device_count)]) + model_dir = os.path.join(HF_MODEL_DIR, "Qwen/Qwen2.5-Omni-7B") + cmd = f"python end2end.py --model {model_dir} --prompt audio-in-video-v2 --enforce-eager --thinker-only --thinker-gpu-memory-utilization {thinker_gpu_memory_utilization} --thinker-devices [{thinker_devices}]" + print(cmd) + subprocess.run( + cmd, shell=True, cwd="/vllm/examples/offline_inference/qwen2_5_omni/", env=os.environ + ) + + +def thinker2talker2wav(**kwargs): + """ + use thinker to generate text and then use talker to generate audio vq indices code, finally convert audio vq indices code to audio waveform + + multi-modal data text/image/audio/video --> thinker -> text | talker -> audio vq indices code -> code2wav -> audio waveform + """ + # default l40s + thinker_gpu_memory_utilization = kwargs.get("thinker_gpu_memory_utilization", 0.6) + talker_gpu_memory_utilization = kwargs.get("talker_gpu_memory_utilization", 0.3) + device_count = torch.cuda.device_count() + print(f"CUDA device count: {device_count}") + thinker_devices = talker_devices = code2wav_devices = "0" + if device_count > 1: + thinker_device_count = math.ceil(device_count / 2) + thinker_devices = ",".join([f"{i}" for i in range(thinker_device_count)]) + talker_devices = ",".join([f"{i}" for i in range(thinker_device_count, device_count)]) + code2wav_devices = f"{device_count-1}" + if device_count == 2: + thinker_devices = "0,1" + model_dir = os.path.join(HF_MODEL_DIR, "Qwen/Qwen2.5-Omni-7B") + cmd = f"python end2end.py --model {model_dir} --prompt audio-in-video-v2 --enforce-eager --do-wave --voice-type Chelsie --warmup-voice-type Chelsie --thinker-devices [{thinker_devices}] --thinker-gpu-memory-utilization {thinker_gpu_memory_utilization} --talker-devices [{talker_devices}] --talker-gpu-memory-utilization {talker_gpu_memory_utilization} --code2wav-devices [{code2wav_devices}] --output-dir {ASSETS_DIR}" + print(cmd) + subprocess.run( + cmd, shell=True, cwd="/vllm/examples/offline_inference/qwen2_5_omni/", env=os.environ + ) + + +def code2wav(**kwargs): + """ + vq code --> cfm dit -> mel --> bigvgan -> waveforms streaming + """ + + other_cmd_args = kwargs.get("other_cmd_args", "") + model_dir = os.path.join(HF_MODEL_DIR, "Qwen/Qwen2.5-Omni-7B") + code_file = os.path.join(ASSETS_DIR, "code2wav.json") + cmd = f"python code2wav.py --code2wav-model {model_dir} --input-json {code_file} --output-dir {ASSETS_DIR} {other_cmd_args}" + print(cmd) + subprocess.run( + cmd, shell=True, cwd="/vllm/examples/offline_inference/qwen2_5_omni/", env=os.environ + ) + + +""" +# NOTE: +# - thinker LM: model weights take 16.73GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 5.48GiB; the rest of the memory reserved for KV Cache, so the total memory reserved for the model is 22.3 GiB. must thinker-gpu-memory-utilization * total_gpu_memory > 22.3 GiB +# - talker LM: model weights take 2.55GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 4.36GiB; the rest of the memory reserved for KV Cache, so the total memory reserved for the model is 6.9 GiB. must talker-gpu-memory-utilization * total_gpu_memory > 6.9 GiB + +IMAGE_GPU=L40s modal run src/llm/vllm/qwen2_5omni.py --task thinker_only +# use tp +IMAGE_GPU=L4:2 modal run src/llm/vllm/qwen2_5omni.py --task thinker_only + + +IMAGE_GPU=L40s modal run src/llm/vllm/qwen2_5omni.py --task thinker2talker2wav +IMAGE_GPU=L40s:2 modal run src/llm/vllm/qwen2_5omni.py --task thinker2talker2wav --thinker-gpu-memory-utilization 0.9 --talker-gpu-memory-utilization 0.7 + +# slow with no torch compile +IMAGE_GPU=T4 modal run src/llm/vllm/qwen2_5omni.py --task code2wav +IMAGE_GPU=L4 modal run src/llm/vllm/qwen2_5omni.py --task code2wav + +# fast with torch compile +IMAGE_GPU=L4 modal run src/llm/vllm/qwen2_5omni.py --task code2wav --other-cmd-args "--enable-torch-compile" +IMAGE_GPU=L40s modal run src/llm/vllm/qwen2_5omni.py --task code2wav --other-cmd-args "--enable-torch-compile" +IMAGE_GPU=L4 modal run src/llm/vllm/qwen2_5omni.py --task code2wav --other-cmd-args "--enable-torch-compile --odeint-method euler" +IMAGE_GPU=L4 modal run src/llm/vllm/qwen2_5omni.py --task code2wav --other-cmd-args "--enable-torch-compile --multi-waveforms" +""" + + +@app.local_entrypoint() +def main( + task: str = "thinker_only", + thinker_gpu_memory_utilization: str = "0.6", # thinker-gpu-memory-utilization * total_gpu_memory > 22.3GB + talker_gpu_memory_utilization: str = "0.3", # talker-gpu-memory-utilization * total_gpu_memory > 6.9 GB + other_cmd_args: str = "", +): + tasks = { + "thinker_only": thinker_only, + "thinker2talker2wav": thinker2talker2wav, + "code2wav": code2wav, + } + if task not in tasks: + raise ValueError(f"task {task} not found") + print(f"running task {task}") + run.remote( + tasks[task], thinker_gpu_memory_utilization, talker_gpu_memory_utilization, other_cmd_args + ) diff --git a/pyproject.toml b/pyproject.toml index 0dfcaf4e..5c92abaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ build-backend = "setuptools.build_meta" name = "achatbot" #dynamic = ["version"] # todo -version = "0.0.9.post8" +version = "0.0.9.post10" authors = [{ name = "weedge", email = "weege007@gmail.com" }] maintainers = [{ name = "weedge", email = "weege007@gmail.com" }] description = "An open source chat bot for voice (and multimodal) assistants" @@ -86,7 +86,6 @@ accelerate = ["accelerate~=0.28.0"] opencv = ["opencv-python~=4.10.0.84"] librosa = ["librosa~=0.10.2.post1"] soundfile = ["soundfile~=0.12.1"] -silero-vad = ["silero-vad~=5.1.2"] # diffusers DiT with quantizing model @@ -166,7 +165,7 @@ speech_waker = ["achatbot[porcupine_wakeword]"] # vad module tag -> pkgs pyannote_vad = ["pyannote.audio~=3.2.0"] webrtcvad = ["webrtcvad~=2.0.10"] -silero_vad = ["torch~=2.2.2", "torchaudio~=2.2.2"] +silero_vad = ["torch", "torchaudio"] webrtc_silero_vad = ["achatbot[webrtcvad,silero_vad]"] speech_vad = ["achatbot[pyannote_vad,webrtcvad,silero_vad]"] @@ -280,23 +279,38 @@ llm_transformers_manual_vision_voice_minicpmo = [ "decord", "moviepy", ] +llm_transformers_manual_vision_voice_qwen =[ + # https://github.com/huggingface/transformers/releases/tag/v4.51.3-Qwen2.5-Omni-preview + #"git+https://github.com/huggingface/transformers@v4.51.3-Qwen2.5-Omni-preview", + #"transformers==4.52.0", + "achatbot[accelerate,librosa,soundfile]", + "torch~=2.6.0", + "torchaudio~=2.6.0", + "torchvision~=0.21.0", + "numpy==1.26.2", + "qwen-omni-utils[decord]", + # code2wav + "torchdiffeq", + "x_transformers", +] # core llms core_llm = ["achatbot[llama_cpp,llm_personalai_proxy]"] # ----------------- asr ------------------ # asr module tag -> pkgs -whisper_asr = ["openai-whisper==20231117"] -whisper_timestamped_asr = ["whisper-timestamped~=1.14.2"] -whisper_faster_asr = ["faster-whisper~=1.0.2"] +whisper_asr = ["openai-whisper"] +whisper_timestamped_asr = ["whisper-timestamped"] +whisper_faster_asr = ["faster-whisper"] whisper_transformers_asr = ["transformers[torch]>=4.40.2"] whisper_mlx_asr = [ "mlx_whisper~=0.2.0; sys_platform == 'darwin' and platform_machine == 'arm64'", ] whisper_groq_asr = ["groq~=0.9.0"] sense_voice_asr = [ - "torch~=2.2.2", - "funasr~=1.1.8", + "torch", + "torchaudio", + "funasr", "onnx", "onnxconverter-common", ] @@ -624,7 +638,7 @@ line-length = 100 [tool.ruff.format] exclude = ["*.pyi", "*.ipynb"] [tool.ruff.lint] -ignore = ["F401", "F403", "F405", "F541"] +ignore = ["F401", "F403", "F405", "F541", "E741"] # uv [tool.uv] diff --git a/src/cmd/bots/__init__.py b/src/cmd/bots/__init__.py index 02949111..c9a0ec1a 100644 --- a/src/cmd/bots/__init__.py +++ b/src/cmd/bots/__init__.py @@ -166,6 +166,18 @@ def import_bots(bot_name: str = "DummyBot"): if "LivekitDescribeVisionToolsBot" in bot_name: from .vision import livekit_describe_vision_tools_bot + return True + if "LivekitQwen2_5OmniVoiceBot" in bot_name: + from .voice import livekit_qwen2_5omni_voice_bot + + return True + if "LivekitAsrQwen2_5OmniVoiceBot" in bot_name: + from .voice import livekit_asr_qwen2_5omni_voice_bot + + return True + if "LivekitQwen2_5OmniVisionVoiceBot" in bot_name: + from .omni import livekit_qwen2_5omni_vision_voice_bot + return True # if "LivekitMoshiVoiceBot" in bot_name: # from .voice import livekit_moshi_bot diff --git a/src/cmd/bots/base.py b/src/cmd/bots/base.py index 64acec3c..a21f32ef 100644 --- a/src/cmd/bots/base.py +++ b/src/cmd/bots/base.py @@ -335,6 +335,19 @@ def get_text_minicpmo_voice_processor(self, llm: LLMConfig | None = None) -> Voi llm_processor = MiniCPMoTextVoiceProcessor() return llm_processor + def get_text_qwen2_5omni_voice_processor( + self, llm: LLMConfig | None = None + ) -> VoiceProcessorBase: + from src.processors.voice.qwen2_5omni_voice_processor import Qwen2_5OmniTextVoiceProcessor + + if not llm: + llm = self._bot_config.voice_llm + if llm.args: + llm_processor = Qwen2_5OmniTextVoiceProcessor(**llm.args) + else: + llm_processor = Qwen2_5OmniTextVoiceProcessor() + return llm_processor + def get_audio_minicpmo_voice_processor( self, llm: LLMConfig | None = None ) -> VoiceProcessorBase: @@ -348,19 +361,51 @@ def get_audio_minicpmo_voice_processor( llm_processor = MiniCPMoAudioVoiceProcessor() return llm_processor + def get_audio_qwen2_5omni_voice_processor( + self, llm: LLMConfig | None = None + ) -> VoiceProcessorBase: + from src.processors.voice.qwen2_5omni_voice_processor import Qwen2_5OmniAudioVoiceProcessor + + if not llm: + llm = self._bot_config.voice_llm + if llm.args: + llm_processor = Qwen2_5OmniAudioVoiceProcessor(**llm.args) + else: + llm_processor = Qwen2_5OmniAudioVoiceProcessor() + return llm_processor + def get_minicpmo_vision_voice_processor( self, llm: LLMConfig | None = None ) -> VisionVoiceProcessorBase: from src.processors.omni.minicpmo_vision_voice import MiniCPMoVisionVoiceProcessor + from src.processors.omni.base import MockVisionVoiceProcessor if not llm: llm = self._bot_config.omni_llm + if "mock" in llm.tag: + return MockVisionVoiceProcessor() if llm.args: llm_processor = MiniCPMoVisionVoiceProcessor(**llm.args) else: llm_processor = MiniCPMoVisionVoiceProcessor() return llm_processor + def get_qwen2_5omni_vision_voice_processor( + self, llm: LLMConfig | None = None + ) -> VisionVoiceProcessorBase: + from src.processors.omni.qwen2_5omni_vision_voice import Qwen2_5OmnVisionVoiceProcessor + from src.processors.omni.base import MockVisionVoiceProcessor + + if not llm: + llm = self._bot_config.omni_llm + if "mock" in llm.tag: + return MockVisionVoiceProcessor() + if llm.args: + llm_processor = Qwen2_5OmnVisionVoiceProcessor(**llm.args) + else: + llm_processor = Qwen2_5OmnVisionVoiceProcessor() + return llm_processor + def get_text_glm_voice_processor(self, llm: LLMConfig | None = None) -> VoiceProcessorBase: from src.processors.voice.glm_voice_processor import GLMTextVoiceProcessor diff --git a/src/cmd/bots/omni/daily_minicpmo_vision_voice_bot.py b/src/cmd/bots/omni/daily_minicpmo_vision_voice_bot.py index eb2f0c63..79724ced 100644 --- a/src/cmd/bots/omni/daily_minicpmo_vision_voice_bot.py +++ b/src/cmd/bots/omni/daily_minicpmo_vision_voice_bot.py @@ -98,6 +98,6 @@ async def arun(self): async def on_first_participant_say_hi(self, transport: DailyTransport, participant): transport.capture_participant_video(participant["id"], framerate=0) self.image_requester.set_participant_id(participant["id"]) - self._vision_voice_processor.say( + await self._vision_voice_processor.say( "你好,欢迎使用 Vision Voice Omni Bot. 我是一名虚拟助手,可以结合视频进行提问。" ) diff --git a/src/cmd/bots/omni/livekit_qwen2_5omni_vision_voice_bot.py b/src/cmd/bots/omni/livekit_qwen2_5omni_vision_voice_bot.py new file mode 100644 index 00000000..ad9d9b6d --- /dev/null +++ b/src/cmd/bots/omni/livekit_qwen2_5omni_vision_voice_bot.py @@ -0,0 +1,106 @@ +import logging + +from apipeline.pipeline.pipeline import Pipeline +from apipeline.pipeline.task import PipelineParams, PipelineTask +from apipeline.pipeline.runner import PipelineRunner +from apipeline.processors.logger import FrameLogger +from apipeline.frames import AudioRawFrame, TextFrame + +from src.processors.user_image_request_processor import UserImageRequestProcessor +from src.processors.aggregators.vision_image_audio_frame import VisionImageAudioFrameAggregator +from src.processors.speech.audio_save_processor import AudioSaveProcessor +from src.processors.aggregators.user_audio_response import UserAudioResponseAggregator +from src.cmd.bots.base_livekit import LivekitRoomBot, rtc +from src.modules.speech.vad_analyzer import VADAnalyzerEnvInit +from src.common.types import LivekitParams +from src.transports.livekit import LivekitTransport +from src.cmd.bots import register_ai_room_bots +from src.types.frames import * + +from dotenv import load_dotenv + +load_dotenv(override=True) + + +@register_ai_room_bots.register +class LivekitQwen2_5OmniVisionVoiceBot(LivekitRoomBot): + """ + use livekit images + audio stream(bytes) --> Qwen2_5Omni vision voice processor -->text/audio_bytes + """ + + def __init__(self, **args) -> None: + super().__init__(**args) + self.init_bot_config() + + async def arun(self): + self._vad_analyzer = VADAnalyzerEnvInit.initVADAnalyzerEngine() + self.params = LivekitParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=self._vad_analyzer, + vad_audio_passthrough=True, + camera_in_enabled=True, + ) + + self._vision_voice_processor = self.get_qwen2_5omni_vision_voice_processor() + stream_info = self._vision_voice_processor.stream_info + self.params.audio_out_sample_rate = stream_info["sample_rate"] + self.params.audio_out_channels = stream_info["channels"] + + transport = LivekitTransport( + self.args.token, + params=self.params, + ) + self.regisiter_room_event(transport) + + in_audio_aggr = UserAudioResponseAggregator() + self.image_requester = UserImageRequestProcessor(request_frame_cls=AudioRawFrame) + image_audio_aggr = VisionImageAudioFrameAggregator() + + self.task = PipelineTask( + Pipeline( + [ + transport.input_processor(), + in_audio_aggr, + FrameLogger(include_frame_types=[AudioRawFrame]), + # AudioSaveProcessor(prefix_name="user_audio_aggr"), + # FrameLogger(include_frame_types=[PathAudioRawFrame]), + self.image_requester, + image_audio_aggr, + FrameLogger(include_frame_types=[VisionImageVoiceRawFrame]), + self._vision_voice_processor, + FrameLogger(include_frame_types=[AudioRawFrame, TextFrame]), + # AudioSaveProcessor(prefix_name="bot_speak"), + transport.output_processor(), + ] + ), + params=PipelineParams( + allow_interruptions=False, + enable_metrics=True, + send_initial_empty_metrics=False, + ), + ) + + await PipelineRunner().run(self.task) + + async def on_first_participant_joined( + self, + transport: LivekitTransport, + participant: rtc.RemoteParticipant, + ): + # subscribed the first participant + transport.capture_participant_video(participant.sid, framerate=0) + self.image_requester.set_participant_id(participant.sid) + + participant_name = participant.name if participant.name else participant.identity + await self._vision_voice_processor.say( + f"你好,{participant_name} 欢迎使用 Vision Voice Omni Bot. 我是一名虚拟助手,可以结合视频进行提问。" + ) + + async def on_video_track_subscribed( + self, + transport: LivekitTransport, + participant: rtc.RemoteParticipant, + ): + transport.capture_participant_video(participant.sid, framerate=0) diff --git a/src/cmd/bots/voice/livekit_asr_qwen2_5omni_voice_bot.py b/src/cmd/bots/voice/livekit_asr_qwen2_5omni_voice_bot.py new file mode 100644 index 00000000..a5ae8f48 --- /dev/null +++ b/src/cmd/bots/voice/livekit_asr_qwen2_5omni_voice_bot.py @@ -0,0 +1,84 @@ +import logging + +from apipeline.pipeline.pipeline import Pipeline +from apipeline.pipeline.task import PipelineParams, PipelineTask +from apipeline.pipeline.runner import PipelineRunner +from apipeline.processors.logger import FrameLogger +from apipeline.frames import AudioRawFrame, TextFrame + +from src.processors.aggregators.user_response import UserResponseAggregator +from src.cmd.bots.base_livekit import LivekitRoomBot +from src.modules.speech.vad_analyzer import VADAnalyzerEnvInit +from src.common.types import LivekitParams +from src.transports.livekit import LivekitTransport +from src.cmd.bots import register_ai_room_bots + +from dotenv import load_dotenv + +load_dotenv(override=True) + + +@register_ai_room_bots.register +class LivekitAsrQwen2_5OmniVoiceBot(LivekitRoomBot): + """ + use livekit audio stream(bytes) --> asr --> text ---> qwen2.5omni voice processor -->text/audio_bytes + - don't support tools call, need sft + """ + + def __init__(self, **args) -> None: + super().__init__(**args) + self.init_bot_config() + + async def arun(self): + self._vad_analyzer = VADAnalyzerEnvInit.initVADAnalyzerEngine() + self.params = LivekitParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=self._vad_analyzer, + vad_audio_passthrough=True, + ) + asr_processor = self.get_asr_processor() + + self._voice_processor = self.get_text_qwen2_5omni_voice_processor() + stream_info = self._voice_processor.stream_info + self.params.audio_out_sample_rate = stream_info["sample_rate"] + self.params.audio_out_channels = stream_info["channels"] + + transport = LivekitTransport( + self.args.token, + params=self.params, + ) + self.regisiter_room_event(transport) + + # messages = [] + # if self._bot_config.llm.messages: + # messages = self._bot_config.llm.messages + + self.task = PipelineTask( + Pipeline( + [ + transport.input_processor(), + asr_processor, + UserResponseAggregator(), + FrameLogger(include_frame_types=[TextFrame]), + self._voice_processor, + FrameLogger(include_frame_types=[AudioRawFrame, TextFrame]), + transport.output_processor(), + ] + ), + params=PipelineParams( + allow_interruptions=False, + enable_metrics=True, + send_initial_empty_metrics=False, + ), + ) + + transport.add_event_handlers( + "on_first_participant_joined", + [self.on_first_participant_joined, self.on_first_participant_say_hi], + ) + await PipelineRunner().run(self.task) + + async def on_first_participant_say_hi(self, transport: LivekitTransport, participant): + pass diff --git a/src/cmd/bots/voice/livekit_qwen2_5omni_voice_bot.py b/src/cmd/bots/voice/livekit_qwen2_5omni_voice_bot.py new file mode 100644 index 00000000..cf76d8a1 --- /dev/null +++ b/src/cmd/bots/voice/livekit_qwen2_5omni_voice_bot.py @@ -0,0 +1,88 @@ +import logging + +from apipeline.pipeline.pipeline import Pipeline +from apipeline.pipeline.task import PipelineParams, PipelineTask +from apipeline.pipeline.runner import PipelineRunner +from apipeline.processors.logger import FrameLogger +from apipeline.frames import AudioRawFrame, TextFrame + +from src.processors.speech.audio_save_processor import AudioSaveProcessor +from src.processors.aggregators.user_audio_response import UserAudioResponseAggregator +from src.cmd.bots.base_livekit import LivekitRoomBot +from src.modules.speech.vad_analyzer import VADAnalyzerEnvInit +from src.common.types import LivekitParams +from src.transports.livekit import LivekitTransport +from src.cmd.bots import register_ai_room_bots +from src.types.frames import * + +from dotenv import load_dotenv + +load_dotenv(override=True) + + +@register_ai_room_bots.register +class LivekitQwen2_5OmniVoiceBot(LivekitRoomBot): + """ + use livekit audio stream(bytes) --> Qwen2_5Omni voice processor -->text/audio_bytes + - unsupport tools call, need sft + """ + + def __init__(self, **args) -> None: + super().__init__(**args) + self.init_bot_config() + + async def arun(self): + self._vad_analyzer = VADAnalyzerEnvInit.initVADAnalyzerEngine() + self.params = LivekitParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=self._vad_analyzer, + vad_audio_passthrough=True, + ) + + self._voice_processor = self.get_audio_qwen2_5omni_voice_processor() + stream_info = self._voice_processor.stream_info + self.params.audio_out_sample_rate = stream_info["sample_rate"] + self.params.audio_out_channels = stream_info["channels"] + + transport = LivekitTransport( + self.args.token, + params=self.params, + ) + self.regisiter_room_event(transport) + + # messages = [] + # if self._bot_config.llm.messages: + # messages = self._bot_config.llm.messages + + self.task = PipelineTask( + Pipeline( + [ + transport.input_processor(), + UserAudioResponseAggregator(), + FrameLogger(include_frame_types=[AudioRawFrame]), + # AudioSaveProcessor(prefix_name="user_audio_aggr"), + FrameLogger(include_frame_types=[PathAudioRawFrame]), + self._voice_processor, + FrameLogger(include_frame_types=[AudioRawFrame, TextFrame]), + # AudioSaveProcessor(prefix_name="bot_speak"), + transport.output_processor(), + ] + ), + params=PipelineParams( + allow_interruptions=False, + enable_metrics=True, + send_initial_empty_metrics=False, + ), + ) + + transport.add_event_handlers( + "on_first_participant_joined", + [self.on_first_participant_joined, self.on_first_participant_say_hi], + ) + + await PipelineRunner().run(self.task) + + async def on_first_participant_say_hi(self, transport: LivekitTransport, participant): + pass diff --git a/src/common/task_manager/sync_generator.py b/src/common/task_manager/sync_generator.py new file mode 100644 index 00000000..517787f7 --- /dev/null +++ b/src/common/task_manager/sync_generator.py @@ -0,0 +1,34 @@ +import asyncio +from typing import Any, AsyncGenerator, Generator + +import concurrent + + +class SynchronizedGenerator(Generator[Any, None, None]): + def __init__( + self, + generator: AsyncGenerator[Any, None], + loop: asyncio.AbstractEventLoop, + ): + self._generator = generator + self._loop = loop + + def __iter__(self): + return self + + def __next__(self): + try: + return asyncio.run_coroutine_threadsafe( + self._generator.__anext__(), + self._loop, + ).result() + except StopAsyncIteration as e: + raise StopIteration from e + except concurrent.futures._base.CancelledError as e: + raise StopIteration from e + + def send(self, value): + return self.__next__() + + def throw(self, type, value=None, traceback=None): + raise StopIteration diff --git a/src/common/utils/helper.py b/src/common/utils/helper.py index e2fefab0..9080c85e 100644 --- a/src/common/utils/helper.py +++ b/src/common/utils/helper.py @@ -2,6 +2,7 @@ import json import logging import platform +import threading import pyloudnorm as pyln import numpy as np @@ -71,3 +72,21 @@ def file_md5_hash(file_path: str): data = _file.read() data_hash = hashlib.md5(data).hexdigest() return data_hash + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = threading.RLock() + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def set(self, key, value): + with self._lock: + self._dict[key] = value + + def pop(self, key): + with self._lock: + return self._dict.pop(key, None) diff --git a/src/core/llm/__init__.py b/src/core/llm/__init__.py index e348a01c..7864b459 100644 --- a/src/core/llm/__init__.py +++ b/src/core/llm/__init__.py @@ -53,6 +53,8 @@ def getEngine(tag, **kwargs) -> interface.ILlm | EngineClass: from .transformers import manual_vision_img_janus_pro elif "llm_transformers_manual_vision_minicpmo" == tag: from .transformers import manual_vision_voice_minicpmo + elif "llm_transformers_manual_qwen2_5omni" in tag: + from .transformers import manual_vision_voice_qwen elif "llm_transformers_manual" == tag: from .transformers import manual elif "llm_transformers_pipeline" == tag: @@ -150,19 +152,22 @@ def get_llm_personal_ai_proxy_args() -> dict: return kwargs @staticmethod - def _get_llm_generate_args() -> dict: + def _get_llm_generate_args(prefix: str = "") -> dict: from src.types.llm.sampling import LMGenerateArgs return LMGenerateArgs( - lm_gen_seed=int(os.getenv("LLM_GEN_SEED", "42")), - lm_gen_do_sample=bool(os.getenv("LLM_GEN_DO_SAMPLE", "1")), - lm_gen_max_new_tokens=int(os.getenv("LLM_GEN_MAX_NEW_TOKENS", "1024")), - lm_gen_temperature=float(os.getenv("LLM_GEN_TEMPERATURE", "0.8")), - lm_gen_top_k=int(os.getenv("LLM_GEN_TOP_K", "50")), - lm_gen_top_p=float(os.getenv("LLM_GEN_TOP_P", "0.95")), - lm_gen_min_p=float(os.getenv("LLM_GEN_MIN_P", "0.0")), - lm_gen_repetition_penalty=float(os.getenv("LLM_GEN_REPETITION_PENALTY", "1.1")), - lm_gen_min_new_tokens=int(os.getenv("LLM_GEN_MIN_NEW_TOKENS", "0")), + lm_gen_seed=int(os.getenv(f"{prefix}LLM_GEN_SEED", "42")), + lm_gen_do_sample=bool(os.getenv(f"{prefix}LLM_GEN_DO_SAMPLE", "1")), + lm_gen_max_tokens_per_step=int(os.getenv(f"{prefix}LLM_GEN_MAX_TOKENS_PER_STEP", "3")), + lm_gen_max_new_tokens=int(os.getenv(f"{prefix}LLM_GEN_MAX_NEW_TOKENS", "1024")), + lm_gen_temperature=float(os.getenv(f"{prefix}LLM_GEN_TEMPERATURE", "0.8")), + lm_gen_top_k=int(os.getenv(f"{prefix}LLM_GEN_TOP_K", "50")), + lm_gen_top_p=float(os.getenv(f"{prefix}LLM_GEN_TOP_P", "0.95")), + lm_gen_min_p=float(os.getenv(f"{prefix}LLM_GEN_MIN_P", "0.0")), + lm_gen_repetition_penalty=float( + os.getenv(f"{prefix}LLM_GEN_REPETITION_PENALTY", "1.1") + ), + lm_gen_min_new_tokens=int(os.getenv(f"{prefix}LLM_GEN_MIN_NEW_TOKENS", "1")), ).__dict__ @staticmethod @@ -286,6 +291,66 @@ def get_llm_trtllm_runner_generator_args() -> dict: ) return kwargs + @staticmethod + def get_qwen2_5omni_transformers_args() -> dict: + from src.types.llm.transformers import TransformersLMArgs + from src.thirdparty.qwen2_code2wav import Code2WavEngineConfig, Code2WavGenerationConfig + from src.types.omni.qwen2_vision_voice import Qwen2_5TransformersVisionVoiceLMArgs + + kwargs = Qwen2_5TransformersVisionVoiceLMArgs( + lm_model_name_or_path=os.getenv( + "LLM_MODEL_NAME_OR_PATH", os.path.join(MODELS_DIR, "Qwen/Qwen2.5-Omni-7B") + ), + lm_attn_impl=os.getenv("LLM_ATTN_IMPL", None), + lm_device=os.getenv("LLM_DEVICE", None), + lm_device_map=os.getenv("LLM_DEVICE_MAP", None), + lm_torch_dtype=os.getenv("LLM_TORCH_DTYPE", "auto"), + lm_stream=bool(os.getenv("LLM_STREAM", "1")), + init_chat_prompt=os.getenv("LLM_INIT_CHAT_PROMPT", ""), + chat_history_size=int(os.getenv("LLM_CHAT_HISTORY_SIZE", "10")), # cache 10 round + model_type=os.getenv("LLM_MODEL_TYPE", "chat_completion"), + warmup_steps=int(os.getenv("LLM_WARMUP_STEPS", "1")), + **LLMEnvInit._get_llm_generate_args(), + thinker_eos_token_ids=[ + int(i) for i in os.getenv("THINKER_EOS_TOKEN_IDS", "151644,151645").split(",") + ], + thinker_stop_strings_per_step=[ + i for i in os.getenv("THINKER_STOP_STRINGS_PER_STEP", ".。") + ], + thinker_args=TransformersLMArgs( + **LLMEnvInit._get_llm_generate_args(prefix="THINKER_"), + ).__dict__, + speaker=os.getenv("SPEAKER", "Chelsie"), + talker_eos_token_ids=[ + int(i) for i in os.getenv("TALKER_EOS_TOKEN_IDS", "8292, 8294").split(",") + ], + talker_args=TransformersLMArgs( + **LLMEnvInit._get_llm_generate_args(prefix="TALKER_"), + ).__dict__, + code2wav_args=Code2WavEngineConfig( + model_path=os.getenv( + "CODE2WAV_MODEL_PATH", os.path.join(MODELS_DIR, "Qwen/Qwen2.5-Omni-7B") + ), + enable_torch_compile=bool(os.getenv("CODE2WAV_ENABLE_TORCH_COMPILE", "1")), + enable_torch_compile_first_chunk=bool( + os.getenv("CODE2WAV_ENABLE_TORCH_COMPILE_FIRST_CHUNK", "") + ), + odeint_method=os.getenv("CODE2WAV_ODEINT_METHOD", "euler"), + odeint_method_relaxed=bool(os.getenv("CODE2WAV_ODEINT_METHOD_RELAXED", "")), + batched_chunk=int(os.getenv("CODE2WAV_BATCHED_CHUNK", "3")), + frequency=os.getenv("CODE2WAV_FREQUENCY", "50hz"), + device=os.getenv("CODE2WAV_DEVICE", "cuda"), + code2wav_dynamic_batch=bool(os.getenv("CODE2WAV_DYNAMIC_BATCHING", "")), + num_steps=int(os.getenv("CODE2WAV_NUM_STEPS", "10")), + guidance_scale=float(os.getenv("CODE2WAV_GUIDANCE_SCALE", "0.5")), + sway_coefficient=float(os.getenv("CODE2WAV_SWAY_COEFFICIENT", "-1.0")), + ).__dict__, + is_use_sliding_window_code2wav=bool(os.getenv("IS_USE_SLIDING_WINDOW_CODE2WAV", "")), + disable_talker=bool(os.getenv("DISABLE_TALKER", "")), + ).__dict__ + + return kwargs + # TAG : config map_config_func = { "llm_llamacpp": get_llm_llamacpp_args, @@ -303,6 +368,15 @@ def get_llm_trtllm_runner_generator_args() -> dict: "llm_transformers_manual_vision_janus_flow": get_llm_transformers_args, "llm_transformers_manual_image_janus_flow": get_llm_transformers_manual_image_janus_flow_args, "llm_transformers_manual_vision_minicpmo": get_llm_transformers_args, + "llm_transformers_manual_qwen2_5omni": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_vision": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_audio": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_audio_asr": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_audio_translation": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_audio_classification": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_vision_voice": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_text_voice": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_qwen2_5omni_audio_voice": get_qwen2_5omni_transformers_args, "llm_transformers_generator": get_llm_transformers_args, "llm_llamacpp_generator": get_llm_llamacpp_generator_args, "llm_vllm_generator": get_llm_vllm_generator_args, diff --git a/src/core/llm/transformers/manual_vision_qwen.py b/src/core/llm/transformers/manual_vision_qwen.py index 69077efd..ba82bfaa 100644 --- a/src/core/llm/transformers/manual_vision_qwen.py +++ b/src/core/llm/transformers/manual_vision_qwen.py @@ -31,7 +31,7 @@ def __init__(self, **args) -> None: from transformers import ( Qwen2_5_VLForConditionalGeneration as QwenVLForConditionalGeneration, ) - else: + if self.TAG == "llm_transformers_manual_vision_qwen": from transformers import ( Qwen2VLForConditionalGeneration as QwenVLForConditionalGeneration, ) diff --git a/src/core/llm/transformers/manual_vision_voice_minicpmo.py b/src/core/llm/transformers/manual_vision_voice_minicpmo.py index 55792d7c..300f7932 100644 --- a/src/core/llm/transformers/manual_vision_voice_minicpmo.py +++ b/src/core/llm/transformers/manual_vision_voice_minicpmo.py @@ -197,11 +197,6 @@ def warmup(self): if self.args.warmup_steps < 0: return logging.info(f"Warming up {self.__class__.__name__} device: {self._model.device}") - if "cuda" in str(self._model.device): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() dummy_input_text = self.args.warnup_prompt content = [dummy_input_text] @@ -215,7 +210,13 @@ def warmup(self): } ] - for i in range(self.args.warmup_steps): + if "cuda" in str(self._model.device): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for step in range(self.args.warmup_steps): self._sys_msg and self._model.streaming_prefill( session_id="", msgs=[self._sys_msg], tokenizer=self._tokenizer ) @@ -232,15 +233,13 @@ def warmup(self): repetition_penalty=self.args.lm_gen_repetition_penalty, generate_audio=False, ) - for step in range(self.args.warmup_steps): - for _ in streamer: - times = [] - start_time = time.perf_counter() - for _ in streamer: - times.append(time.perf_counter() - start_time) - start_time = time.perf_counter() - logging.info(f"step {step} warnup TTFT time: {times[0]} s") - step += 1 + times = [] + start_time = time.perf_counter() + for _ in streamer: + times.append(time.perf_counter() - start_time) + start_time = time.perf_counter() + logging.info(f"step {step} warnup TTFT time: {times[0]} s") + step += 1 if "cuda" in str(self._model.device): end_event.record() diff --git a/src/core/llm/transformers/manual_vision_voice_qwen.py b/src/core/llm/transformers/manual_vision_voice_qwen.py new file mode 100644 index 00000000..99733700 --- /dev/null +++ b/src/core/llm/transformers/manual_vision_voice_qwen.py @@ -0,0 +1,641 @@ +import logging +from threading import Thread +from time import perf_counter +import time +import traceback +from typing import Generator, Optional + +import numpy as np +import torch + + +try: + from qwen_omni_utils import process_mm_info + from transformers import ( + AutoConfig, + AutoProcessor, + TextIteratorStreamer, + ) + from src.thirdparty.qwen2_code2wav.engine import Code2WavEngine + from src.thirdparty.qwen2_code2wav import Code2WavEngineConfig, Code2WavGenerationConfig + from src.core.llm.transformers.models.qwen2_5_omni import ( + Qwen2_5OmniForConditionalGenerationStreaming, + ) +except ModuleNotFoundError as e: + logging.error(f"Exception: {e}") + logging.error( + "In order to use Qwen2.5Omni, you need to `pip install achatbot[llm_transformers_manual_vision_voice_qwen]`" + ) + raise Exception(f"Missing module: {e}") + + +from src.common.utils.helper import ThreadSafeDict, get_device +from src.core.llm.transformers.streamer import TokenStreamer +from src.common.random import set_all_random_seed +from src.common.chat_history import ChatHistory +from src.common.session import Session +from src.types.omni.qwen2_vision_voice import Qwen2_5TransformersVisionVoiceLMArgs +from src.types.llm.transformers import TransformersLMArgs +from .base import TransformersBaseLLM + + +class TransformersManualQwen2_5OmniLLM(TransformersBaseLLM): + TAG = "llm_transformers_manual_qwen2_5omni" + + # NOTE: if want to generate speech, need use this system prompt to generate speech + SPEECH_SYS_PROMPT = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + # Voice settings + SPEAKER_LIST = ["Chelsie", "Ethan"] + DEFAULT_SPEAKER = "Chelsie" + RATE = 24000 + + def __init__(self, **args) -> None: + self.args = Qwen2_5TransformersVisionVoiceLMArgs(**args) + self.args.lm_device = self.args.lm_device or get_device() + self.thinker_args = TransformersLMArgs(**self.args.thinker_args) + self.talker_args = TransformersLMArgs(**self.args.talker_args) + self.code2wav_args = Code2WavEngineConfig(**self.args.code2wav_args) + logging.info(f"Model args: {args}") + logging.info(f"Model thinker_args: {self.thinker_args}") + logging.info(f"Model talker_args: {self.talker_args}") + logging.info(f"Model code2wav_args: {self.code2wav_args}") + config = AutoConfig.from_pretrained(self.args.lm_model_name_or_path) + config.enable_audio_output = True + if self.args.disable_talker is True: + config.enable_audio_output = False + + if self.args.lm_device_map: + self._model: Qwen2_5OmniForConditionalGenerationStreaming = ( + Qwen2_5OmniForConditionalGenerationStreaming.from_pretrained( + self.args.lm_model_name_or_path, + torch_dtype=self.args.lm_torch_dtype, + #!NOTE: https://github.com/huggingface/transformers/issues/20896 + # device_map for multi cpu/gpu with accelerate + device_map=self.args.lm_device_map, + attn_implementation=self.args.lm_attn_impl, + trust_remote_code=True, + config=config, + ).eval() + ) + else: + self._model: Qwen2_5OmniForConditionalGenerationStreaming = ( + Qwen2_5OmniForConditionalGenerationStreaming.from_pretrained( + self.args.lm_model_name_or_path, + torch_dtype=self.args.lm_torch_dtype, + attn_implementation=self.args.lm_attn_impl, + trust_remote_code=True, + config=config, + ) + .eval() + .to(self.args.lm_device) + ) + + # The default range for the number of visual tokens per image in the model is 4-16384. + # You can set min_pixels and max_pixels according to your needs, such as a + # token count range of 256-1280, to balance speed and memory usage. + self._tokenizer = AutoProcessor.from_pretrained( + self.args.lm_model_name_or_path, + min_pixels=256 * 28 * 28, + max_pixels=1280 * 28 * 28, + trust_remote_code=True, + ) + + # use sliding window code2wav + self.code2wav_engine: Code2WavEngine = None + if self.args.is_use_sliding_window_code2wav is True: + self.code2wav_engine = Code2WavEngine(**self.args.code2wav_args) + if hasattr(self._model, "token2wav"): + logging.info("use Code2WavEngine, delete _model.token2wav") + del self._model.token2wav + torch.cuda.empty_cache() + + self.chat_history_dict = ThreadSafeDict() + + self.warmup() + + def chat_history(self, session: Session, **kwargs) -> ChatHistory: + session_id = session.ctx.client_id + if not self.chat_history_dict.get(session_id): + chat_history = ChatHistory( + kwargs.get("chat_history_size", None) or self.args.chat_history_size + ) + init_chat_role = kwargs.get("init_chat_role", None) or self.args.init_chat_role + init_chat_prompt = ( + kwargs.get("init_chat_prompt", self.args.init_chat_prompt) or self.SPEECH_SYS_PROMPT + ) + if init_chat_role: + sys_msg = { + "role": init_chat_role, + "content": [ + { + "type": "text", + "text": init_chat_prompt, + } + ], + } + chat_history.init(sys_msg) + self.chat_history_dict.set(session_id, chat_history) + + return self.chat_history_dict.get(session_id) + + def warmup(self): + if self.args.warmup_steps <= 0: + return + logging.info( + f"Warming up {self.__class__.__name__} device: {self._model.device} with {self.args.warmup_steps} steps" + ) + + dummy_msgs = [ + { + "role": self.args.init_chat_role, + "content": [ + { + "type": "text", + "text": self.args.init_chat_prompt or self.SPEECH_SYS_PROMPT, + } + ], + }, + { + "role": self.args.user_role, + "content": [ + {"type": "text", "text": self.args.warnup_prompt or "请简单介绍下自己"} + ], + }, + ] + # Preparation for inference + text = self._tokenizer.apply_chat_template( + dummy_msgs, tokenize=False, add_generation_prompt=True + ) + logging.info(f"Warmup text: {text}") + audios, images, videos = process_mm_info(dummy_msgs, use_audio_in_video=False) + inputs = self._tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=False, + ) + inputs = inputs.to(self._model.device).to(self._model.dtype) + + if "cuda" in str(self._model.device): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for step in range(self.args.warmup_steps): + streamer = self._model.generate_stream( + inputs, + use_audio_in_video=False, + thinker_max_tokens_per_step=self.thinker_args.lm_gen_max_tokens_per_step, + thinker_max_new_tokens=15, + thinker_top_k=self.thinker_args.lm_gen_top_k, + thinker_top_p=self.thinker_args.lm_gen_top_p, + thinker_temperature=self.thinker_args.lm_gen_temperature, + thinker_repetition_penalty=self.thinker_args.lm_gen_repetition_penalty, + thinker_eos_token_ids=self.args.thinker_eos_token_ids, + thinker_stop_strings_per_step=self.args.thinker_stop_strings_per_step, + tokenizer=self._tokenizer.tokenizer, + return_audio=self._model.has_talker, + speaker=self.args.speaker, + talker_top_k=self.talker_args.lm_gen_top_k, + talker_top_p=self.talker_args.lm_gen_top_p, + talker_temperature=self.talker_args.lm_gen_temperature, + talker_repetition_penalty=self.talker_args.lm_gen_repetition_penalty, + talker_min_new_tokens=self.talker_args.lm_gen_min_new_tokens, + talker_max_new_tokens=self.talker_args.lm_gen_max_new_tokens, + talker_eos_token_ids=self.args.talker_eos_token_ids, + code2wav_num_steps=self.code2wav_args.num_steps, + code2wav_guidance_scale=self.code2wav_args.guidance_scale, + code2wav_sway_coefficient=self.code2wav_args.sway_coefficient, + code2wav_chunk_stream_func=self.code2wav_sliding_window_chunk_stream + if self.args.is_use_sliding_window_code2wav + else None, + ) + times = [] + start_time = time.perf_counter() + for i, chunk in enumerate(streamer): + times.append(time.perf_counter() - start_time) + text = self._tokenizer.decode(chunk["thinker_ids"][0], skip_special_tokens=True) + if "talker_wav" in chunk: + logging.info( + f"{i} chunk: {text} | {chunk['talker_wav'].shape} , warmup time: {times[i]} s" + ) + else: + logging.info(f"{i} chunk: {text} , warmup time: {times[i]} s") + # if ( + # self.args.is_use_sliding_window_code2wav + # and self.code2wav_args.enable_torch_compile is True + # ): + # logging.info(f"torch.compile code2wav warmup finish") + # break + start_time = time.perf_counter() + if len(times) > 0: + logging.info( + f"step {step} warmup TTFT(chunk) time: {times[0]} s | total: {sum(times)} s" + ) + else: + logging.warning(f"step {step} warmup no generate stream") + step += 1 + + if "cuda" in str(self._model.device): + end_event.record() + torch.cuda.synchronize() + logging.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def skip_token_ids(self): + token_ids = [] + for i in ",;.?,;。?!": + # for i in ",.": + token_id = self._tokenizer.tokenizer.encode(i) + token_ids.extend(token_id) + return token_ids + + def code2wav_sliding_window_chunk_stream( + self, + talker_streamer: TokenStreamer, + speaker: str = DEFAULT_SPEAKER, + talker_eos_token_ids: list[int] = [8292, 8294], + **kwargs, + ) -> Generator[torch.Tensor, None, None]: + """ + code2wav sliding window streaming + """ + talker_eos_token_ids = talker_eos_token_ids or self.args.talker_eos_token_ids + prev_generated = None + progress = 0 + finished = False + code2wav_times = [] + talker_generate_codes = [] + times = [] + start_time = perf_counter() + for token_id in talker_streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + if token_id in talker_eos_token_ids: + finished = True + talker_generate_codes.append(token_id) + prev_generated, wav = self.code2wav_engine.step_generate_waveform( + talker_generate_codes, + voice_type=speaker, + prev_generated=prev_generated, + progress=progress, + finished=finished, + gen_args=Code2WavGenerationConfig( + num_steps=kwargs.get("code2wav_num_steps") or self.code2wav_args.num_steps, + guidance_scale=kwargs.get("code2wav_guidance_scale") + or self.code2wav_args.guidance_scale, + sway_coefficient=kwargs.get("code2wav_sway_coefficient") + or self.code2wav_args.sway_coefficient, + ), + ) + if wav is not None: + progress += 1 + code2wav_times.append(perf_counter() - start_time) + yield wav.detach() # (T,) + + start_time = perf_counter() + + logging.info( + f"talker generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + logging.info( + f"code2wav sliding window streaming first chunk time: {code2wav_times[0]} s | cost: {sum(code2wav_times)} s" + ) + + def get_prompt(self, session: Session) -> list: + prompt = [] + if isinstance(session.ctx.state["prompt"], list): + prompt = session.ctx.state["prompt"] + return prompt + + @torch.no_grad() + def generate(self, session: Session, **kwargs): + """ + - prompt: + [ + {"type": "text", "text": str}, + {"type": "image", "image": url / path / base64 / nparray}, + {"type": "video", "video": url / path / base64 / nparray}, + {"type": "audio", "audio": url / path / base64 / nparray}, + ] + + - return Generator[dict, None, None]: + { + "text": str, + "audio_wav": torch.Tensor,# (T,) + } + """ + seed = kwargs.get("seed", self.args.lm_gen_seed) + set_all_random_seed(seed) + + prompt = self.get_prompt(session) + + message = {"role": "user", "content": prompt} + session_chat_history = self.chat_history(session, **kwargs) + session_chat_history.append(message) + messages = session_chat_history.to_list() + logging.debug(f"messages: {messages}") + + text = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + audios, images, videos = process_mm_info( + messages, use_audio_in_video=kwargs.get("use_audio_in_video", False) + ) + logging.debug(text) + { + logging.debug(f"audios[{i}]: {item.shape}") for i, item in enumerate(audios) + } if audios else logging.debug(audios) + { + logging.debug(f"images[{i}]: {item}") for i, item in enumerate(images) + } if images else logging.debug(images) + { + logging.debug(f"videos[{i}]: {item.shape}") for i, item in enumerate(videos) + } if videos else logging.debug(videos) + + inputs = self._tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=kwargs.get("use_audio_in_video", False), + ) + inputs = inputs.to(self._model.device).to(self._model.dtype) + for k, v in inputs.items(): + logging.debug(f"{k}: {v.shape}") + + return_audio = kwargs.get("return_audio", self._model.has_talker) + thinker_all_talker_stream = kwargs.get( + "thinker_all_talker_stream", + self._model.has_talker and self.args.thinker_all_talker_stream, + ) + gen_assistant_text = "" + try: + if not return_audio: # text / vision(image/video) / audio / text + image -> text + for item in self.thinker_stream( + inputs, + use_audio_in_video=kwargs.get("use_audio_in_video", False), + thinker_top_k=kwargs.get("thinker_top_k", None) + or self.thinker_args.lm_gen_top_k, + thinker_top_p=kwargs.get("thinker_top_p", None) + or self.thinker_args.lm_gen_top_p, + thinker_temperature=kwargs.get("thinker_temperature", None) + or self.thinker_args.lm_gen_temperature, + thinker_repetition_penalty=kwargs.get("thinker_repetition_penalty", None) + or self.thinker_args.lm_gen_repetition_penalty, + thinker_min_new_tokens=kwargs.get("thinker_min_new_tokens", None) + or self.thinker_args.lm_gen_min_new_tokens, + thinker_max_new_tokens=kwargs.get("thinker_max_new_tokens", None) + or self.thinker_args.lm_gen_max_new_tokens, + ): + gen_assistant_text += item["text"] + yield item + else: # text / vision(image/video) / audio / text + image -> text + audio + gen_stream_func = self._model.generate_stream + if ( + thinker_all_talker_stream is True + ): # text / vision(image/video) / audio / text + image -> all text + chunk audio + logging.info("use thinker_all_talker_stream to generate") + gen_stream_func = self._model.thinker_all_talker_stream + + stream = gen_stream_func( + inputs, + use_audio_in_video=kwargs.get("use_audio_in_video", False), + thinker_max_tokens_per_step=kwargs.get("thinker_max_tokens_per_step", None) + or self.thinker_args.lm_gen_max_tokens_per_step, + thinker_max_new_tokens=kwargs.get("thinker_max_new_tokens", None) + or self.thinker_args.lm_gen_max_new_tokens, + thinker_top_k=kwargs.get("thinker_top_k", None) + or self.thinker_args.lm_gen_top_k, + thinker_top_p=kwargs.get("thinker_top_p", None) + or self.thinker_args.lm_gen_top_p, + thinker_temperature=kwargs.get("thinker_temperature", None) + or self.thinker_args.lm_gen_temperature, + thinker_repetition_penalty=kwargs.get("thinker_repetition_penalty", None) + or self.thinker_args.lm_gen_repetition_penalty, + thinker_eos_token_ids=kwargs.get("thinker_eos_token_ids", None) + or self.args.thinker_eos_token_ids, + thinker_stop_strings_per_step=kwargs.get("thinker_stop_strings_per_step", None) + or self.args.thinker_stop_strings_per_step, + tokenizer=self._tokenizer.tokenizer, + return_audio=kwargs.get("return_audio", self._model.has_talker), + speaker=kwargs.get("speaker", None) or self.args.speaker, + talker_top_k=kwargs.get("talker_top_k", None) or self.talker_args.lm_gen_top_k, + talker_top_p=kwargs.get("talker_top_p", None) or self.talker_args.lm_gen_top_p, + talker_temperature=kwargs.get("talker_temperature", None) + or self.talker_args.lm_gen_temperature, + talker_repetition_penalty=kwargs.get("talker_repetition_penalty", None) + or self.talker_args.lm_gen_repetition_penalty, + talker_min_new_tokens=kwargs.get("talker_min_new_tokens", None) + or self.talker_args.lm_gen_min_new_tokens, + talker_max_new_tokens=kwargs.get("talker_max_new_tokens", None) + or self.talker_args.lm_gen_max_new_tokens, + talker_eos_token_ids=kwargs.get("talker_eos_token_ids", None) + or self.args.talker_eos_token_ids, + talker_skip_thinker_token_ids=kwargs.get("talker_skip_thinker_token_ids", None) + or self.args.talker_skip_thinker_token_ids + or self.skip_token_ids(), + code2wav_num_steps=kwargs.get("code2wav_num_steps", None) + or self.code2wav_args.num_steps, + code2wav_guidance_scale=kwargs.get("code2wav_guidance_scale", None) + or self.code2wav_args.guidance_scale, + code2wav_sway_coefficient=kwargs.get("code2wav_sway_coefficient", None) + or self.code2wav_args.sway_coefficient, + code2wav_chunk_stream_func=self.code2wav_sliding_window_chunk_stream + if self.args.is_use_sliding_window_code2wav + else None, + mask_embedding=kwargs.get("mask_embedding", None) or self.args.mask_embedding, + ) + + gen_text = "" + for chunk in stream: + text = self._tokenizer.decode(chunk["thinker_ids"][0], skip_special_tokens=True) + if gen_text != text: + gen_text = text + gen_assistant_text += text + + if "talker_wav" not in chunk: + yield {"text": text} + else: + # audio_bytes = ( + # (chunk["talker_wav"].float().detach().cpu().numpy() * 32768) + # .astype(np.int16) + # .tobytes() + # ) + yield {"text": text, "audio_wav": chunk["talker_wav"]} + except Exception as e: + tb_str = traceback.format_exc() + logging.error(f"Exception: {e}; traceback: {tb_str}") + + session_chat_history.append( + {"role": "assistant", "content": [{"text": gen_assistant_text}]} + ) + self.chat_history_dict.set(session.ctx.client_id, session_chat_history) + + @torch.no_grad() + def thinker_stream( + self, + inputs: dict, + use_audio_in_video: bool = False, + thinker_top_k: int = 40, + thinker_top_p: float = 0.8, + thinker_temperature: float = 0.9, + thinker_repetition_penalty: float = 1.05, + thinker_min_new_tokens: int = 1, + thinker_max_new_tokens: int = 1024, + ): + streamer = TextIteratorStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True) + + generation_kwargs = dict( + **inputs, + streamer=streamer, + use_audio_in_video=use_audio_in_video, + return_audio=False, + thinker_do_sample=True if thinker_temperature > 0.0 else False, + temperature=thinker_temperature, + top_k=thinker_top_k, + top_p=thinker_top_p, + repetition_penalty=thinker_repetition_penalty, + min_new_tokens=thinker_min_new_tokens, + max_new_tokens=thinker_max_new_tokens, + ) + thread = Thread(target=self._model.generate, kwargs=generation_kwargs) + thread.start() + + generated_text = "" + times = [] + start_time = perf_counter() + for new_text in streamer: + times.append(perf_counter() - start_time) + generated_text += new_text + if new_text == "": + continue + yield {"text": new_text} + start_time = perf_counter() + + logging.info( + f"thinker generate [{generated_text}] TTFT: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + + +class TransformersManualAudioQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + audio understanding + + - speech -> text + """ + + TAG = [ + "llm_transformers_manual_qwen2_5omni_audio_asr", + "llm_transformers_manual_qwen2_5omni_audio_translation", + "llm_transformers_manual_qwen2_5omni_audio_classification", + ] + + def __init__(self, **args) -> None: + args["disable_talker"] = True + args["init_chat_prompt"] = args.get( + "init_chat_prompt", "You are a speech recognition model" + ) + if self.SELECTED_TAG == "llm_transformers_manual_qwen2_5omni_audio_asr": + args["init_chat_prompt"] = "You are a speech recognition model" + if self.SELECTED_TAG == "llm_transformers_manual_qwen2_5omni_audio_translation": + args["init_chat_prompt"] = "You are a speech translation model" + if self.SELECTED_TAG == "llm_transformers_manual_qwen2_5omni_audio_classification": + args["init_chat_prompt"] = "You are a voice classification model." + + super().__init__(**args) + + @torch.inference_mode() + def generate(self, session: Session, **kwargs): + for item in super().generate(session, **kwargs): + text = item.pop("text", "") + if text == "": + continue + yield text + + +class TransformersManualVisionQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + vision only, vision understanding + + - vision -> text + """ + + TAG = "llm_transformers_manual_qwen2_5omni_vision" + + def __init__(self, **args) -> None: + args["disable_talker"] = True + args["init_chat_prompt"] = args.get("init_chat_prompt", "You are a helpful assistant.") + super().__init__(**args) + + @torch.inference_mode() + def generate(self, session: Session, **kwargs): + for item in super().generate(session, **kwargs): + yield item["text"] + + +class TransformersManualInstructSpeechQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + text --> thinker lm -> gen hidden stats --> talker lm -> vq codes --> flow -> mel --> bigvgan -> speech + """ + + TAG = "llm_transformers_manual_qwen2_5omni_speech" + + def __init__(self, **args) -> None: + args["disable_talker"] = False + super().__init__(**args) + + @torch.inference_mode() + def generate(self, session: Session, **kwargs): + for item in super().generate(session, **kwargs): + audio_wav = item.pop("audio_wav", None) + yield audio_wav + + +class TransformersManualVisionVoiceQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + vision + speech to speech voice chat + + - vision + speech -> text + speech + """ + + TAG = "llm_transformers_manual_qwen2_5omni_vision_voice" + + def __init__(self, **args) -> None: + args["disable_talker"] = False + super().__init__(**args) + + +class TransformersManualTextVoiceQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + text to speech voice chat + + - text -> text + speech + """ + + TAG = "llm_transformers_manual_qwen2_5omni_text_voice" + + def __init__(self, **args) -> None: + args["disable_talker"] = False + super().__init__(**args) + + +class TransformersManualVoiceQwen2_5OmniLLM(TransformersManualQwen2_5OmniLLM): + """ + speech to speech voice chat + + - speech -> text + speech + """ + + TAG = "llm_transformers_manual_qwen2_5omni_audio_voice" + + def __init__(self, **args) -> None: + args["disable_talker"] = False + super().__init__(**args) diff --git a/src/core/llm/transformers/models/qwen2_5_omni.py b/src/core/llm/transformers/models/qwen2_5_omni.py new file mode 100644 index 00000000..69e107bd --- /dev/null +++ b/src/core/llm/transformers/models/qwen2_5_omni.py @@ -0,0 +1,941 @@ +import logging +from threading import Thread +from time import perf_counter +from typing import Generator, Optional, Callable + +import torch + + +try: + from transformers import ( + Qwen2_5OmniForConditionalGeneration, + ) +except ModuleNotFoundError as e: + logging.error(f"Exception: {e}") + logging.error( + "In order to use Qwen2.5Omni, you need to `pip install git+https://github.com/huggingface/transformers`" + ) + raise Exception(f"Missing module: {e}") + + +from src.common.utils.helper import print_model_params +from src.core.llm.transformers.streamer import TokenStreamer + + +class Qwen2_5OmniForConditionalGenerationStreaming(Qwen2_5OmniForConditionalGeneration): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print_model_params(self.thinker, "qwen2.5omni_thinker") + if self.has_talker: + print_model_params(self.talker, "qwen2.5omni_talker") + print_model_params(self.token2wav, "qwen2.5omni_token2wav") + self.token2wav.float() + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_ids: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from self._tokenizer. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-self. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError( + f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}" + ) + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError( + "Qwen2.5-Omni currently does not support batched inference with audio output" + ) + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_ids, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) + if thinker_kwargs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = ( + audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = ( + image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = ( + video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + ) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to( + self.talker.device + ) + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.talker.device + ), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like( + input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device + ), + torch.tensor( + [[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device + ), + torch.tensor( + [[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device + ), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device + ) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + # stream + skip_prompt = kwargs.get("skip_prompt", True) + streamer = TokenStreamer(skip_prompt=skip_prompt) + talker_kwargs = dict( + input_ids=talker_input_ids, + streamer=streamer, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{ + k: (v.to(self.talker.device) if torch.is_tensor(v) else v) + for k, v in talker_kwargs.items() + }, + ) + # logging.debug(talker_kwargs.keys()) + thread = Thread(target=self.talker.generate, kwargs=talker_kwargs) + thread.start() + talker_generate_codes = [] + times = [] + start_time = perf_counter() + for token_id in streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + talker_generate_codes.append(token_id) + logging.info( + f"generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + offset = 0 + if skip_prompt is False: + offset = talker_input_ids.shape[1] + # logging.debug( + # talker_input_ids.shape[1], + # # talker_generate_codes, + # talker_generate_codes[:offset], + # talker_generate_codes[offset:-1], + # ) + talker_generate_codes = torch.tensor( + [talker_generate_codes[offset:-1]], + dtype=torch.long, + device=self.talker.device, + ) + + # logging.debug(f"talker_generate_codes:{talker_generate_codes.shape} {talker_generate_codes}") + + # 3. Generate wavs from code + # logging.debug(self.token2wav.device, speaker_params, token2wav_kwargs) + wav = self.token2wav( + talker_generate_codes.to(self.token2wav.device), + conditioning=speaker_params["cond"].to(self.token2wav.device).float(), + reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + @torch.no_grad() + def thinker_generate_chunk( + self, + inputs: dict, + use_audio_in_video: bool = False, + thinker_max_tokens_per_step=10, # Controls how many tokens to generate *per step* + thinker_max_new_tokens: int = 1024, + thinker_top_k: int = 40, + thinker_top_p: float = 0.8, + thinker_temperature: float = 0.9, + thinker_eos_token_ids=[151644, 151645], # Define EOS tokens + thinker_repetition_penalty: float = 1.05, + thinker_output_hidden_states=False, + thinker_stop_strings_per_step=[], + tokenizer=None, + **kwargs, + ): + input_ids = inputs.pop("input_ids") + attention_mask = inputs.pop("attention_mask", None) + + if thinker_max_tokens_per_step > thinker_max_new_tokens: + thinker_max_tokens_per_step = thinker_max_new_tokens + + # Keep track of the full generated sequence full_generated_ids = input_ids.clone() + # Ensure full_attention_mask is correctly initialized and expanded + full_attention_mask = ( + attention_mask.clone() + if attention_mask is not None + else torch.ones_like(input_ids, device=input_ids.device) + ) + full_generated_ids = input_ids.clone() + + # KV cache + # past_key_values = None + + # Inputs for the current step + current_input_ids = full_generated_ids + # The attention mask passed to generate should cover the sequence length for the current step + current_attention_mask = full_attention_mask + + total_new_tokens_generated = 0 + hidden_states = None + hidden_states_len = 0 + + times = [] + while total_new_tokens_generated < thinker_max_new_tokens: + # Prepare inputs for generate call + # logging.debug(current_input_ids, current_attention_mask.shape) + model_inputs = { + "input_ids": current_input_ids, + "attention_mask": current_attention_mask, + # "past_key_values": past_key_values, + "use_cache": True, + "use_audio_in_video": use_audio_in_video, + "do_sample": True if thinker_temperature > 0 else False, + "top_k": thinker_top_k, + "top_p": thinker_top_p, + "temperature": thinker_temperature, + "repetition_penalty": thinker_repetition_penalty, + "min_new_tokens": 1, # Ensure at least one token is generated if possible + "max_new_tokens": thinker_max_tokens_per_step, # Generate in smaller steps + # output_hidden_states/scores can consume memory, + # enable if needed downstream(talker) + "output_hidden_states": thinker_output_hidden_states, + "return_dict_in_generate": True, + # "output_scores": True, + "eos_token_id": thinker_eos_token_ids, + "pad_token_id": kwargs.get("thinker_pad_token_id", 151643), + } + model_inputs = {**inputs, **model_inputs} + if len(thinker_stop_strings_per_step) > 0: + model_inputs["stop_strings"] = thinker_stop_strings_per_step + model_inputs["tokenizer"] = tokenizer + + start_time = perf_counter() + outputs = self.thinker.generate(**model_inputs) + times.append(perf_counter() - start_time) + + # Extract newly generated token IDs *for this step* + # `outputs.sequences` contains the input_ids for this step + new tokens generated in this step + step_new_ids = outputs.sequences[:, current_input_ids.shape[1] :] + num_step_new_tokens = step_new_ids.shape[1] + + if num_step_new_tokens == 0: # Handle case where generate stops early + logging.warning("Warning: generate produced 0 new tokens in this step.") + break + + if thinker_output_hidden_states is True: + hidden_states = outputs.hidden_states + hidden_states_len = ( + hidden_states_len if hidden_states_len > 0 else hidden_states[0][0].shape[1] + ) + logging.debug(f"hidden_states_len: {hidden_states_len}") + # new generate thinker_token_embeds + thinker_new_token_embeds = hidden_states[0][0][:, :hidden_states_len, :] + hidden_states = ( + (thinker_new_token_embeds,) + hidden_states[0][1:], + ) + hidden_states[1:] + # new generate thinker_hidden_states + thinker_new_hidden_states = hidden_states[0][-1][:, :hidden_states_len, :] + hidden_states = ( + hidden_states[0][:-1] + (thinker_new_hidden_states,), + ) + hidden_states[1:] + + yield { + "thinker_generate_ids": step_new_ids, + "thinker_generate_hidden_states": hidden_states, + } + total_new_tokens_generated += num_step_new_tokens + + # Update the full sequence + full_generated_ids = torch.cat([full_generated_ids, step_new_ids], dim=1) + + # Prepare for the next iteration: + # Input is only the last generated token + # NOTE: need use past_key_values to keep the context by manually, + # current_input_ids = step_new_ids[:, -1:] + # so we can't use the last generated token, use cache instead + # input ids need to be the full sequence for next generation + current_input_ids = full_generated_ids + + # Update past_key_values + # past_key_values = outputs.past_key_values + + # Update attention mask by appending 1s for the new tokens + full_attention_mask = torch.cat( + [full_attention_mask, torch.ones_like(step_new_ids)], dim=1 + ) + current_attention_mask = full_attention_mask + + # Check if EOS token was generated in this step + if step_new_ids[0, -1].item() in thinker_eos_token_ids: + logging.info("EOS token generated.") + break + + # Check if max_new_tokens limit is reached (after processing the step) + if total_new_tokens_generated >= thinker_max_new_tokens: + logging.info("Max new tokens limit reached.") + break + + logging.info( + f"Total new tokens generated: {total_new_tokens_generated} | thinker_max_tokens_per_step: {thinker_max_tokens_per_step} | first chunk generated cost: {times[0]} s | total cost: {sum(times)} s" + ) + + @torch.no_grad() + def talker_generate_chunk( + self, + inputs: dict, + thinker_chunk_stream, + speaker: str = "Chelsie", + talker_eos_token_ids: list[int] = [8292, 8294], + talker_top_k: int = 10, + talker_top_p: float = 0.9, + talker_temperature: float = 0.95, + talker_repetition_penalty: float = 1.1, + talker_min_new_tokens: int = 0, + talker_max_new_tokens: int = 8192, + talker_skip_thinker_token_ids: list[int] = [], # skip tokens don't to talk + code2wav_num_steps: int = 10, + code2wav_guidance_scale: float = 0.5, + code2wav_sway_coefficient: float = -1.0, + code2wav_chunk_stream_func: Callable = None, + mask_embedding: bool = True, + ) -> Generator[dict, None, None]: + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", None) + + for chunk in thinker_chunk_stream: + thinker_generate_ids = chunk["thinker_generate_ids"].to(self.talker.device) + # skip talk + if ( + thinker_generate_ids.shape[1] == 1 + and thinker_generate_ids[0, -1].item() in talker_skip_thinker_token_ids + ): + logging.info(f"skip token {thinker_generate_ids} to talk") + yield {"thinker_ids": thinker_generate_ids, "talker_wav": torch.empty([0])} + continue + thinker_generate_hidden_states = chunk["thinker_generate_hidden_states"] + if thinker_generate_hidden_states is None or len(thinker_generate_hidden_states) < 2: + if len(thinker_generate_hidden_states) < 2: + logging.warning( + f"thinker_generate_ids: {thinker_generate_ids} | len(thinker_generate_hidden_states): {len(thinker_generate_hidden_states)} < 2" + ) + yield {"thinker_ids": thinker_generate_ids, "talker_wav": torch.empty([0])} + continue + + processed_thinker_hidden = thinker_generate_hidden_states + if mask_embedding is True: + logging.info("mask embedding") + embeds_to_talker = ( + thinker_generate_hidden_states[0][0].clone().to(self.talker.device) + ) + if inputs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = ( + audio_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if inputs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = ( + image_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if inputs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = ( + video_ids_mask.unsqueeze(-1) + .expand_as(embeds_to_talker) + .to(embeds_to_talker.device) + ) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_generate_hidden_states[0][1:], + ) + thinker_generate_hidden_states[1:] + + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) + for token_hidden_states in processed_thinker_hidden + ] + logging.debug( + f"len(thinker_generate_hidden_states):{len(thinker_generate_hidden_states)}" + ) + for i in range(len(thinker_generate_hidden_states)): + logging.debug( + f"thinker_generate_hidden_states[{i}]:{thinker_generate_hidden_states[i][0].shape}, {thinker_generate_hidden_states[i][-1].shape}" + ) + + talker_text_bos_token = self.speaker_map[speaker]["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.talker.device + ), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + logging.debug(f"talker_input_text_ids.shape:{talker_input_text_ids.shape}") + + talker_input_ids = torch.cat( + [ + torch.full_like( + input_ids, + fill_value=self.talker.codec_mask_token, + device=self.talker.device, + ), + torch.tensor( + [[self.talker.codec_pad_token]], + dtype=torch.long, + device=self.talker.device, + ), + torch.tensor( + [[self.talker.codec_bos_token]], + dtype=torch.long, + device=self.talker.device, + ), + ], + dim=1, + ) + logging.debug(f"talker_input_ids.shape:{talker_input_ids.shape}") + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device + ) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to( + self.talker.device + ) + logging.debug( + f"talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape} thinker_reply_part.shape {thinker_reply_part.shape}" + ) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + logging.debug( + f"talker_inputs_embeds.shape {talker_inputs_embeds.shape} talker_text_bos_embed.shape {talker_text_bos_embed.shape}" + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor( + [[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device + ) + ).to(self.talker.device) + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + logging.debug(f"thinker_reply_part.shape:{thinker_reply_part.shape}") + + talker_attention_mask = None + if attention_mask is not None: + talker_attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + streamer = TokenStreamer(skip_prompt=True) + talker_kwargs = dict( + input_ids=talker_input_ids, + streamer=streamer, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + eos_token_id=talker_eos_token_ids, + pad_token_id=8292, + do_sample=True if talker_temperature > 0.0 else False, + top_k=talker_top_k, + top_p=talker_top_p, + temperature=talker_temperature, + repetition_penalty=talker_repetition_penalty, + min_new_tokens=talker_min_new_tokens, + max_new_tokens=talker_max_new_tokens, + ) + # logging.debug(talker_kwargs.keys()) + thread = Thread(target=self.talker.generate, kwargs=talker_kwargs) + thread.start() + + code2wav_chunk_stream_func = code2wav_chunk_stream_func or self.code2wav_chunk_stream + # Generate wavs from code + code2wav_stream = code2wav_chunk_stream_func( + talker_streamer=streamer, + speaker=speaker, + talker_eos_token_ids=talker_eos_token_ids, + code2wav_num_steps=code2wav_num_steps, + code2wav_guidance_scale=code2wav_guidance_scale, + code2wav_sway_coefficient=code2wav_sway_coefficient, + ) + + for wav in code2wav_stream: + yield {"thinker_ids": thinker_generate_ids, "talker_wav": wav} + + @torch.no_grad() + def code2wav_chunk_stream( + self, + talker_streamer: TokenStreamer, + speaker: str = "Chelsie", + talker_eos_token_ids: list[int] = [8292, 8294], + code2wav_num_steps: int = 10, + code2wav_guidance_scale: float = 0.5, + code2wav_sway_coefficient: float = -1.0, + ) -> Generator[torch.Tensor, None, None]: + """ + fixed chunk stream + """ + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + code2wav_times = [] + talker_generate_codes = [] + times = [] + start_time = perf_counter() + pre_offset = 0 + for token_id in talker_streamer: + times.append(perf_counter() - start_time) + start_time = perf_counter() + if token_id in talker_eos_token_ids: + break + talker_generate_codes.append(token_id) + chunk_code_length = len(talker_generate_codes) * 2 - 24 + if chunk_code_length > 0 and chunk_code_length % 48 == 0: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=self.talker.device, + ) + pre_offset = len(talker_generate_codes) + wav = self.token2wav( + codes_tensor.to(self.token2wav.device), + conditioning=self.speaker_map[speaker]["cond"] + .to(self.token2wav.device) + .float(), + reference_mel=self.speaker_map[speaker]["ref_mel"] + .to(self.token2wav.device) + .float(), + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ).detach() + code2wav_times.append(perf_counter() - start_time) + yield wav # (T,) + start_time = perf_counter() + + logging.info( + f"talker generate first token cost time: {times[0]} s, {len(times)} tokens cost time: {sum(times)} s" + ) + + if len(talker_generate_codes) > pre_offset: + codes_tensor = torch.tensor( + [talker_generate_codes[pre_offset:]], + dtype=torch.long, + device=self.talker.device, + ) + wav = self.token2wav( + codes_tensor.to(self.token2wav.device), + conditioning=self.speaker_map[speaker]["cond"].to(self.token2wav.device).float(), + reference_mel=self.speaker_map[speaker]["ref_mel"] + .to(self.token2wav.device) + .float(), + num_steps=code2wav_num_steps, + guidance_scale=code2wav_guidance_scale, + sway_coefficient=code2wav_sway_coefficient, + ).detach() + code2wav_times.append(perf_counter() - start_time) + yield wav # (T,) + + logging.info( + f"code2wav streaming first chunk time: {code2wav_times[0]} s | cost: {sum(code2wav_times)} s" + ) + + @torch.no_grad() + def generate_stream( + self, + inputs: dict, + use_audio_in_video: bool = False, + thinker_max_tokens_per_step=10, # Controls how many tokens to generate *per step* + thinker_max_new_tokens: int = 1024, + thinker_top_k: int = 40, + thinker_top_p: float = 0.8, + thinker_temperature: float = 0.9, + thinker_repetition_penalty: float = 1.05, + thinker_eos_token_ids=[151644, 151645], + thinker_stop_strings_per_step=[], + tokenizer=None, + return_audio=True, + speaker="Chelsie", + talker_top_k: int = 10, + talker_top_p: float = 0.9, + talker_temperature: float = 0.95, + talker_repetition_penalty: float = 1.1, + talker_min_new_tokens: int = 0, + talker_max_new_tokens: int = 8192, + talker_eos_token_ids: list[int] = [8292, 8294], + talker_skip_thinker_token_ids: list[int] = [], + code2wav_num_steps: int = 10, + code2wav_guidance_scale: float = 0.5, + code2wav_sway_coefficient: float = -1.0, + code2wav_chunk_stream_func: Callable = None, + mask_embedding: bool = True, + **kwargs, + ) -> Generator[dict, None, None]: + """ + - return Generator[dict, None, None] + { + "thinker_ids": torch.Tensor, # (1,T) + "talker_wav": torch.Tensor, # (T,) + } + """ + thinker_chunk_stream = self.thinker_generate_chunk( + inputs, + use_audio_in_video=use_audio_in_video, + thinker_max_tokens_per_step=thinker_max_tokens_per_step, + thinker_max_new_tokens=thinker_max_new_tokens, + thinker_top_k=thinker_top_k, + thinker_top_p=thinker_top_p, + thinker_temperature=thinker_temperature, + thinker_eos_token_ids=thinker_eos_token_ids, + thinker_repetition_penalty=thinker_repetition_penalty, + thinker_output_hidden_states=return_audio, + thinker_stop_strings_per_step=thinker_stop_strings_per_step, + tokenizer=tokenizer, + **kwargs, + ) + if not return_audio: + for thinker_chunk in thinker_chunk_stream: + yield {"thinker_ids": thinker_chunk["thinker_generate_ids"]} + else: + talker_streamer = self.talker_generate_chunk( + inputs, + thinker_chunk_stream=thinker_chunk_stream, + speaker=speaker, + talker_eos_token_ids=talker_eos_token_ids, + talker_top_k=talker_top_k, + talker_top_p=talker_top_p, + talker_temperature=talker_temperature, + talker_repetition_penalty=talker_repetition_penalty, + talker_min_new_tokens=talker_min_new_tokens, + talker_max_new_tokens=talker_max_new_tokens, + talker_skip_thinker_token_ids=talker_skip_thinker_token_ids, + code2wav_num_steps=code2wav_num_steps, + code2wav_guidance_scale=code2wav_guidance_scale, + code2wav_sway_coefficient=code2wav_sway_coefficient, + code2wav_chunk_stream_func=code2wav_chunk_stream_func, + mask_embedding=mask_embedding, + ) + + for talker_chunk in talker_streamer: + yield talker_chunk + + @torch.no_grad() + def thinker_all_talker_stream( + self, + inputs: dict, + use_audio_in_video: bool = False, + thinker_max_new_tokens: int = 1024, + thinker_top_k: int = 40, + thinker_top_p: float = 0.8, + thinker_temperature: float = 0.9, + thinker_repetition_penalty: float = 1.05, + thinker_eos_token_ids=[151644, 151645], + thinker_stop_strings_per_step=[], + tokenizer=None, + return_audio=True, + speaker="Chelsie", + talker_top_k: int = 10, + talker_top_p: float = 0.9, + talker_temperature: float = 0.95, + talker_repetition_penalty: float = 1.1, + talker_min_new_tokens: int = 0, + talker_max_new_tokens: int = 8192, + talker_eos_token_ids: list[int] = [8292, 8294], + talker_skip_thinker_token_ids: list[int] = [], + code2wav_num_steps: int = 10, + code2wav_guidance_scale: float = 0.5, + code2wav_sway_coefficient: float = -1.0, + code2wav_chunk_stream_func: Callable = None, + mask_embedding: bool = True, + **kwargs, + ) -> Generator[dict, None, None]: + """ + - return Generator[dict, None, None] + { + "thinker_ids": torch.Tensor, # (1,T) + "talker_wav": torch.Tensor, # (T,) + } + """ + + def to_generator(item): + yield item + + thinker_result = self.thinker.generate( + **inputs, + use_audio_in_video=use_audio_in_video, + do_sample=True if thinker_temperature > 0 else False, + top_k=thinker_top_k, + top_p=thinker_top_p, + temperature=thinker_temperature, + repetition_penalty=thinker_repetition_penalty, + min_new_tokens=1, + max_new_tokens=thinker_max_new_tokens, + eos_token_id=thinker_eos_token_ids, + output_hidden_states=return_audio, + return_dict_in_generate=True, + ) + input_ids = inputs["input_ids"] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :] + if not return_audio: + yield {"thinker_ids": thinker_generate_ids} + else: + talker_streamer = self.talker_generate_chunk( + inputs, + thinker_chunk_stream=to_generator( + { + "thinker_generate_ids": thinker_generate_ids, + "thinker_generate_hidden_states": thinker_result.hidden_states, + } + ), + speaker=speaker, + talker_eos_token_ids=talker_eos_token_ids, + talker_top_k=talker_top_k, + talker_top_p=talker_top_p, + talker_temperature=talker_temperature, + talker_repetition_penalty=talker_repetition_penalty, + talker_min_new_tokens=talker_min_new_tokens, + talker_max_new_tokens=talker_max_new_tokens, + talker_skip_thinker_token_ids=talker_skip_thinker_token_ids, + code2wav_num_steps=code2wav_num_steps, + code2wav_guidance_scale=code2wav_guidance_scale, + code2wav_sway_coefficient=code2wav_sway_coefficient, + code2wav_chunk_stream_func=code2wav_chunk_stream_func, + mask_embedding=mask_embedding, + ) + + for talker_chunk in talker_streamer: + yield talker_chunk diff --git a/src/modules/speech/asr/__init__.py b/src/modules/speech/asr/__init__.py index ffafa5c2..f6610896 100644 --- a/src/modules/speech/asr/__init__.py +++ b/src/modules/speech/asr/__init__.py @@ -14,6 +14,8 @@ class ASREnvInit: @staticmethod def getEngine(tag, **kwargs) -> interface.IAsr | EngineClass: + if "qwen2_5omni_asr" in tag: + from . import qwen2_5omni_asr if "minicpmo_asr" in tag: from . import minicpmo_asr if "sense_voice" in tag: @@ -57,6 +59,12 @@ def get_asr_minicpmo_args() -> dict: kwargs["use_gptq_ckpt"] = bool(os.getenv("USE_GPTQ_CKPT", "")) return kwargs + @staticmethod + def get_asr_qwen2_5omni_args() -> dict: + kwargs = LLMEnvInit.get_qwen2_5omni_transformers_args() + return kwargs + map_config_func = { "minicpmo_asr": get_asr_minicpmo_args, + "qwen2_5omni_asr": get_asr_qwen2_5omni_args, } diff --git a/src/modules/speech/asr/qwen2_5omni_asr.py b/src/modules/speech/asr/qwen2_5omni_asr.py new file mode 100644 index 00000000..738eed29 --- /dev/null +++ b/src/modules/speech/asr/qwen2_5omni_asr.py @@ -0,0 +1,56 @@ +import os +from typing import AsyncGenerator +import asyncio +import re + +import librosa + +from src.core.llm.transformers.manual_vision_voice_qwen import TransformersManualAudioQwen2_5OmniLLM +from src.common.utils.audio_utils import bytes2NpArrayWith16 +from src.common.session import Session +from src.modules.speech.asr.base import ASRBase + + +class Qwen2_5OmniAsr(ASRBase): + TAG = "qwen2_5omni_asr" + + @classmethod + def get_args(cls, **kwargs) -> dict: + return kwargs + + def __init__(self, **args) -> None: + args["init_chat_prompt"] = "You are a speech recognition model" + self.model = TransformersManualAudioQwen2_5OmniLLM(**args) + self.args = args + + def set_audio_data(self, audio_data): + if isinstance(audio_data, (bytes, bytearray)): + self.asr_audio = bytes2NpArrayWith16(audio_data) + if isinstance(audio_data, str): # path + audio_nparr, _ = librosa.load(audio_data, sr=16000, mono=True) + self.asr_audio = audio_nparr + + async def transcribe_stream(self, session: Session) -> AsyncGenerator[str, None]: + prompt = [ + {"type": "text", "text": "请将这段中文语音转换为纯文本"}, + {"type": "audio", "audio": self.asr_audio}, + ] + session.ctx.state["prompt"] = session.ctx.state.get("prompt", prompt) + transcription = self.model.generate(session) + for item in transcription: + if "text" in item: + clean_text = re.sub(r"<\|.*?\|>", "", item["text"]) + yield clean_text + + async def transcribe(self, session: Session) -> dict: + res = "" + async for text in self.transcribe_stream(session): + res += text + + res = { + "language": self.args.get("language", "zh"), + "language_probability": None, + "text": res, + "words": [], + } + return res diff --git a/src/processors/audio_camera_output_processor.py b/src/processors/audio_camera_output_processor.py index 72a58adc..c7613582 100644 --- a/src/processors/audio_camera_output_processor.py +++ b/src/processors/audio_camera_output_processor.py @@ -51,6 +51,10 @@ def __init__( # away and we would lose the TTSStoppedFrame. self._bot_speaking = False + # Audio queue and task + self._audio_out_queue = None + self._audio_out_task = None + async def start(self, frame: StartFrame): await super().start(frame) # Create media threads queues and task @@ -59,6 +63,9 @@ async def start(self, frame: StartFrame): self._camera_out_task = self.get_event_loop().create_task( self._camera_out_task_handler() ) + if self._params.audio_out_enabled: + self._audio_out_queue = asyncio.Queue() + self._audio_out_task = self.get_event_loop().create_task(self._audio_out_task_handler()) async def stop(self, frame: EndFrame): await super().stop(frame) @@ -66,12 +73,18 @@ async def stop(self, frame: EndFrame): if self._params.camera_out_enabled: self._camera_out_task.cancel() await self._camera_out_task + if self._params.audio_out_enabled: + self._audio_out_task.cancel() + await self._audio_out_task async def cancel(self, frame: CancelFrame): await super().cancel(frame) if self._params.camera_out_enabled: self._camera_out_task.cancel() await self._camera_out_task + if self._params.audio_out_enabled: + self._audio_out_task.cancel() + await self._audio_out_task async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -149,15 +162,26 @@ async def _handle_audio(self, frame: AudioRawFrame): chunk = audio[i : i + self._audio_chunk_size] # if len(chunk) % 2 != 0: don't do that, need subclass to do # chunk = chunk[:len(chunk) - 1] - await self.write_raw_audio_frames(chunk) + await self._audio_out_queue.put(chunk) await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM) # self._audio_out_buff.clear() async def write_raw_audio_frames(self, frames: bytes): """ - subcalss audio output stream transport to impl + Subclass audio output stream transport to implement. """ - pass + logging.info(f"no subclass implement, dump len(frames): {len(frames)}") + + async def _audio_out_task_handler(self): + while True: + try: + chunk = await self._audio_out_queue.get() + await self.write_raw_audio_frames(chunk) + self._audio_out_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logging.exception(f"{self} error writing audio: {e}") # # Camera out @@ -183,9 +207,9 @@ async def _set_camera_images(self, images: List[ImageRawFrame]): async def write_frame_to_camera(self, frame: ImageRawFrame): """ - subcalss camera output stream transport to impl + Subclass camera output stream transport to implement. """ - pass + logging.info(f"no subclass implement, dump frame: {frame}") async def _draw_image(self, frame: ImageRawFrame): desired_size = (self._params.camera_out_width, self._params.camera_out_height) diff --git a/src/processors/omni/base.py b/src/processors/omni/base.py index dc4c4c59..998aab77 100644 --- a/src/processors/omni/base.py +++ b/src/processors/omni/base.py @@ -1,5 +1,8 @@ +import asyncio import logging from abc import abstractmethod +import queue +import threading from typing import AsyncGenerator import uuid @@ -11,14 +14,13 @@ from src.common.interface import ILlm from src.common.session import Session from src.common.types import CHANNELS, RATE, SessionCtx -from src.types.frames.data_frames import Frame, VisionImageVoiceRawFrame +from src.types.frames.data_frames import Frame, VisionImageVoiceRawFrame, PathAudioRawFrame from src.processors.ai_processor import AsyncAIProcessor class VisionVoiceProcessorBase(AsyncAIProcessor): """ VisionVoiceProcessorBase is a base class for vision+voice processors. - j input: vision + voice frame use omni lm to process vision + voice frames output: text+audio frame @@ -28,25 +30,47 @@ def __init__( self, llm: ILlm | EngineClass | None = None, session: Session | None = None, + no_stream_sleep_time: float = 0.5, **kwargs, ): super().__init__(**kwargs) + assert llm is not None, "llm is required" self._llm = llm self._session = session if self._session is None: self._session = Session(**SessionCtx(uuid.uuid4()).__dict__) + self._queue = queue.Queue() + self._input_queue = queue.Queue() + self._generate_thread = None + self._sleep_time = no_stream_sleep_time @property def stream_info(self) -> dict: """Return dict out stream info""" return {"sample_rate": RATE, "channels": CHANNELS} + def _generate(self): + while True: + try: + session = self._input_queue.get() + if session is None: + self._queue.put(None) # Signal the end of the stream + break # Signal to stop the thread + tensor_audio_stream = self._llm.generate(session) + for item in tensor_audio_stream: + self._queue.put(item) + self._queue.put(None) # Signal the end of the stream + except Exception as e: + logging.error(f"Exception generate: {e}", exc_info=True) + self._queue.put(None) # Signal the end of the stream + break + @abstractmethod async def run(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield frame async def say(self, text: str): - pass + logging.info(f"say: {text}") async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -61,41 +85,76 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): async def start(self, frame: StartFrame): await super().start(frame) self._create_push_task() + self._generate_thread = threading.Thread(target=self._generate) + self._generate_thread.start() logging.info("start done") async def stop(self, frame: EndFrame): await super().stop(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish logging.info("stop done") async def cancel(self, frame: CancelFrame): await super().cancel(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish logging.info("cancel done") + def send_input(self, session: Session): + self._input_queue.put(session) + async def gen(self) -> AsyncGenerator[Frame, None]: - """ - - gen tensor audio streamer - = push text, audio frame - """ - tensor_audio_stream = self._llm.generate(self._session) - for item in tensor_audio_stream: - logging.debug(f"generate data: {item}") - tensor_audio = item.pop("audio_wav", None) - rate = item.pop("sampling_rate", RATE) - text = item.pop("text", "").strip() - if text != "": - await self.push_frame(TextFrame(text=text)) - - if tensor_audio is not None: # don't use if tensor_audio to check - audio_bytes = ( - (tensor_audio.float().detach().cpu().numpy() * 32768).astype(np.int16).tobytes() - ) - logging.debug( - f"audio tensor:{tensor_audio.shape},push audio len:{len(audio_bytes)}" - ) - await self.queue_frame( - AudioRawFrame( - audio=audio_bytes, - sample_rate=rate, + while True: + try: + item = self._queue.get_nowait() + if item is None: + break # End of the stream + logging.debug(f"generate data: {item}") + tensor_audio = item.pop("audio_wav", None) + text = item.pop("text", "").strip() + if text != "": + await self.push_frame(TextFrame(text=text)) + + if tensor_audio is not None: # don't use if tensor_audio to check + audio_bytes = ( + (tensor_audio.float().detach().cpu().numpy() * 32768) + .astype(np.int16) + .tobytes() ) - ) - yield None + logging.info( + f"audio tensor:{tensor_audio.shape},push audio len:{len(audio_bytes)}" + ) + await self.push_frame( + AudioRawFrame( + audio=audio_bytes, + sample_rate=self.stream_info["sample_rate"], + num_channels=self.stream_info["channels"], + ) + ) + yield None + except queue.Empty: + # yield asysncio.sleep to allow other tasks to run, e.g.: sink task (write audio) + await asyncio.sleep(self._sleep_time) + continue + + +class MockVisionVoiceProcessor(VisionVoiceProcessorBase): + async def run(self, frame: VisionImageVoiceRawFrame) -> AsyncGenerator[Frame, None]: + logging.debug(f"VisionImageVoiceRawFrame: {frame}") + self._session.ctx.state["prompt"] = [] + + # frame.text and self._session.ctx.state["prompt"].append(frame.text) + + if frame.text: + yield TextFrame(text=f"{frame.text}") + + if frame.images: + for image_frame in frame.images: + yield TextFrame(text=f"{image_frame}") + + if frame.audio and frame.audio.audio: + yield AudioRawFrame( + audio=frame.audio.audio, + sample_rate=frame.audio.sample_rate, + ) diff --git a/src/processors/omni/minicpmo_vision_voice.py b/src/processors/omni/minicpmo_vision_voice.py index ef0397ad..ef1b411b 100644 --- a/src/processors/omni/minicpmo_vision_voice.py +++ b/src/processors/omni/minicpmo_vision_voice.py @@ -22,10 +22,14 @@ def __init__( self, *, session: Session | None = None, + no_stream_sleep_time: float = 0.5, **kwargs, ): super().__init__( - llm=TransformersManualVisionVoiceMiniCPMO(**kwargs), session=session, **kwargs + llm=TransformersManualVisionVoiceMiniCPMO(**kwargs), + session=session, + no_stream_sleep_time=no_stream_sleep_time, + **kwargs, ) @property @@ -59,5 +63,6 @@ async def run(self, frame: VisionImageVoiceRawFrame) -> AsyncGenerator[Frame, No audio_nparr = bytes2NpArrayWith16(frame.audio.audio) self._session.ctx.state["prompt"].append(audio_nparr) + self.send_input(self._session) async for item in self.gen(): yield item diff --git a/src/processors/omni/qwen2_5omni_vision_voice.py b/src/processors/omni/qwen2_5omni_vision_voice.py new file mode 100644 index 00000000..d0da7853 --- /dev/null +++ b/src/processors/omni/qwen2_5omni_vision_voice.py @@ -0,0 +1,68 @@ +import logging +from typing import AsyncGenerator + +from PIL import Image +from apipeline.frames import * +import librosa +import numpy as np + +from src.common.session import Session +from src.core.llm.transformers.manual_vision_voice_qwen import ( + TransformersManualVisionVoiceQwen2_5OmniLLM, +) +from src.common.utils.audio_utils import bytes2NpArrayWith16 +from src.processors.omni.base import VisionVoiceProcessorBase +from src.types.frames.data_frames import PathAudioRawFrame, VisionImageVoiceRawFrame + + +class Qwen2_5OmnVisionVoiceProcessor(VisionVoiceProcessorBase): + """ """ + + def __init__( + self, + *, + session: Session | None = None, + no_stream_sleep_time: float = 0.5, + **kwargs, + ): + super().__init__( + llm=TransformersManualVisionVoiceQwen2_5OmniLLM(**kwargs), + session=session, + no_stream_sleep_time=no_stream_sleep_time, + **kwargs, + ) + + @property + def stream_info(self) -> dict: + """Return dict out stream info""" + return { + "sample_rate": TransformersManualVisionVoiceQwen2_5OmniLLM.RATE, + "channels": 1, + } + + async def run(self, frame: VisionImageVoiceRawFrame) -> AsyncGenerator[Frame, None]: + if not self._llm: + logging.error(f"{self} error: llm not available") + yield ErrorFrame("llm not available") + return + + logging.debug(f"VisionImageVoiceRawFrame: {frame}") + self._session.ctx.state["prompt"] = [] + + # frame.text and self._session.ctx.state["prompt"].append(frame.text) + + if frame.images: + for image_frame in frame.images: + image = Image.frombytes(image_frame.mode, image_frame.size, image_frame.image) + self._session.ctx.state["prompt"].append({"type": "image", "image": image}) + + if frame.audio and frame.audio.audio: + if isinstance(frame.audio, PathAudioRawFrame): + audio_nparr, _ = librosa.load(frame.audio.path, sr=16000, mono=True) + else: + audio_nparr = bytes2NpArrayWith16(frame.audio.audio) + self._session.ctx.state["prompt"].append({"type": "audio", "audio": audio_nparr}) + + self.send_input(self._session) + async for item in self.gen(): + yield item diff --git a/src/processors/voice/qwen2_5omni_voice_processor.py b/src/processors/voice/qwen2_5omni_voice_processor.py new file mode 100644 index 00000000..b3ed8902 --- /dev/null +++ b/src/processors/voice/qwen2_5omni_voice_processor.py @@ -0,0 +1,176 @@ +import asyncio +import io +import uuid +import logging +import threading +import queue +from typing import AsyncGenerator + +import librosa +import numpy as np +from apipeline.frames import * + +from src.core.llm.transformers.manual_vision_voice_qwen import ( + TransformersManualQwen2_5OmniLLM, + TransformersManualVoiceQwen2_5OmniLLM, + TransformersManualTextVoiceQwen2_5OmniLLM, +) + +from src.processors.voice.base import VoiceProcessorBase +from src.common.session import Session +from src.common.types import SessionCtx +from src.common.utils.audio_utils import ( + bytes2NpArrayWith16, +) +from src.types.frames import PathAudioRawFrame + + +class Qwen2_5OmniVoiceProcessor(VoiceProcessorBase): + def __init__( + self, + *, + session: Session | None = None, + no_stream_sleep_time: float = 0.5, + **kwargs, + ): + super().__init__(**kwargs) + self._session = session or Session(**SessionCtx(uuid.uuid4()).__dict__) + self._model: TransformersManualQwen2_5OmniLLM = None + self._queue = queue.Queue() + self._input_queue = queue.Queue() + self._generate_thread = None + self._sleep_time = no_stream_sleep_time + + @property + def stream_info(self) -> dict: + """Return dict out stream info""" + return { + "sample_rate": TransformersManualQwen2_5OmniLLM.RATE, + "channels": 1, + } + + async def say(self, text: str): + logging.info(f"say: {text}") + + def _generate(self): + while True: + try: + session = self._input_queue.get() + if session is None: + self._queue.put(None) # Signal the end of the stream + break # Signal to stop the thread + tensor_audio_stream = self._model.generate(session) + for item in tensor_audio_stream: + self._queue.put(item) + self._queue.put(None) # Signal the end of the stream + except Exception as e: + logging.error(f"Exception generate: {e}", exc_info=True) + self._queue.put(None) # Signal the end of the stream + break + + async def start(self, frame: StartFrame): + await super().start(frame) + self._generate_thread = threading.Thread(target=self._generate) + self._generate_thread.start() + logging.info("start done") + + async def stop(self, frame: EndFrame): + await super().stop(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish + logging.info("stop done") + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish + logging.info("cancel done") + + async def gen(self) -> AsyncGenerator[Frame, None]: + while True: + try: + item = self._queue.get_nowait() + if item is None: + break # End of the stream + logging.debug(f"generate data: {item}") + tensor_audio = item.pop("audio_wav", None) + text = item.pop("text", "").strip() + if text != "": + await self.push_frame(TextFrame(text=text)) + + if tensor_audio is not None: # don't use if tensor_audio to check + audio_bytes = ( + (tensor_audio.float().detach().cpu().numpy() * 32768) + .astype(np.int16) + .tobytes() + ) + logging.debug( + f"audio tensor:{tensor_audio.shape},push audio len:{len(audio_bytes)}" + ) + await self.push_frame( + AudioRawFrame( + audio=audio_bytes, + sample_rate=self._model.RATE, + ) + ) + yield None + except queue.Empty: + # yield asysncio.sleep to allow other tasks to run, e.g.: sink task (write audio) + await asyncio.sleep(self._sleep_time) + continue + + def send_input(self, session: Session): + self._input_queue.put(session) + + +class Qwen2_5OmniAudioVoiceProcessor(Qwen2_5OmniVoiceProcessor): + """ + qwen2.5omni voice + - A1->T2A2 + """ + + def __init__( + self, + *, + session: Session | None = None, + no_stream_sleep_time: float = 0.05, + **kwargs, + ): + super().__init__(session=session, no_stream_sleep_time=no_stream_sleep_time, **kwargs) + + self._model = TransformersManualVoiceQwen2_5OmniLLM(**kwargs) + + async def run_voice(self, frame: AudioRawFrame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, PathAudioRawFrame): + audio_nparr, _ = librosa.load(frame.path, sr=16000, mono=True) + else: + audio_nparr = bytes2NpArrayWith16(frame.audio) + + self._session.ctx.state["prompt"] = [{"type": "audio", "audio": audio_nparr}] + self.send_input(self._session) + async for item in self.gen(): + yield item + + +class Qwen2_5OmniTextVoiceProcessor(Qwen2_5OmniVoiceProcessor): + """ + - T1->T2A2 + """ + + def __init__( + self, + *, + session: Session | None = None, + no_stream_sleep_time: float = 0.05, + **kwargs, + ): + super().__init__(session=session, no_stream_sleep_time=no_stream_sleep_time, **kwargs) + + self._model = TransformersManualTextVoiceQwen2_5OmniLLM(**kwargs) + + async def run_text(self, frame: TextFrame) -> AsyncGenerator[Frame, None]: + user_input = frame.text.strip() + self._session.ctx.state["prompt"] = [{"type": "text", "text": user_input}] + self.send_input(self._session) + async for item in self.gen(): + yield item diff --git a/src/thirdparty/model_loader/__init__.py b/src/thirdparty/model_loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/thirdparty/model_loader/weight_utils.py b/src/thirdparty/model_loader/weight_utils.py new file mode 100644 index 00000000..89f687c9 --- /dev/null +++ b/src/thirdparty/model_loader/weight_utils.py @@ -0,0 +1,34 @@ +from typing import Generator, List, Tuple + +from safetensors.torch import safe_open +import torch +from tqdm.auto import tqdm + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def enable_tqdm(use_tqdm_on_load: bool): + return use_tqdm_on_load and ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + + +def safetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param diff --git a/src/thirdparty/qwen2_code2wav/__init__.py b/src/thirdparty/qwen2_code2wav/__init__.py new file mode 100644 index 00000000..7707b25f --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/__init__.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Union + + +@dataclass +class Code2WavGenerationConfig: + # dit cfm + num_steps: int = 10 + guidance_scale: float = 0.5 + sway_coefficient: float = -1.0 + + +@dataclass +class Code2WavEngineConfig(Code2WavGenerationConfig): + model_path: str = "" + enable_torch_compile: bool = False + enable_torch_compile_first_chunk: bool = False + odeint_method: str = "euler" + odeint_method_relaxed: bool = False + batched_chunk: int = 3 + frequency: str = "50hz" + device: Union[int, str] = "cuda" + code2wav_dynamic_batch: bool = False diff --git a/src/thirdparty/qwen2_code2wav/engine.py b/src/thirdparty/qwen2_code2wav/engine.py new file mode 100644 index 00000000..eb7c54b2 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/engine.py @@ -0,0 +1,188 @@ +import glob +import logging +import os +from typing import List, Union, Tuple + +import numpy as np +import torch + +from . import Code2WavEngineConfig, Code2WavGenerationConfig + +from ..model_loader.weight_utils import safetensors_weights_iterator +from .modeling_fast import Qwen2Code2wav + + +class Code2WavEngine: + def __init__(self, **kwargs) -> None: + self.args = Code2WavEngineConfig(**kwargs) + model_path = self.args.model_path + enable_torch_compile = self.args.enable_torch_compile + enable_torch_compile_first_chunk = self.args.enable_torch_compile_first_chunk + odeint_method = self.args.odeint_method + odeint_method_relaxed = self.args.odeint_method_relaxed + batched_chunk = self.args.batched_chunk + frequency: str = self.args.frequency + device: Union[int, str] = self.args.device + code2wav_dynamic_batch: bool = self.args.code2wav_dynamic_batch # todo batch chunk + + if isinstance(device, int): + device = f"cuda:{device}" + self.device = torch.device(device) + + logging.info( + f"Code2WavEngine starting up on device {self.device}, with model {model_path}, method: {odeint_method}, relaxed: {odeint_method_relaxed}" + ) + + # load spk_dict ["Ethan", "Chelsie"] + if os.path.exists(os.path.join(model_path, "spk_dict.pt")): + self.code2wav_conds, self.code2wav_ref_mels = self.load_spk_dict(model_path) + assert len(self.code2wav_conds) > 0 and len(self.code2wav_ref_mels) > 0, "No speakers found" + if "default" not in self.code2wav_conds: + self.code2wav_conds["default"] = self.code2wav_conds[ + list(self.code2wav_conds.keys())[0] + ] + if "default" not in self.code2wav_ref_mels: + self.code2wav_ref_mels["default"] = self.code2wav_ref_mels[ + list(self.code2wav_ref_mels.keys())[0] + ] + + self.frequency = frequency + self.code2wav_steps: int = 10 + self.code2wav_bs_mel: int = 24 if frequency == "50hz" else 32 + self.factor: int = 2 if frequency == "50hz" else 4 + + dit_model, bigvgan_model = self.load_code2wav(model_path) + self.code2wav = Qwen2Code2wav( + dit_ckpt=dit_model, + bigvgan_ckpt=bigvgan_model, + steps=self.code2wav_steps, + bs_mel=self.code2wav_bs_mel, + odeint_method=odeint_method, + odeint_method_relaxed=odeint_method_relaxed, + batched_chunk=batched_chunk, + frequency=frequency, + device=self.device, + with_weight_norm=False, + ) + self.torch_compile_model(enable_torch_compile, enable_torch_compile_first_chunk) + + self.code2wav_y_all = torch.randn( + 1, 32768, 80, device=self.device, dtype=list(self.code2wav_ref_mels.values())[0].dtype + ) + + def get_voice(self, voice_type: str = "default"): + if voice_type not in self.code2wav_conds: + logging.warning(f"voice type {voice_type} not found, using default") + voice_type = "default" + code2wav_cond = self.code2wav_conds[voice_type] + code2wav_ref_mel = self.code2wav_ref_mels[voice_type] + return code2wav_cond, code2wav_ref_mel + + def load_spk_dict(self, model_path): + code2wav_conds, code2wav_ref_mels = {}, {} + + if not os.path.exists(os.path.join(model_path, "spk_dict.pt")): + return code2wav_conds, code2wav_ref_mels + + for key, value in torch.load(os.path.join(model_path, "spk_dict.pt")).items(): + code2wav_conds[key] = value["cond"].to(self.device) + code2wav_ref_mels[key] = value["ref_mel"].to(self.device) + return code2wav_conds, code2wav_ref_mels + + def load_code2wav(self, model_path): + dit_model, bigvgan_model = {}, {} + safetensors = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + legacy_weights = False + for key, value in safetensors_weights_iterator(safetensors, use_tqdm_on_load=True): + legacy_weights = legacy_weights or "input_embed.spk_encoder.fc.conv.weight" in key + if legacy_weights: + break + for key, value in safetensors_weights_iterator(safetensors, use_tqdm_on_load=True): + if key.startswith("token2wav.code2wav_bigvgan_model."): + if "generator" not in bigvgan_model: + bigvgan_model["generator"] = {} + bigvgan_model["generator"][key.replace("token2wav.code2wav_bigvgan_model.", "")] = ( + value + ) + if key.startswith("token2wav.code2wav_dit_model."): + key = key.replace("token2wav.code2wav_dit_model.", "transformer.") + if key.startswith("transformer.input_embed.spk_encoder"): + if legacy_weights: + dit_model[key] = value + else: + dit_model[ + key.replace(".bias", ".conv.bias").replace(".weight", ".conv.weight") + ] = value + elif ".ff.ff.0.weight" in key or ".ff.ff.0.bias" in key: + dit_model[ + key.replace(".ff.ff.0.weight", ".ff.ff.0.0.weight").replace( + ".ff.ff.0.bias", ".ff.ff.0.0.bias" + ) + ] = value + elif ".ff.ff.3.weight" in key or ".ff.ff.3.bias" in key: + dit_model[ + key.replace(".ff.ff.3.weight", ".ff.ff.2.weight").replace( + ".ff.ff.3.bias", ".ff.ff.2.bias" + ) + ] = value + else: + dit_model[key] = value + return dit_model, bigvgan_model + + def torch_compile_model( + self, + enable_torch_compile, + enable_torch_compile_first_chunk, + ): + if not enable_torch_compile: + return + + # compile the bigvgan + self.code2wav.code2wav_bigvgan_model.vocoder.forward = torch.compile( + self.code2wav.code2wav_bigvgan_model.vocoder.forward, + ) + # compile the dit + if hasattr(self.code2wav, "enable_torch_compile"): + self.code2wav.enable_torch_compile(enable_torch_compile_first_chunk) + + logging.info("Code2Wav model torch compiled") + + @torch.inference_mode() + def step_generate_waveform( + self, + code: List[int], + prev_generated: Union[torch.Tensor, List[torch.Tensor]], + progress: int, + finished: bool = False, + y_all: torch.Tensor = None, + voice_type: str = "default", + gen_args: Code2WavGenerationConfig = Code2WavGenerationConfig(), + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: + """ + Generate waveform from code list step by step. + """ + cond, ref_mel = self.get_voice(voice_type) + chunk_code_length = len(code) * self.factor - self.code2wav.future_cache_size + if ( + chunk_code_length > 0 and chunk_code_length % self.code2wav.chunk_size == 0 + ) or finished: + code = torch.tensor(code, dtype=torch.long, device=self.device).reshape(1, -1) + if progress == 0 and finished: + process_chunk = self.code2wav.process_little_chunk + else: + process_chunk = self.code2wav.process_chunk + + return process_chunk( + cond, + ref_mel, + codec_all=code, + y_all=self.code2wav_y_all if y_all is None else y_all, + i=progress, + steps=gen_args.num_steps, + prev_generated=prev_generated, + finished=finished, + cfg_strength=gen_args.guidance_scale, + sway_sampling_coef=gen_args.sway_coefficient, + ) + else: + return prev_generated, None diff --git a/src/thirdparty/qwen2_code2wav/model/__init__.py b/src/thirdparty/qwen2_code2wav/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/thirdparty/qwen2_code2wav/model/dit.py b/src/thirdparty/qwen2_code2wav/model/dit.py new file mode 100644 index 00000000..312d23aa --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/dit.py @@ -0,0 +1,298 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +from torch import nn +import torch.nn.functional as F + +from x_transformers.x_transformers import RotaryEmbedding + +from .spk_encoder import ECAPA_TDNN +from .dit_modules import ( + TimestepEmbedding, + ConvNeXtV2Block, + ConvPositionEmbedding, + DiTBlock, + AdaLayerNormZero_Final, + precompute_freqs_cis, + get_pos_embed_indices, +) + +# Text embedding +class CodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, codec: int["b nc"], seq_len, drop_text=False): + if drop_text: + codec = torch.zeros_like(codec) + codec = self.codec_embed(codec) + codec = torch.repeat_interleave(codec, repeats=self.repeats, dim=1) + return codec + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim+ 128 + 192 + text_dim, out_dim) # 192 for x-vector + self.spk_encoder = ECAPA_TDNN(80, 128, + channels=[256, 256, 256, 256, 768], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=64, + res2net_scale=2, + se_channels=64, + global_context=True, + batch_norm=False) + # remove convposembedding for causal or block causal + # self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x: float["b n d"],spk: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + spk = torch.zeros_like(spk) + cond = self.spk_encoder(cond).unsqueeze(1).repeat(1, x.size(1), 1) + # import pdb; pdb.set_trace() + # print(x.shape,cond.shape,text_embed.shape) + x = self.proj(torch.cat((x, cond, text_embed, spk), dim=-1)) + # x = self.conv_pos_embed(x) + x + return x + + def fast_forward(self, x, spk, cond, text_embed,text_embed_uncond): + x = torch.cat([x,x],dim=0) + spk = torch.cat([spk,torch.zeros_like(spk)],dim=0) + cond = torch.cat([cond,torch.rand_like(cond)],dim=0) + cond = self.spk_encoder(cond).unsqueeze(1).repeat(1, x.size(1), 1) + text_emb = torch.cat([text_embed,text_embed_uncond],dim=0) + x = self.proj(torch.cat((x, cond, text_emb, spk), dim=-1)) + + return x + + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + long_skip_connection=False, + use_codec=False, + attn_processor="", + repeats=2 + ): + super().__init__() + self.repeats = repeats + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + if not use_codec: + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + else: + self.text_embed = CodecEmbedding(text_num_embeds, text_dim, repeats=repeats) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + if attn_processor == "stream_block_sr": + attn_processor_0 = 'stream_block_sr_00' + attn_processor_1 = 'stream_block_sr_10' + attn_processor_2 = 'stream_block_sr_01' + self.transformer_blocks = nn.ModuleList() + for i in range(depth): + if i == 0 or i == 20: + attn_processor_in = attn_processor_1 + elif i == 10: + attn_processor_in = attn_processor_2 + else: + attn_processor_in = attn_processor_0 + self.transformer_blocks.append( + DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, attn_processor=attn_processor_in) + ) + elif attn_processor == "stream_block_sr_low": + attn_processor_0 = 'stream_block_sr_00' + attn_processor_1 = 'stream_block_sr_11' + attn_processor_2 = 'stream_block_sr_10' + self.transformer_blocks = nn.ModuleList() + for i in range(depth): + if i == 0: + attn_processor_in = attn_processor_1 + elif i == 10 or i == 20: + attn_processor_in = attn_processor_2 + else: + attn_processor_in = attn_processor_0 + self.transformer_blocks.append( + DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, attn_processor=attn_processor_in) + ) + elif attn_processor == 'stream_block': + attn_processor_0 = 'stream_block_0' + attn_processor_1 = 'stream_block_1' + self.transformer_blocks = nn.ModuleList( + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, attn_processor=attn_processor_1 if i%5==0 else attn_processor_0) for i in range(depth)] + ) + elif attn_processor == "stream_block_8_L_4": + attn_processor_0 = 'stream_block_8_00' + attn_processor_1 = 'stream_block_8_10' + attn_processor_2 = 'stream_block_8_01' + self.transformer_blocks = nn.ModuleList() + for i in range(depth): + if i == 0 or i == 8 or i == 24 or i == 30: + attn_processor_in = attn_processor_1 + elif i == 15: + attn_processor_in = attn_processor_2 + else: + attn_processor_in = attn_processor_0 + self.transformer_blocks.append( + DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, attn_processor=attn_processor_in) + ) + else: + self.transformer_blocks = nn.ModuleList( + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, attn_processor=attn_processor) for i in range(depth)] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + spk: float["b n d"], # spk embedding # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + drop_audio_cond, # cfg for cond audio + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + # print(spk.dtype, cond.dtype, text_embed.dtype) + # import pdb; pdb.set_trace() + x = self.input_embed(x, spk, cond, text_embed, drop_audio_cond=drop_audio_cond) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output + + @torch.no_grad() + def fast_forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + spk: float["b n d"], # spk embedding # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + ) -> float["b n d"]: + batch, seq_len = x.shape[0]*2, x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + text_embed = self.text_embed(text, seq_len, drop_text=False) + text_embed_uncond = self.text_embed(text, seq_len, drop_text=True) + # print(spk.dtype, cond.dtype, text_embed.dtype) + # import pdb; pdb.set_trace() + x = self.input_embed.fast_forward(x, spk, cond, text_embed, text_embed_uncond) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output + \ No newline at end of file diff --git a/src/thirdparty/qwen2_code2wav/model/dit_modules.py b/src/thirdparty/qwen2_code2wav/model/dit_modules.py new file mode 100644 index 00000000..f7d3e32a --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/dit_modules.py @@ -0,0 +1,977 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from typing import Optional +import math + +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + +from x_transformers.x_transformers import apply_rotary_pos_emb + + +# raw wav to mel spec + + +class MelSpec(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + normalize=False, + power=1, + norm=None, + center=True, + ): + super().__init__() + self.n_mel_channels = n_mel_channels + + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=filter_length, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=power, + center=center, + normalized=normalize, + norm=norm, + ) + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, inp): + if len(inp.shape) == 3: + inp = inp.squeeze(1) # 'b 1 nw -> b nw' + + assert len(inp.shape) == 2 + + if self.dummy.device != inp.device: + self.to(inp.device) + + mel = self.mel_stft(inp) + mel = mel.clamp(min=1e-5).log() + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + for i in range(len(self.ff)): + tmp = self.ff[i](x) + x = tmp + return x + # return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor + | AttnProcessor + | BlockSelfAttnProcessor + | SteamingAttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + # print('use attn processor:',processor) + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class GlobalStreamSelfAttnProcessor: + def __init__(self, block_size=50, global_size=200, look_ahead_block=0, look_backward_block=0): + self.block_size = block_size + self.global_size = global_size + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + + def create_global_block_causal_mask(self, seq_length, device): + # 为每个 token 分配块索引 + positions = torch.arange(seq_length, device=device) # [seq_length] + + if self.global_size > 0: + # 分配块索引:全局块为 0,其余块从 1 开始 + block_indices = torch.where( + positions < self.global_size, + torch.zeros_like(positions), + 1 + (positions - self.global_size) // self.block_size, + ) + else: + # 没有全局块,所有块按普通块分配 + block_indices = positions // self.block_size # [seq_length] + + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + + # 计算块索引的差值 + block_diff = block_j - block_i # 形状为 (seq_length, seq_length) + + # 定义允许的块范围 + mask = (block_diff >= -float(self.look_backward_block)) & ( + block_diff <= float(self.look_ahead_block) + ) + + if self.global_size > 0: + # 确定哪些位置属于全局块 + is_global_i = block_i == 0 # [seq_length, 1] + is_global_j = block_j == 0 # [1, seq_length] + + # 所有块都可以看到全局块 + # 这意味着如果 j 是全局块,mask[i, j] = True + mask = (mask | is_global_j) & (~is_global_i | is_global_j) + + # 全局块可以看到所有块 + # 这意味着如果 i 是全局块,mask[i, j] = True + # mask = mask | is_global_i + + # 调整掩码形状为 [1, 1, seq_length, seq_length] 以适配多头注意力 + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_global_block_causal_mask( + seq_length=query.shape[2], device=query.device + ) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class BlockSelfAttnProcessor: + def __init__(self, block_size=100): + self.block_size = block_size + + def create_block_causal_mask(self, seq_length, device=None): + """ + 创建 Block Causal Attention 的 attention mask。 + + 参数: + - seq_length (int): 序列的长度。 + - device (torch.device, optional): 设备类型。 + + 返回: + - mask (torch.Tensor): attention mask,形状为 [1, 1, seq_length, seq_length]。 + """ + # 为每个 token 分配块索引 + block_indices = torch.arange(seq_length, device=device) // self.block_size # [seq_length] + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + + # 创建 mask:如果 j 所属块 <= i 所属块,则允许关注 + mask = block_j <= block_i # [seq_length, seq_length] + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_block_causal_mask( + seq_length=query.shape[2], device=query.device + ) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class SteamingAttnProcessor: + def __init__(self, block_size=40, look_ahead_block=1, look_backward_block=1): + self.block_size = block_size + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + + def create_block_mask(self, seq_length, device=None): + """ + 创建 Block Causal Attention 的 attention mask。 + + 参数: + - seq_length (int): 序列的长度。 + - block_size (int): 每个块的大小。 + - device (torch.device, optional): 设备类型。 + + 返回: + - mask (torch.Tensor): attention mask,形状为 [1, 1, seq_length, seq_length]。 + """ + # 为每个 token 分配块索引 + block_indices = torch.arange(seq_length, device=device) // self.block_size # [seq_length] + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + # 计算块索引的差值 + block_diff = block_j - block_i # 形状为 (n, n) + + # 定义允许的块范围 + mask = (block_diff >= -float(self.look_backward_block)) & ( + block_diff <= float(self.look_ahead_block) + ) + # 创建 mask:如果 j 所属块 <= i 所属块,则允许关注 + # mask = block_j <= block_i # [seq_length, seq_length] + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_block_mask(seq_length=query.shape[2], device=query.device) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, attn_processor=""): + super().__init__() + self.attn_norm = AdaLayerNormZero(dim) + # import pdb;pdb.set_trace() + # print(attn_processor) + if attn_processor == "block_attn": + processor = BlockSelfAttnProcessor() + elif attn_processor == "block_attn_50": + processor = BlockSelfAttnProcessor(block_size=50) + elif attn_processor == "stream_block_0": + processor = SteamingAttnProcessor(look_ahead_block=0, look_backward_block=0) + elif attn_processor == "stream_block_1": + processor = SteamingAttnProcessor(look_ahead_block=1, look_backward_block=1) + elif attn_processor == "stream_block_sr_00": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=0 + ) + elif attn_processor == "stream_block_sr_01": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=0 + ) + elif attn_processor == "stream_block_sr_10": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=1 + ) + elif attn_processor == "stream_block_sr_11": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=1 + ) + elif attn_processor == "g_stream_block_sr_00": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=0 + ) + elif attn_processor == "g_stream_block_sr_01": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=0 + ) + elif attn_processor == "g_stream_block_sr_10": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=1 + ) + elif attn_processor == "stream_block_8_00": + processor = SteamingAttnProcessor( + block_size=32, look_ahead_block=0, look_backward_block=0 + ) + elif attn_processor == "stream_block_8_01": + processor = SteamingAttnProcessor( + block_size=32, look_ahead_block=1, look_backward_block=0 + ) + elif attn_processor == "stream_block_8_10": + processor = SteamingAttnProcessor( + block_size=32, look_ahead_block=0, look_backward_block=1 + ) + + else: + processor = AttnProcessor() + self.attn = Attention( + processor=processor, + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + + # attention + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/src/thirdparty/qwen2_code2wav/model/modules.py b/src/thirdparty/qwen2_code2wav/model/modules.py new file mode 100644 index 00000000..a860c91b --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/modules.py @@ -0,0 +1,962 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from typing import Optional +import math + +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + +from x_transformers.x_transformers import apply_rotary_pos_emb + + +# raw wav to mel spec + + +class MelSpec(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + normalize=False, + power=1, + norm=None, + center=True, + ): + super().__init__() + self.n_mel_channels = n_mel_channels + + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=filter_length, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=power, + center=center, + normalized=normalize, + norm=norm, + ) + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, inp): + if len(inp.shape) == 3: + inp = inp.squeeze(1) # 'b 1 nw -> b nw' + + assert len(inp.shape) == 2 + + if self.dummy.device != inp.device: + self.to(inp.device) + + mel = self.mel_stft(inp) + mel = mel.clamp(min=1e-5).log() + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + for i in range(len(self.ff)): + tmp = self.ff[i](x) + x = tmp + return x + # return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor + | AttnProcessor + | BlockSelfAttnProcessor + | SteamingAttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class GlobalStreamSelfAttnProcessor: + def __init__(self, block_size=50, global_size=200, look_ahead_block=0, look_backward_block=0): + self.block_size = block_size + self.global_size = global_size + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + + def create_global_block_causal_mask(self, seq_length, device): + # 为每个 token 分配块索引 + positions = torch.arange(seq_length, device=device) # [seq_length] + + if self.global_size > 0: + # 分配块索引:全局块为 0,其余块从 1 开始 + block_indices = torch.where( + positions < self.global_size, + torch.zeros_like(positions), + 1 + (positions - self.global_size) // self.block_size, + ) + else: + # 没有全局块,所有块按普通块分配 + block_indices = positions // self.block_size # [seq_length] + + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + + # 计算块索引的差值 + block_diff = block_j - block_i # 形状为 (seq_length, seq_length) + + # 定义允许的块范围 + mask = (block_diff >= -float(self.look_backward_block)) & ( + block_diff <= float(self.look_ahead_block) + ) + + if self.global_size > 0: + # 确定哪些位置属于全局块 + is_global_i = block_i == 0 # [seq_length, 1] + is_global_j = block_j == 0 # [1, seq_length] + + # 所有块都可以看到全局块 + # 这意味着如果 j 是全局块,mask[i, j] = True + mask = (mask | is_global_j) & (~is_global_i | is_global_j) + + # 全局块可以看到所有块 + # 这意味着如果 i 是全局块,mask[i, j] = True + # mask = mask | is_global_i + + # 调整掩码形状为 [1, 1, seq_length, seq_length] 以适配多头注意力 + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_global_block_causal_mask( + seq_length=query.shape[2], device=query.device + ) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class BlockSelfAttnProcessor: + def __init__(self, block_size=100): + self.block_size = block_size + + def create_block_causal_mask(self, seq_length, device=None): + """ + 创建 Block Causal Attention 的 attention mask。 + + 参数: + - seq_length (int): 序列的长度。 + - device (torch.device, optional): 设备类型。 + + 返回: + - mask (torch.Tensor): attention mask,形状为 [1, 1, seq_length, seq_length]。 + """ + # 为每个 token 分配块索引 + block_indices = torch.arange(seq_length, device=device) // self.block_size # [seq_length] + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + + # 创建 mask:如果 j 所属块 <= i 所属块,则允许关注 + mask = block_j <= block_i # [seq_length, seq_length] + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_block_causal_mask( + seq_length=query.shape[2], device=query.device + ) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class SteamingAttnProcessor: + def __init__(self, block_size=40, look_ahead_block=1, look_backward_block=1): + self.block_size = block_size + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + + def create_block_mask(self, seq_length, device=None): + """ + 创建 Block Causal Attention 的 attention mask。 + + 参数: + - seq_length (int): 序列的长度。 + - block_size (int): 每个块的大小。 + - device (torch.device, optional): 设备类型。 + + 返回: + - mask (torch.Tensor): attention mask,形状为 [1, 1, seq_length, seq_length]。 + """ + # 为每个 token 分配块索引 + block_indices = torch.arange(seq_length, device=device) // self.block_size # [seq_length] + # 拓展维度以进行块间比较 + # block_i: [seq_length, 1], block_j: [1, seq_length] + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + # 计算块索引的差值 + block_diff = block_j - block_i # 形状为 (n, n) + + # 定义允许的块范围 + mask = (block_diff >= -float(self.look_backward_block)) & ( + block_diff <= float(self.look_ahead_block) + ) + # 创建 mask:如果 j 所属块 <= i 所属块,则允许关注 + # mask = block_j <= block_i # [seq_length, seq_length] + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + return mask + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + assert ( + query.shape[2] == key.shape[2] and query.shape[2] == value.shape[2] + ), f"block self attn require self attn but get query_shape: {query.shape}, key_shape: {key.shape}, value_shape: {value.shape}" + # mask. e.g. inference got a batch with different target durations, mask out the padding + block_attn_mask = self.create_block_mask(seq_length=query.shape[2], device=query.device) + block_attn_mask = block_attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = torch.logical_and(block_attn_mask, attn_mask) + else: + attn_mask = block_attn_mask + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, attn_processor=""): + super().__init__() + self.attn_norm = AdaLayerNormZero(dim) + if attn_processor == "block_attn": + processor = BlockSelfAttnProcessor() + elif attn_processor == "block_attn_50": + processor = BlockSelfAttnProcessor(block_size=50) + elif attn_processor == "stream_block_0": + processor = SteamingAttnProcessor(look_ahead_block=0, look_backward_block=0) + elif attn_processor == "stream_block_1": + processor = SteamingAttnProcessor(look_ahead_block=1, look_backward_block=1) + elif attn_processor == "stream_block_sr_00": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=0 + ) + elif attn_processor == "stream_block_sr_01": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=0 + ) + elif attn_processor == "stream_block_sr_10": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=1 + ) + elif attn_processor == "stream_block_sr_11": + processor = SteamingAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=1 + ) + elif attn_processor == "g_stream_block_sr_00": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=0 + ) + elif attn_processor == "g_stream_block_sr_01": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=1, look_backward_block=0 + ) + elif attn_processor == "g_stream_block_sr_10": + processor = GlobalStreamSelfAttnProcessor( + block_size=24, look_ahead_block=0, look_backward_block=1 + ) + + else: + processor = AttnProcessor() + self.attn = Attention( + processor=processor, + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + + # attention + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/src/thirdparty/qwen2_code2wav/model/spk_encoder.py b/src/thirdparty/qwen2_code2wav/model/spk_encoder.py new file mode 100644 index 00000000..f8889181 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/spk_encoder.py @@ -0,0 +1,923 @@ +"""A popular speaker recognition and diarization model. + +Authors + * Hwidong Na 2020 +""" + +import math +import os +import torch # noqa: F401 +import torch.nn as nn +import torch.nn.functional as F + + +def length_to_mask(length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + + Example + ------- + >>> length=torch.Tensor([1,2,3]) + >>> mask=length_to_mask(length) + >>> mask + tensor([[1., 0., 0.], + [1., 1., 0.], + [1., 1., 1.]]) + """ + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + if device is None: + device = length.device + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + +def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): + """This function computes the number of elements to add for zero-padding. + + Arguments + --------- + L_in : int + stride: int + kernel_size : int + dilation : int + """ + if stride > 1: + n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) + L_out = stride * (n_steps - 1) + kernel_size * dilation + padding = [kernel_size // 2, kernel_size // 2] + + else: + L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 + + padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] + return padding + + +class Conv1d(nn.Module): + """This function implements 1d convolution. + + Arguments + --------- + out_channels : int + It is the number of output channels. + kernel_size : int + Kernel size of the convolutional filters. + input_shape : tuple + The shape of the input. Alternatively use ``in_channels``. + in_channels : int + The number of input channels. Alternatively use ``input_shape``. + stride : int + Stride factor of the convolutional filters. When the stride factor > 1, + a decimation in time is performed. + dilation : int + Dilation factor of the convolutional filters. + padding : str + (same, valid, causal). If "valid", no padding is performed. + If "same" and stride is 1, output shape is the same as the input shape. + "causal" results in causal (dilated) convolutions. + padding_mode : str + This flag specifies the type of padding. See torch.nn documentation + for more information. + skip_transpose : bool + If False, uses batch x time x channel convention of speechbrain. + If True, uses batch x channel x time convention. + + Example + ------- + >>> inp_tensor = torch.rand([10, 40, 16]) + >>> cnn_1d = Conv1d( + ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 + ... ) + >>> out_tensor = cnn_1d(inp_tensor) + >>> out_tensor.shape + torch.Size([10, 40, 8]) + """ + + def __init__( + self, + out_channels, + kernel_size, + input_shape=None, + in_channels=None, + stride=1, + dilation=1, + padding="same", + groups=1, + bias=True, + padding_mode="reflect", + skip_transpose=True, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + self.unsqueeze = False + self.skip_transpose = skip_transpose + + if input_shape is None and in_channels is None: + raise ValueError("Must provide one of input_shape or in_channels") + + if in_channels is None: + in_channels = self._check_input_shape(input_shape) + + self.conv = nn.Conv1d( + in_channels, + out_channels, + self.kernel_size, + stride=self.stride, + dilation=self.dilation, + padding=0, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """Returns the output of the convolution. + + Arguments + --------- + x : torch.Tensor (batch, time, channel) + input to convolve. 2d or 4d tensors are expected. + """ + + if not self.skip_transpose: + x = x.transpose(1, -1) + + if self.unsqueeze: + x = x.unsqueeze(1) + + if self.padding == "same": + x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) + + elif self.padding == "causal": + num_pad = (self.kernel_size - 1) * self.dilation + x = F.pad(x, (num_pad, 0)) + + elif self.padding == "valid": + pass + + else: + raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding) + # print(x.shape) + wx = self.conv(x) + + if self.unsqueeze: + wx = wx.squeeze(1) + + if not self.skip_transpose: + wx = wx.transpose(1, -1) + + return wx + + def _manage_padding( + self, + x, + kernel_size: int, + dilation: int, + stride: int, + ): + """This function performs zero-padding on the time axis + such that their lengths is unchanged after the convolution. + + Arguments + --------- + x : torch.Tensor + Input tensor. + kernel_size : int + Size of kernel. + dilation : int + Dilation used. + stride : int + Stride. + """ + + # Detecting input shape + L_in = x.shape[-1] + + # Time padding + padding = get_padding_elem(L_in, stride, kernel_size, dilation) + + # Applying padding + x = F.pad(x, padding, mode=self.padding_mode) + + return x + + def _check_input_shape(self, shape): + """Checks the input shape and returns the number of input channels.""" + + if len(shape) == 2: + self.unsqueeze = True + in_channels = 1 + elif self.skip_transpose: + in_channels = shape[1] + elif len(shape) == 3: + in_channels = shape[2] + else: + raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape))) + + # Kernel size must be odd + if self.kernel_size % 2 == 0: + raise ValueError( + "The field kernel size must be an odd number. Got %s." % (self.kernel_size) + ) + return in_channels + + +class Fp32BatchNorm(nn.Module): + def __init__(self, sync=True, *args, **kwargs): + super().__init__() + + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + sync = False + + if sync: + self.bn = nn.SyncBatchNorm(*args, **kwargs) + else: + self.bn = nn.BatchNorm1d(*args, **kwargs) + + self.sync = sync + + def forward(self, input): + if self.bn.running_mean.dtype != torch.float: + if self.sync: + self.bn.running_mean = self.bn.running_mean.float() + self.bn.running_var = self.bn.running_var.float() + if self.bn.affine: + try: + self.bn.weight = self.bn.weight.float() + self.bn.bias = self.bn.bias.float() + except: + self.bn.float() + else: + self.bn.float() + + output = self.bn(input.float()) + return output.type_as(input) + + +class BatchNorm1d(nn.Module): + """Applies 1d batch normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + + + Example + ------- + >>> input = torch.randn(100, 10) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + combine_batch_time=False, + skip_transpose=True, + enabled=True, + ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + if enabled: + self.norm = Fp32BatchNorm( + num_features=input_size, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + else: + self.norm = nn.Identity() + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2]) + + elif not self.skip_transpose: + x = x.transpose(-1, 1) + + x_n = self.norm(x) + + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose(1, -1) + + return x_n + + +class Linear(torch.nn.Module): + """Computes a linear transformation y = wx + b. + + Arguments + --------- + n_neurons : int + It is the number of output neurons (i.e, the dimensionality of the + output). + bias : bool + If True, the additive bias b is adopted. + combine_dims : bool + If True and the input is 4D, combine 3rd and 4th dimensions of input. + + Example + ------- + >>> inputs = torch.rand(10, 50, 40) + >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) + >>> output = lin_t(inputs) + >>> output.shape + torch.Size([10, 50, 100]) + """ + + def __init__( + self, + n_neurons, + input_shape=None, + input_size=None, + bias=True, + combine_dims=False, + ): + super().__init__() + self.combine_dims = combine_dims + + if input_shape is None and input_size is None: + raise ValueError("Expected one of input_shape or input_size") + + if input_size is None: + input_size = input_shape[-1] + if len(input_shape) == 4 and self.combine_dims: + input_size = input_shape[2] * input_shape[3] + + # Weights are initialized following pytorch approach + self.w = nn.Linear(input_size, n_neurons, bias=bias) + + def forward(self, x): + """Returns the linear transformation of input tensor. + + Arguments + --------- + x : torch.Tensor + Input to transform linearly. + """ + if x.ndim == 4 and self.combine_dims: + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) + + wx = self.w(x) + + return wx + + +class TDNNBlock(nn.Module): + """An implementation of TDNN. + + Arguments + ---------- + in_channels : int + Number of input channels. + out_channels : int + The number of output channels. + kernel_size : int + The kernel size of the TDNN blocks. + dilation : int + The dilation of the Res2Net block. + activation : torch class + A class for constructing the activation layers. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1) + >>> out_tensor = layer(inp_tensor).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + activation=nn.ReLU, + batch_norm=True, + ): + super(TDNNBlock, self).__init__() + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + ) + self.activation = activation() + self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm) + + def forward(self, x): + return self.norm(self.activation(self.conv(x))) + + +class Res2NetBlock(torch.nn.Module): + """An implementation of Res2NetBlock w/ dilation. + + Arguments + --------- + in_channels : int + The number of channels expected in the input. + out_channels : int + The number of output channels. + scale : int + The scale of the Res2Net block. + kernel_size: int + The kernel size of the Res2Net block. + dilation : int + The dilation of the Res2Net block. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) + >>> out_tensor = layer(inp_tensor).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, + in_channels, + out_channels, + scale=8, + kernel_size=3, + dilation=1, + batch_norm=True, + ): + super(Res2NetBlock, self).__init__() + assert in_channels % scale == 0 + assert out_channels % scale == 0 + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TDNNBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + batch_norm=batch_norm, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, x): + y = [] + for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): + if i == 0: + y_i = x_i + elif i == 1: + y_i = self.blocks[i - 1](x_i) + else: + y_i = self.blocks[i - 1](x_i + y_i) + y.append(y_i) + y = torch.cat(y, dim=1) + return y + + +class SEBlock(nn.Module): + """An implementation of squeeze-and-excitation block. + + Arguments + --------- + in_channels : int + The number of input channels. + se_channels : int + The number of output channels after squeeze. + out_channels : int + The number of output channels. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> se_layer = SEBlock(64, 16, 64) + >>> lengths = torch.rand((8,)) + >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__(self, in_channels, se_channels, out_channels): + super(SEBlock, self).__init__() + + self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1) + self.relu = torch.nn.ReLU(inplace=True) + self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, lengths=None): + L = x.shape[-1] + if lengths is not None: + mask = length_to_mask(lengths * L, max_len=L, device=x.device) + mask = mask.unsqueeze(1) + total = mask.sum(dim=2, keepdim=True) + s = (x * mask).sum(dim=2, keepdim=True) / total + else: + s = x.mean(dim=2, keepdim=True) + + s = self.relu(self.conv1(s)) + s = self.sigmoid(self.conv2(s)) + + return s * x + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + + Arguments + --------- + channels: int + The number of input channels. + attention_channels: int + The number of attention channels. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> asp_layer = AttentiveStatisticsPooling(64) + >>> lengths = torch.rand((8,)) + >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 1, 128]) + """ + + def __init__(self, channels, attention_channels=128, global_context=True, batch_norm=True): + super().__init__() + + self.eps = 1e-12 + self.global_context = global_context + if global_context: + self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1, batch_norm=batch_norm) + else: + self.tdnn = TDNNBlock(channels, attention_channels, 1, 1, batch_norm, batch_norm) + self.tanh = nn.Tanh() + self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1) + + def forward(self, x, lengths=None): + """Calculates mean and std for a batch (input tensor). + + Arguments + --------- + x : torch.Tensor + Tensor of shape [N, C, L]. + """ + L = x.shape[-1] + + def _compute_statistics(x, m, dim=2, eps=self.eps): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) + return mean, std + + if lengths is None: + lengths = torch.ones(x.shape[0], device=x.device) + + # Make binary mask of shape [N, 1, L] + mask = length_to_mask(lengths * L, max_len=L, device=x.device) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + if self.global_context: + # torch.std is unstable for backward computation + # https://github.com/pytorch/pytorch/issues/4320 + total = mask.sum(dim=2, keepdim=True) + + mean, std = _compute_statistics(x, mask / total) + if x.dtype == torch.float16: + mean = mean.half() + std = std.half() + mean = mean.unsqueeze(2).repeat(1, 1, L) + std = std.unsqueeze(2).repeat(1, 1, L) + attn = torch.cat([x, mean, std], dim=1) + else: + attn = x + + # Apply layers + attn = self.conv(self.tanh(self.tdnn(attn))) + + # Filter out zero-paddings + attn = attn.masked_fill(mask == 0, float("-inf")) + + attn = F.softmax(attn, dim=2) + mean, std = _compute_statistics(x, attn) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SERes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SEBlock. + + Arguments + ---------- + out_channels: int + The number of output channels. + res2net_scale: int + The scale of the Res2Net block. + kernel_size: int + The kernel size of the TDNN blocks. + dilation: int + The dilation of the Res2Net block. + activation : torch class + A class for constructing the activation layers. + + Example + ------- + >>> x = torch.rand(8, 120, 64).transpose(1, 2) + >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) + >>> out = conv(x).transpose(1, 2) + >>> out.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + activation=torch.nn.ReLU, + batch_norm=True, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TDNNBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + batch_norm=batch_norm, + ) + self.res2net_block = Res2NetBlock( + out_channels, out_channels, res2net_scale, kernel_size, dilation, batch_norm=batch_norm + ) + self.tdnn2 = TDNNBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + batch_norm=batch_norm, + ) + self.se_block = SEBlock(out_channels, se_channels, out_channels) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.tdnn1(x) + x = self.res2net_block(x) + x = self.tdnn2(x) + x = self.se_block(x, lengths) + + return x + residual + + +class ECAPA_TDNN(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). + + Arguments + --------- + device : str + Device used, e.g., "cpu" or "cuda". + activation : torch class + A class for constructing the activation layers. + channels : list of ints + Output channels for TDNN/SERes2Net layer. + kernel_sizes : list of ints + List of kernel sizes for each layer. + dilations : list of ints + List of dilations for kernels in each layer. + lin_neurons : int + Number of neurons in linear layers. + + Example + ------- + >>> input_feats = torch.rand([5, 120, 80]) + >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192) + >>> outputs = compute_embedding(input_feats) + >>> outputs.shape + torch.Size([5, 1, 192]) + """ + + def __init__( + self, + input_size, + lin_neurons=192, + activation=torch.nn.ReLU, + channels=[512, 512, 512, 512, 1536], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=128, + res2net_scale=8, + se_channels=128, + global_context=True, + batch_norm=True, + ): + super().__init__() + assert len(channels) == len(kernel_sizes) + assert len(channels) == len(dilations) + self.channels = channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TDNNBlock( + input_size, + channels[0], + kernel_sizes[0], + dilations[0], + activation, + batch_norm=batch_norm, + ) + ) + + # SE-Res2Net layers + for i in range(1, len(channels) - 1): + self.blocks.append( + SERes2NetBlock( + channels[i - 1], + channels[i], + res2net_scale=res2net_scale, + se_channels=se_channels, + kernel_size=kernel_sizes[i], + dilation=dilations[i], + activation=activation, + batch_norm=batch_norm, + ) + ) + + # Multi-layer feature aggregation + self.mfa = TDNNBlock( + channels[-1], + channels[-1], + kernel_sizes[-1], + dilations[-1], + activation, + batch_norm=batch_norm, + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + channels[-1], + attention_channels=attention_channels, + global_context=global_context, + batch_norm=batch_norm, + ) + self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm) + + # Final linear transformation + self.fc = Conv1d( + in_channels=channels[-1] * 2, + out_channels=lin_neurons, + kernel_size=1, + ) + + # @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32) + def forward(self, x, lengths=None): + """Returns the embedding vector. + + Arguments + --------- + x : torch.Tensor + Tensor of shape (batch, time, channel). + """ + # Minimize transpose for efficiency + x = x.transpose(1, 2) + + xl = [] + for layer in self.blocks: + try: + x = layer(x, lengths=lengths) + except TypeError: + x = layer(x) + xl.append(x) + + # Multi-layer feature aggregation + x = torch.cat(xl[1:], dim=1) + x = self.mfa(x) + + # Attentive Statistical Pooling + x = self.asp(x, lengths=lengths) + x = self.asp_bn(x) + + # Final linear transformation + x = self.fc(x) + + x = x.squeeze(-1) + return x + + +if __name__ == "__main__": + model = ECAPA_TDNN( + 80, + 512, + channels=[256, 256, 256, 256, 768], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=64, + res2net_scale=2, + se_channels=64, + global_context=True, + batch_norm=False, + ) + + # print(model) diff --git a/src/thirdparty/qwen2_code2wav/model/t2w_cfm.py b/src/thirdparty/qwen2_code2wav/model/t2w_cfm.py new file mode 100644 index 00000000..03830941 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/t2w_cfm.py @@ -0,0 +1,474 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from typing import Callable +import random + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from torchdiffeq import odeint + +from .modules import MelSpec +from .utils import ( + default, + exists, + list_str_to_idx, + list_str_to_tensor, + lens_to_mask, + mask_from_frac_lengths, +) + + +class CodecCFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + upsample_rate: int = 2, + ): + super().__init__() + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # self.spk_proj = nn.Linear(192, 80) + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + self.upsample_rate = upsample_rate + + def logit_normal_sample(self, batch, dtype, device, m=0.0, s=1.0): + u = torch.randn((batch,), dtype=dtype, device=device) * s + m # u ~ N(m, s^2) + samples = torch.sigmoid(u) # logistic(u) = 1 / (1 + exp(-u)) + + return samples + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + codec: int["b nc dc"], + ref_mel: float["b n d"], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + y0: float["b n d"] | None = None, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + + max_duration = codec.shape[1] * self.transformer.repeats + if next(self.parameters()).dtype == torch.float16: + cond = cond.half() + ref_mel = ref_mel.half() + if y0 is not None: + y0 = y0.half() + # raw wave + cond = cond.unsqueeze(1).repeat(1, max_duration, 1) + # if cond.ndim == 2: + # cond = self.mel_spec(cond) + # cond = cond.permute(0, 2, 1) + # assert cond.shape[-1] == self.num_channels + + batch, cond_seq_len, device = *ref_mel.shape[:2], codec.device + assert batch == 1, "only support batch size = 1 currently" + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # cond_mask = lens_to_mask(lens) + # ref_mel = F.pad(ref_mel, (0, 0, 0, cond_seq_len), value=0.0) + # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + # cond_mask = cond_mask.unsqueeze(-1) + # step_cond = torch.where( + # cond_mask, ref_mel, torch.zeros_like(ref_mel) + # ) # allow direct control (cut cond audio) with lens passed in + + mask = None + + # test for no ref audio + if no_ref_audio: + ref_mel = torch.zeros_like(ref_mel) + cond = torch.zeros_like(cond) + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, ref_mel, torch.zeros_like(ref_mel)) + + # predict flow + # print(x.dtype,cond.dtype,ref_mel.dtype) + pred = self.transformer( + x=x, + spk=cond, + cond=ref_mel, + text=codec, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + ) + if cfg_strength < 1e-5: + return pred + + null_pred = self.transformer( + x=x, + spk=cond, + cond=ref_mel, + text=codec, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, + ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + if y0 is None: + y0 = torch.randn( + [1, max_duration, self.num_channels], device=self.device, dtype=cond.dtype + ) + + t_start = 0 + t = torch.linspace(t_start, 1, steps, device=self.device, dtype=cond.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + # out = torch.where(cond_mask, cond, out) + return out, trajectory + + @torch.no_grad() + def block_sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + codec: int["b nc dc"], + ref_mel: float["b n d"], # noqa: F722 + y0: float["b n d"], + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + + max_duration = y0.shape[1] + if next(self.parameters()).dtype == torch.float16: + cond = cond.half() + ref_mel = ref_mel.half() + y0 = y0.half() + # print(next(self.parameters()).dtype) + + # raw wave + + # if cond.ndim == 2: + # cond = self.mel_spec(cond) + # cond = cond.permute(0, 2, 1) + # assert cond.shape[-1] == self.num_channels + cond = cond.unsqueeze(1).repeat(1, max_duration, 1) + batch, cond_seq_len, device = *ref_mel.shape[:2], cond.device + assert batch == 1, "only support batch size = 1 currently" + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # cond_mask = lens_to_mask(lens) + # ref_mel = F.pad(ref_mel, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + # cond_mask = cond_mask.unsqueeze(-1) + # step_cond = torch.where( + # cond_mask, ref_mel, torch.zeros_like(ref_mel) + # ) # allow direct control (cut cond audio) with lens passed in + # ref_mel = F.pad(ref_mel, (0, 0, 0, cond_seq_len), value=0.0) + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, ref_mel, torch.zeros_like(ref_mel)) + + # predict flow + # import pdb;pdb.set_trace() + pred = self.transformer( + x=x, + cond=ref_mel, + spk=cond, + text=codec, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + ) + if cfg_strength < 1e-5: + return pred + + null_pred = self.transformer( + x=x, + cond=ref_mel, + spk=cond, + text=codec, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, + ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # y0 = torch.randn([1, max_duration, self.num_channels], device=self.device, dtype=step_cond.dtype) + + t_start = 0 + t = torch.linspace(t_start, 1, steps, device=self.device, dtype=ref_mel.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + # out = torch.where(cond_mask, ref_mel, out) + return out, trajectory + + @torch.no_grad() + def fast_block_sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + codec: int["b nc dc"], + ref_mel: float["b n d"], # noqa: F722 + y0: float["b n d"], + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + + max_duration = y0.shape[1] + if next(self.parameters()).dtype == torch.float16: + cond = cond.half() + ref_mel = ref_mel.half() + y0 = y0.half() + # print(next(self.parameters()).dtype) + + # raw wave + + # if cond.ndim == 2: + # cond = self.mel_spec(cond) + # cond = cond.permute(0, 2, 1) + # assert cond.shape[-1] == self.num_channels + cond = cond.unsqueeze(1).repeat(1, max_duration, 1) + batch, cond_seq_len, device = *ref_mel.shape[:2], cond.device + assert batch == 1, "only support batch size = 1 currently" + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # cond_mask = lens_to_mask(lens) + # ref_mel = F.pad(ref_mel, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + # cond_mask = cond_mask.unsqueeze(-1) + # step_cond = torch.where( + # cond_mask, ref_mel, torch.zeros_like(ref_mel) + # ) # allow direct control (cut cond audio) with lens passed in + # ref_mel = F.pad(ref_mel, (0, 0, 0, cond_seq_len), value=0.0) + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, ref_mel, torch.zeros_like(ref_mel)) + + # predict flow + # print(x.dtype,cond.dtype,ref_mel.dtype) + out_put = self.transformer.fast_forward( + x=x, + text=codec, + spk=cond, + cond=ref_mel, + time=t, + mask=mask, + ) + pred, null_pred = torch.chunk(out_put, 2, dim=0) + # pred = self.transformer( + # x=x, spk=cond,cond=ref_mel, text=codec, time=t, mask=mask, drop_audio_cond=False, drop_text=False + # ) + # if cfg_strength < 1e-5: + # return pred + + # null_pred = self.transformer( + # x=x, spk=cond,cond=ref_mel , text=codec, time=t, mask=mask, drop_audio_cond=True, drop_text=True + # ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # y0 = torch.randn([1, max_duration, self.num_channels], device=self.device, dtype=step_cond.dtype) + + t_start = 0 + t = torch.linspace(t_start, 1, steps, device=self.device, dtype=ref_mel.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + # out = torch.where(cond_mask, ref_mel, out) + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + codec: int["b nc dc"], + lens: int["b"] | None = None, # noqa: F821 + spk: int["b nc"] | None = None, + ref_mel: float["b n d"] | None = None, + noise_scheduler: str | None = None, + use_log_norm: bool = True, + ): + batch, seq_len, dtype, device, sigma = *inp.shape[:2], inp.dtype, self.device, self.sigma + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch + + # # get a random span to mask out for training conditionally + # frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) + # rand_span_mask = mask_from_frac_lengths(lens, 0.2) + + # if exists(mask): + # rand_span_mask &= mask + + # mel is x1 + x1 = inp + # cond = self.spk_proj(cond) + spk = spk.unsqueeze(1).repeat(1, inp.size(1), 1) + # x0 is gaussian noise + x0 = torch.randn_like(x1) + # cond = torch.zeros_like(x1) + # cond_mask = torch.zeros_like(cond,dtype=torch.bool) + # for i,j in enumerate(lens): + # if random.random() < 0.6: + # continue + # index = random.randint(0,int(0.6*j)) + # length = random.randint(0,int(0.3*j)) + # cond[i,index:index+length,:] = x1[i,index:index+length,:] + # cond_mask[i,index:index+length,:] = True + + # import pdb;pdb.set_trace() + # time step + if use_log_norm: + time = self.logit_normal_sample(batch, dtype=dtype, device=self.device) + else: + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + phi = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + # cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random.random() < self.audio_drop_prob # p_drop in voicebox paper + if random.random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here + # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + pred = self.transformer( + x=phi, + cond=ref_mel, + spk=spk, + text=codec, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + # mask = cond_mask&mask.unsqueeze(-1) + loss = loss[mask] + + return loss.mean(), ref_mel, pred diff --git a/src/thirdparty/qwen2_code2wav/model/utils.py b/src/thirdparty/qwen2_code2wav/model/utils.py new file mode 100644 index 00000000..539bbe19 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/model/utils.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import os +import math +import random +import string +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +import torchaudio + + +def seed_everything(seed=0): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + +def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 + max_seq_len = seq_len.max().item() + seq = torch.arange(max_seq_len, device=start.device).long() + start_mask = seq[None, :] >= start[:, None] + end_mask = seq[None, :] < end[:, None] + return start_mask & end_mask + + +def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.rand_like(frac_lengths) + start = (max_start * rand).long().clamp(min=0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + + +def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 + if not exists(mask): + return t.mean(dim=1) + + t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) + num = t.sum(dim=1) + den = mask.float().sum(dim=1) + + return num / den.clamp(min=1.0) + + +# simple utf-8 tokenizer, since paper went character based +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + +# char tokenizer, based on custom dataset's extracted .txt file +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + +# padded to max length mel batch +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def load_checkpoint(model, ckpt_path_or_ckpt, device, use_ema=True): + if not isinstance(ckpt_path_or_ckpt, str): + checkpoint = ckpt_path_or_ckpt + ckpt_type = "safetensors" + else: + ckpt_type = ckpt_path_or_ckpt.split(".")[-1] + if ckpt_type == "safetensors": + from safetensors.torch import load_file + + checkpoint = load_file(ckpt_path_or_ckpt) + else: + checkpoint = torch.load(ckpt_path_or_ckpt, weights_only=True) + + if use_ema: + if ckpt_type == "safetensors": + checkpoint = {"ema_model_state_dict": checkpoint} + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) + else: + if ckpt_type == "safetensors": + checkpoint = {"model_state_dict": checkpoint} + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) + + assert unexpected_keys == [], f"Unexpected keys: {unexpected_keys}" + assert missing_keys == [] or missing_keys == [ + "mel_spec.mel_stft.spectrogram.window", + "mel_spec.mel_stft.mel_scale.fb", + ], f"Missing keys: {missing_keys}" + + return model.to(device) diff --git a/src/thirdparty/qwen2_code2wav/modeling.py b/src/thirdparty/qwen2_code2wav/modeling.py new file mode 100644 index 00000000..81bc15c8 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/modeling.py @@ -0,0 +1,1171 @@ +# SPDX-License-Identifier: Apache-2.0 +import math + +import numpy as np +import torch +from torch import nn, pow, sin +from torch.nn import Conv1d, ConvTranspose1d, Parameter +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +from .model.dit import DiT +from .model.t2w_cfm import CodecCFM +from .model.utils import load_checkpoint + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + + def forward(self, x): + return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) + + +class Snake(nn.Module): + """ + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + """ + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class AMPBlock1(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + activation=None, + snake_logscale=True, + frequency="50hz", + causal_type="1", + ): + super(AMPBlock1, self).__init__() + + self.frequency = frequency + if self.frequency == "50hz": + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + else: + self.convs1 = nn.ModuleList( + [ + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + # padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + # padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + # padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + if causal_type == "1": + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + else: + self.convs2 = nn.ModuleList( + [ + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + # padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + # padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + # padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + Activation1d = TorchActivation1d + + if activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d(activation=Snake(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d(activation=SnakeBeta(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + if causal_type == "1": + self.pre_conv = nn.Identity() + self.pre_act = nn.Identity() + else: + self.pre_conv = weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=get_padding(kernel_size, 1), + ) + ) + self.pre_conv.apply(init_weights) + if activation == "snake": + self.pre_act = Activation1d( + activation=Snake(channels, alpha_logscale=snake_logscale) + ) + elif activation == "snakebeta": + self.pre_act = Activation1d( + activation=SnakeBeta(channels, alpha_logscale=snake_logscale) + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + if self.frequency == "50hz": + return self.forward_50hz(x) + else: + raise ValueError(f"Unsupported frequency: {self.frequency}") + + def forward_50hz(self, x): + x = self.pre_conv(x) + x = self.pre_act(x) + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3), + activation=None, + snake_logscale=True, + ): + super(AMPBlock2, self).__init__() + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + Activation1d = TorchActivation1d + + if activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d(activation=Snake(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d(activation=SnakeBeta(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__( + self, + frequency: str = "50hz", # 50hz or 25 hz + num_mels=80, + initial_kernel=5, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + resblock_type="1", + snake_logscale=True, + activation="snakebeta", + use_tanh_at_final=False, + use_bias_at_final=False, + ): + super(BigVGAN, self).__init__() + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + # pre conv + self.conv_pre = weight_norm( + Conv1d( + num_mels, upsample_initial_channel, initial_kernel, 1, padding=initial_kernel // 2 + ) + ) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + if frequency == "50hz": + causal_type = "1" + else: + if i > 1: + causal_type = "1" + else: + causal_type = "2" + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append( + resblock( + ch, + k, + d, + activation=activation, + snake_logscale=snake_logscale, + frequency=frequency, + causal_type=causal_type, + ) + ) + + Activation1d = TorchActivation1d + + # post conv + if activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = Snake(ch, alpha_logscale=snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + # whether to use bias for the final conv_post. Defaults to True for backward compatibility + self.use_bias_at_final = use_bias_at_final + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = use_tanh_at_final + + def _normalize(self, S, max_abs_value, min_db): + return torch.clamp( + (2 * max_abs_value) * ((S - min_db) / (-min_db)) - max_abs_value, + -max_abs_value, + max_abs_value, + ) + + def _amp_to_db(self, x, min_level_db): + min_level = np.exp(min_level_db / 20 * np.log(10)) + min_level = torch.ones_like(x) * min_level + return 20 * torch.log10(torch.maximum(min_level, x)) + + def apm_to_db(self, apm_mel): + mel_spec = torch.exp(apm_mel) + + mel_spec = self._amp_to_db(mel_spec, -115) - 20 + mel_spec = self._normalize(mel_spec, 1, -115) + + return mel_spec + + def forward(self, x, is_db=False): + if not is_db: + x = self.apm_to_db(x) + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + # final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1] + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class Qwen2Code2wavBigvgan(torch.nn.Module): + def __init__( + self, + ckpt, + frequency: str = "50hz", # 50hz or 25 hz + device="cpu", + with_weight_norm: bool = True, + ): + super().__init__() + self.frequency = frequency + initial_kernel = 7 if frequency == "50hz" else 5 + resblock_kernel_sizes = [3, 7, 11] if frequency == "50hz" else [3, 5, 9, 11] + resblock_dilation_sizes = ( + [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + if frequency == "50hz" + else [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]] + ) + self.vocoder = BigVGAN( + num_mels=80, + frequency=frequency, + initial_kernel=initial_kernel, + upsample_initial_channel=1536, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + resblock_type="1", + snake_logscale=True, + activation="snakebeta", + use_tanh_at_final=False, + use_bias_at_final=False, + ) + if isinstance(ckpt, str): + state_dict = torch.load(ckpt) + else: + state_dict = ckpt + + if with_weight_norm: + loaded_keys = self.vocoder.load_state_dict(state_dict["generator"], strict=False) + self.vocoder.remove_weight_norm() + else: + self.vocoder.remove_weight_norm() + loaded_keys = self.vocoder.load_state_dict(state_dict["generator"], strict=False) + unexpected_keys = [ + k for k in loaded_keys.unexpected_keys if "downsample" not in k and "upsample" not in k + ] + assert ( + unexpected_keys == [] + ), f"Unexpected keys (except downsample/upsample): {loaded_keys.unexpected_keys}" + missing_keys = [ + k for k in loaded_keys.missing_keys if "downsample" not in k and "upsample" not in k + ] + assert missing_keys == [], f"Missing keys (except downsample/upsample): {missing_keys}" + self.vocoder.eval() + self.use_f0 = False + self.mel_bin = 80 + self.device = device + self.vocoder = self.vocoder.to(device) + + @torch.no_grad() + def forward(self, mel, wav=None): + if len(mel.shape) != 3: + mel = mel.unsqueeze(0) + + if mel.shape[-1] == self.mel_bin: + mel = mel.transpose(1, 2) + + mel = mel.to(self.device) + y_g_hat = self.vocoder(mel) + audio = y_g_hat.squeeze().cpu() + return audio + + def cache_forward(self, mel, future_cache_size, past_cache_size): + if len(mel.shape) != 3: + mel = mel.unsqueeze(0) + + if mel.shape[-1] == self.mel_bin: + mel = mel.transpose(1, 2) + + mel = mel.to(self.device) + y_g_hat = self.vocoder( + mel, past_cache_size=past_cache_size, future_cache_size=future_cache_size + ) + audio = y_g_hat.squeeze().detach().cpu() + return audio + + +class Qwen2Code2wavDit(torch.nn.Module): + def __init__( + self, + ckpt, + frequency: str = "50hz", # 50hz or 25 hz + device="cpu", + ): + super().__init__() + self.freqnecy = frequency + self.device = device + self.dit = DiT( + dim=1024, + depth=22 if frequency == "50hz" else 32, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + use_codec=True, + repeats=2 if frequency == "50hz" else 4, + attn_processor="stream_block_sr" if frequency == "50hz" else "stream_block_8_L_4", + text_num_embeds=8193 if frequency == "50hz" else 32769, + mel_dim=80, + ) + self.mel_spec_kwargs = dict( + target_sample_rate=16000, + n_mel_channels=80, + hop_length=160, + ) + self.odeint_kwargs = dict( + method="rk4" if frequency == "50hz" else "euler", + ) + self.cfm_model = CodecCFM( + transformer=self.dit, + mel_spec_kwargs=self.mel_spec_kwargs, + odeint_kwargs=self.odeint_kwargs, + ).to(device) + self.cfm_model = load_checkpoint(self.cfm_model, ckpt, device, use_ema=True) + + def sample(self, cond, ref_mel, codec, steps=10, cfg_strength=0.5, sway_sampling_coef=-1.0): + y_all = torch.randn([1, 30000, 80], device=self.device, dtype=ref_mel.dtype) + expect_y_len = codec.shape[1] * (2 if self.freqnecy == "50hz" else 4) + y0 = y_all[:, :expect_y_len] + with torch.inference_mode(): + generated, _ = self.cfm_model.sample( + cond=cond, + ref_mel=ref_mel, + codec=codec, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + y0=y0, + ) + generated = generated.to(torch.float32) + generated_mel_spec = generated.permute(0, 2, 1) + return generated_mel_spec + + def fast_block_sample( + self, + cond, + codec, + ref_mel, + y0, + steps=10, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ): + return self.cfm_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + + +class Qwen2Code2wav(torch.nn.Module): + def __init__( + self, + dit_ckpt, + bigvgan_ckpt, + device="cpu", + with_weight_norm: bool = True, + frequency: str = "50hz", # 50hz or 25 hz + ): + super().__init__() + self.freqnecy = frequency + self.code2wav_dit_model = Qwen2Code2wavDit( + ckpt=dit_ckpt, frequency=frequency, device=device + ) + self.code2wav_bigvgan_model = Qwen2Code2wavBigvgan( + ckpt=bigvgan_ckpt, frequency=frequency, device=device, with_weight_norm=with_weight_norm + ) + self.device = device + + def forward(self, cond, ref_mel, codec): + generated_mel = self.code2wav_dit_model.sample(cond, ref_mel, codec) + generated_mel = generated_mel.permute(0, 2, 1) + waveform = self.code2wav_bigvgan_model(generated_mel) + return waveform + + def init_variables(self, cond, ref_mel, codec_all, bs_mel): + self.bs_codec = bs_mel // (2 if self.freqnecy == "50hz" else 4) + self.past_cache_size = bs_mel * (2 if self.freqnecy == "50hz" else 4) + self.future_cache_size = bs_mel * 1 + self.chunk_size = bs_mel * (3 if self.freqnecy == "50hz" else 1) + self.gt_codec_len = codec_all.shape[1] + self.gt_mel_len = (2 if self.frequency == "50hz" else 4) * self.gt_codec_len + if 0 < self.gt_mel_len <= bs_mel * 4: + self.n_iter = 1 + else: + self.n_iter = math.ceil((self.gt_mel_len - self.future_cache_size) / self.chunk_size) + self.future_size = 20 if self.freqnecy == "50hz" else 13 + self.past_size = 20 if self.freqnecy == "50hz" else 51 + self.generated_list = [] + self.audio_list3 = [] + self.y_all = torch.randn([1, 30000, 80], device=self.device, dtype=ref_mel.dtype) + + def process_initial_chunk(self, cond, ref_mel, codec_all, y_all, steps): + factor = 2 if self.freqnecy == "50hz" else 4 + y0 = y_all[:, : self.chunk_size + self.future_cache_size] + codec = codec_all[:, : (self.chunk_size + self.future_cache_size) // factor] + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) + self.generated_list.append(generated.to(torch.float32)[:, : self.chunk_size, :]) + mel = self.generated_list[0] + audio = self.code2wav_bigvgan_model(mel) + audio_output = audio[: -self.future_size * 240] + self.audio_list3.append(audio_output) + + def process_little_chunk(self, cond, ref_mel, codec_all, y_all, steps): + y0 = y_all[:, : self.gt_mel_len] + codec = codec_all + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) + self.generated_list.append(generated.to(torch.float32)[:, :, :]) + mel = self.generated_list[0] + audio = self.code2wav_bigvgan_model(mel) + audio_output = audio + self.audio_list3.append(audio_output) + + def process_subsequent_chunks(self, cond, ref_mel, codec_all, y_all, i, steps): + factor = 2 if self.freqnecy == "50hz" else 4 + start_index = max(i * self.chunk_size - self.past_cache_size, 0) + end_index = min((i + 1) * self.chunk_size + self.future_cache_size, self.gt_mel_len) + y0 = y_all[:, start_index:end_index] + codec = codec_all[:, start_index // factor : end_index // factor] + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) + + if self.freqnecy == "50hz": + if start_index == 0: + mel = self.generated_list[0] + self.generated_list.append( + generated.to(torch.float32)[:, i * self.chunk_size : -self.future_cache_size, :] + ) + else: + self.generated_list.append( + generated.to(torch.float32)[ + :, self.past_cache_size : -self.future_cache_size, : + ] + ) + mel = torch.cat( + [ + self.generated_list[i - 1][:, -self.future_size * 2 :, :], + self.generated_list[i], + ], + dim=1, + ) + else: + if start_index == 0: + mel = self.generated_list[0] + self.generated_list.append( + generated.to(torch.float32)[:, i * self.chunk_size : -self.future_cache_size, :] + ) + else: + self.generated_list.append( + generated.to(torch.float32)[ + :, self.past_cache_size : -self.future_cache_size, : + ] + ) + if len(self.generated_list) <= 2: + mel = torch.cat(self.generated_list, dim=1) + else: # all past mel length >= self.past_size + self.future_size + mel = torch.cat( + [ + self.generated_list[i - 2], + self.generated_list[i - 1], + self.generated_list[i], + ], + dim=1, + ) + + audio = self.code2wav_bigvgan_model(mel) + + if self.freqnecy == "50hz": + audio_output = audio[self.future_size * 240 : -self.future_size * 240] + else: + if len(self.generated_list) <= 2: + audio_output = audio[ + (self.past_size - self.chunk_size) * 240 : -self.future_size * 240 + ] + else: # all past mel length >= self.past_size + self.future_size + audio_output = audio[self.past_size * 240 : -self.future_size * 240] + self.audio_list3.append(audio_output) + + def process_final_chunk(self, cond, ref_mel, codec_all, y_all, steps): + factor = 2 if self.freqnecy == "50hz" else 4 + start_index = max((self.n_iter - 1) * self.chunk_size - self.past_cache_size, 0) + end_index = self.gt_codec_len * factor + y0 = y_all[:, start_index:end_index] + codec = codec_all[:, start_index // factor : self.gt_codec_len] + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) + self.generated_list.append(generated.to(torch.float32)[:, self.past_cache_size :, :]) + if self.freqnecy == "50hz": + mel = torch.cat( + [self.generated_list[-2][:, -self.future_size * 2 :, :], self.generated_list[-1]], + dim=1, + ) + else: + if len(self.generated_list) <= 2: + mel = torch.cat(self.generated_list, dim=1) + else: + mel = torch.cat( + [self.generated_list[-3], self.generated_list[-2], self.generated_list[-1]], + dim=1, + ) + audio = self.code2wav_bigvgan_model(mel) + if self.freqnecy == "50hz": + audio_output = audio[self.future_size * 240 :] + else: + if len(self.generated_list) <= 2: + audio_output = audio[(self.past_size - self.chunk_size) * 240 :] + else: + audio_output = audio[self.past_size * 240 :] + self.audio_list3.append(audio_output) + + def get_full_audio(self): + audio = torch.cat(self.audio_list3, dim=0) + return audio + + def fast_forward(self, cond, ref_mel, codec, steps=10, bs_mel=24): + if self.freqnecy == "50hz": + assert self.bs_mel == 24 + else: + assert self.bs_mel == 32 + self.init_variables(cond, ref_mel, codec, bs_mel) + with torch.inference_mode(): + if self.n_iter <= 0: + return + if self.n_iter == 1: + self.process_little_chunk(cond, ref_mel, codec, self.y_all, steps) + else: + self.process_initial_chunk(cond, ref_mel, codec, self.y_all, steps) + for i in range(1, self.n_iter - 1): + self.process_subsequent_chunks(cond, ref_mel, codec, self.y_all, i, steps) + + self.process_final_chunk(cond, ref_mel, codec, self.y_all, steps) + return self.get_full_audio() diff --git a/src/thirdparty/qwen2_code2wav/modeling_fast.py b/src/thirdparty/qwen2_code2wav/modeling_fast.py new file mode 100644 index 00000000..4585d0f6 --- /dev/null +++ b/src/thirdparty/qwen2_code2wav/modeling_fast.py @@ -0,0 +1,548 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import contextlib +from typing import Callable, Dict, List, Tuple, Union + +import torch +from torchdiffeq import odeint + +from .model.dit import DiT +from .model.t2w_cfm import CodecCFM +from .model.utils import exists, load_checkpoint +from .modeling import Qwen2Code2wavBigvgan + + +class CudaGraphRunner: + def __init__(self, fn, device): + """ + initialize CUDA Graph Wrapper. + + Args: + original_fast_forward (callable): original fast_forward method. + device (torch.device): CUDA device. + """ + torch._dynamo.config.cache_size_limit = 64 + self.fn_compile = torch.compile( + fn, + mode="default", + fullgraph=False, + ) + self.cuda_graph: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = dict() + self.input_buffers: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = dict() + self.output_buffers: Dict[Tuple[int, int], torch.Tensor] = dict() + + self.device = device + + # Create customized CUDA stream for Cuda Graph capture + self.capture_stream = torch.cuda.Stream(device=self.device) + + def capture_cuda_graph(self, x, cond, spk, text, time, mask): + """ + Capture CUDA Graph。 + + Args: + x (torch.Tensor): nosied input audio. + cond (torch.Tensor): masked cond audio. + spk (torch.Tensor): spk embedding. + text (torch.Tensor): text. + time (torch.Tensor): time step. + mask (torch.Tensor or None): mask. + """ + # Move the input data to the specified device and detach it from the computation graph. + + size = (text.size(0), text.size(1)) + + with torch.no_grad(): + if size not in self.input_buffers: + self.input_buffers[size] = { + "x": x.to(self.device, non_blocking=True).clone().detach(), + "cond": cond.to(self.device, non_blocking=True).clone().detach(), + "spk": spk.to(self.device, non_blocking=True).clone().detach(), + "text": text.to(self.device, non_blocking=True).clone().detach(), + "time": time.to(self.device, non_blocking=True).clone().detach(), + "mask": mask.to(self.device, non_blocking=True).clone().detach() + if mask is not None + else None, + } + + # Determine the output shape through a single forward pass and pre-allocate the output buffer. + with torch.no_grad(): + if size not in self.output_buffers: + generated = self.fn_compile( + x=self.input_buffers[size]["x"], + cond=self.input_buffers[size]["cond"], + spk=self.input_buffers[size]["spk"], + text=self.input_buffers[size]["text"], + time=self.input_buffers[size]["time"], + mask=self.input_buffers[size]["mask"], + ) + self.output_buffers[size] = torch.empty_like(generated, device=self.device) + + # Ensure that all previous operations are complete. + torch.cuda.synchronize(self.device) + + # Begin to capture CUDA Graph + self.cuda_graph[size] = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.cuda_graph[size], stream=self.capture_stream): + # Perform the forward pass and copy the results to the pre-allocated buffer. + generated = self.fn_compile( + x=self.input_buffers[size]["x"], + cond=self.input_buffers[size]["cond"], + spk=self.input_buffers[size]["spk"], + text=self.input_buffers[size]["text"], + time=self.input_buffers[size]["time"], + mask=self.input_buffers[size]["mask"], + ) + self.output_buffers[size].copy_(generated) + + # Make sure capture complete + torch.cuda.synchronize(self.device) + + def __call__(self, x, cond, spk, text, time, mask=None): + """ + Args: + x (torch.Tensor): nosied input audio. + cond (torch.Tensor): masked cond audio. + spk (torch.Tensor): spk embedding. + text (torch.Tensor): text. + time (torch.Tensor): time step. + mask (torch.Tensor or None): mask. + + Returns: + torch.Tensor: generated + """ + + size = (text.size(0), text.size(1)) + + if size not in self.cuda_graph: + # Capture the CUDA Graph on the first call. + self.capture_cuda_graph(x, cond, spk, text, time, mask) + + # Update input to buffer. + self.input_buffers[size]["x"].copy_(x.to(self.device, non_blocking=True)) + self.input_buffers[size]["cond"].copy_(cond.to(self.device, non_blocking=True)) + self.input_buffers[size]["spk"].copy_(spk.to(self.device, non_blocking=True)) + self.input_buffers[size]["text"].copy_(text.to(self.device, non_blocking=True)) + self.input_buffers[size]["time"].copy_(time.to(self.device, non_blocking=True)) + if self.input_buffers[size]["mask"] is not None and mask is not None: + self.input_buffers[size]["mask"].copy_(mask.to(self.device, non_blocking=True)) + elif mask is None: + self.input_buffers[size]["mask"] = None + # Replay CUDA Graph + self.cuda_graph[size].replay() + return self.output_buffers[size] + + +class BatchCodecCFM(CodecCFM): + @torch.no_grad() + def fast_block_sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + codec: int["b nc dc"], + ref_mel: float["b n d"], # noqa: F722 + y0: float["b n d"], + lens: int[b] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + + max_duration = y0.shape[1] + if next(self.parameters()).dtype == torch.float16: + cond = cond.half() + ref_mel = ref_mel.half() + y0 = y0.half() + # print(next(self.parameters()).dtype) + + # raw wave + + cond = cond.unsqueeze(1).repeat(1, max_duration, 1) + batch, cond_seq_len, device = *ref_mel.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + out_put = self.transformer.fast_forward( + x=x, + text=codec, + spk=cond, + cond=ref_mel, + time=t, + mask=mask, + ) + pred, null_pred = torch.chunk(out_put, 2, dim=0) + return pred + (pred - null_pred) * cfg_strength + + t_start = 0 + t = torch.linspace(t_start, 1, steps, device=self.device, dtype=ref_mel.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + # out = torch.where(cond_mask, ref_mel, out) + return out, trajectory + + +class Qwen2Code2wavDit(torch.nn.Module): + def __init__(self, ckpt, frequency: str = "50hz", device="cpu"): + super().__init__() + self.frequency = frequency + self.device = device + self.dit = DiT( + dim=1024, + depth=22 if frequency == "50hz" else 32, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + use_codec=True, + repeats=2 if frequency == "50hz" else 4, + attn_processor="stream_block_sr" if frequency == "50hz" else "stream_block_8_L_4", + text_num_embeds=8193 if frequency == "50hz" else 32769, + mel_dim=80, + ) + self.mel_spec_kwargs = dict( + target_sample_rate=16000, + n_mel_channels=80, + hop_length=160, + ) + self.odeint_kwargs = dict( + method="euler", + ) + self.cfm_model = BatchCodecCFM( + transformer=self.dit, + mel_spec_kwargs=self.mel_spec_kwargs, + odeint_kwargs=self.odeint_kwargs, + ).to(device) + self.cfm_model = load_checkpoint(self.cfm_model, ckpt, device, use_ema=True) + + def sample(self, cond, ref_mel, codec, steps=10, cfg_strength=0.5, sway_sampling_coef=-1.0): + y_all = torch.randn([1, 30000, 80], device=self.device, dtype=ref_mel.dtype) + expect_y_len = codec.shape[1] * (2 if self.frequency == "50hz" else 4) + y0 = y_all[:, :expect_y_len] + with torch.inference_mode(): + generated, _ = self.cfm_model.sample( + cond=cond, + ref_mel=ref_mel, + codec=codec, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + y0=y0, + ) + generated = generated.to(torch.float32) + generated_mel_spec = generated.permute(0, 2, 1) + return generated_mel_spec + + def fast_block_sample( + self, + cond, + codec, + ref_mel, + y0, + steps=10, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ): + return self.cfm_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + + +class Qwen2Code2wav(torch.nn.Module): + def __init__( + self, + dit_ckpt, + bigvgan_ckpt, + steps: int = 10, + bs_mel: int = 24, + odeint_method: str = "euler", + odeint_method_relaxed: bool = False, + batched_chunk: int = 3, + frequency: str = "50hz", + device="cpu", + with_weight_norm: bool = True, + ): + super().__init__() + self.frequency = frequency + self.code2wav_dit_model = Qwen2Code2wavDit( + ckpt=dit_ckpt, frequency=frequency, device=device + ) + self.code2wav_bigvgan_model = Qwen2Code2wavBigvgan( + ckpt=bigvgan_ckpt, frequency=frequency, device=device, with_weight_norm=with_weight_norm + ) + self.device = device + + # odeint method: use ruler for first and last chunk to optimize performance + self.odeint_method_relaxed = odeint_method_relaxed + + # cfm model: override the odeint method + self.code2wav_dit_model.cfm_model.odeint_kwargs["method"] = odeint_method + + # dit autocast + self.code2wav_dit_model.dit.fast_forward = torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + )(self.code2wav_dit_model.dit.fast_forward) + + self.dit_forward = self.code2wav_dit_model.dit.fast_forward + self.dit_forward_compiled = self.code2wav_dit_model.dit.fast_forward + self.dit_forward_compiled_first_chunk = self.code2wav_dit_model.dit.fast_forward + self.dit_forward_cudagraph_first = self.code2wav_dit_model.dit.fast_forward + self.dit_forward_cudagraph_intermediate = self.code2wav_dit_model.dit.fast_forward + self.dit_forward_cudagraph_last = self.code2wav_dit_model.dit.fast_forward + + self.torch_compile_first_chunk = False + + self.factor = 2 if frequency == "50hz" else 4 + self.steps = steps + self.bs_mel = bs_mel + self.bs_codec = bs_mel // self.factor + self.past_cache_size = bs_mel * self.factor + self.future_cache_size = bs_mel * 1 + self.chunk_size = bs_mel * batched_chunk + self.future_size = 20 if self.frequency == "50hz" else 13 + self.past_size = 20 if self.frequency == "50hz" else 51 + + text_embed = self.code2wav_dit_model.dit.text_embed + if hasattr(text_embed, "codec_embed"): + self.codec_embed_size = text_embed.codec_embed.weight.size(0) + elif hasattr(text_embed, "text_embed"): + self.codec_embed_size = text_embed.text_embed.weight.size(0) + else: + self.codec_embed_size = -1 + + @contextlib.contextmanager + def relax_odeint_method(self, relax: bool = False): + if relax and self.odeint_method_relaxed: + odeint_method = self.code2wav_dit_model.cfm_model.odeint_kwargs["method"] + self.code2wav_dit_model.cfm_model.odeint_kwargs["method"] = "euler" + yield + if relax and self.odeint_method_relaxed: + self.code2wav_dit_model.cfm_model.odeint_kwargs["method"] = odeint_method + + def enable_torch_compile(self, compile_first_chunk: bool = False): + self.torch_compile_first_chunk = compile_first_chunk + + self.dit_forward_compiled = torch.compile( + self.code2wav_dit_model.dit.fast_forward, + # mode="default", + mode="reduce-overhead", + fullgraph=False, + ) + self.dit_forward_cudagraph_first = CudaGraphRunner( + self.code2wav_dit_model.dit.fast_forward, self.device + ) + self.dit_forward_cudagraph_intermediate = CudaGraphRunner( + self.code2wav_dit_model.dit.fast_forward, self.device + ) + self.dit_forward_cudagraph_last = CudaGraphRunner( + self.code2wav_dit_model.dit.fast_forward, self.device + ) + + @torch.inference_mode() + def forward( + self, + cond, + ref_mel, + codec, + steps=10, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ): + generated_mel = self.code2wav_dit_model.sample( + cond, + ref_mel, + codec, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + generated_mel = generated_mel.permute(0, 2, 1) + waveform = self.code2wav_bigvgan_model(generated_mel) + return waveform + + @torch.inference_mode() + def process_chunk_dit_batch( + self, + cond, + ref_mel, + codec, + y0, + steps, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.codec_embed_size > 0: + codec[codec >= self.codec_embed_size] = 0 + + self.code2wav_dit_model.dit.fast_forward = self.dit_forward_cudagraph_intermediate + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) + return generated + + @torch.inference_mode() + def process_chunk_bigvgan_batch(self, mel_batch): + return self.code2wav_bigvgan_model(mel_batch) + + @torch.inference_mode() + def process_little_chunk( + self, + cond, + ref_mel, + codec_all, + y_all, + i, + steps, + prev_generated: torch.Tensor, + finished: bool = False, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # mask to prevent codec from being out of range (the eos token) + if self.codec_embed_size > 0: + codec_all[codec_all >= self.codec_embed_size] = 0 + + return None, self.forward( + cond, + ref_mel, + codec_all, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + + @torch.inference_mode() + def process_chunk( + self, + cond, + ref_mel, + codec_all, + y_all, + i, + steps, + prev_generated: Union[torch.Tensor, List[torch.Tensor]], + finished: bool = False, + cfg_strength=0.5, + sway_sampling_coef=-1.0, + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: + start_index = max(i * self.chunk_size - self.past_cache_size, 0) + end_index = min( + (i + 1) * self.chunk_size + self.future_cache_size, codec_all.shape[1] * self.factor + ) + y0 = y_all[:, start_index:end_index].reshape(1, -1, 80).contiguous() + codec = ( + codec_all[:, start_index // self.factor : end_index // self.factor] + .reshape(1, -1) + .contiguous() + ) + + # mask to prevent codec from being out of range (the eos token) + if self.codec_embed_size > 0: + codec[codec >= self.codec_embed_size] = 0 + + # N.B. when using cuda graph ("reduce-overhead" mode), don't compile + # shape for the first and the last chunk, as it will affect the performance + # for normal chunks. + # + # The reason is not clear yet. The default torch.compile() mode is not affected. + if i == 0: + if self.torch_compile_first_chunk: + self.code2wav_dit_model.dit.fast_forward = self.dit_forward_compiled + else: + self.code2wav_dit_model.dit.fast_forward = self.dit_forward_cudagraph_first + elif finished: + self.code2wav_dit_model.dit.fast_forward = self.dit_forward_cudagraph_last + else: + self.code2wav_dit_model.dit.fast_forward = self.dit_forward_cudagraph_intermediate + + with self.relax_odeint_method(relax=i == 0 or finished): + generated, _ = self.code2wav_dit_model.fast_block_sample( + cond=cond, + codec=codec, + ref_mel=ref_mel, + y0=y0, + steps=steps, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + + if self.frequency == "50hz": + return self.process_chunk_for_50hz( + i, + start_index, + end_index, + finished, + prev_generated, + generated, + ) + else: + raise ValueError(f"Unsupported frequency: {self.frequency}") + + def process_chunk_for_50hz( + self, + i: int, + start_index: int, + end_index: int, + finished: bool, + prev_generated: torch.Tensor, + generated: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if i == 0: + generated = generated.to(torch.float32)[:, : self.chunk_size, :] + mel = generated + elif finished: + generated = generated.to(torch.float32)[:, self.past_cache_size :, :] + mel = torch.cat([prev_generated[:, -self.future_size * 2 :, :], generated], dim=1) + else: + # Note that self.chunk_size == self.past_cache_size, so the following branch + # can be simplified. But for clearness, we keep it as it is in transformers. + if start_index == 0: + generated = generated.to(torch.float32)[ + :, i * self.chunk_size : -self.future_cache_size, : + ] + else: + generated = generated.to(torch.float32)[ + :, self.past_cache_size : -self.future_cache_size, : + ] + mel = torch.cat([prev_generated[:, -self.future_size * 2 :, :], generated], dim=1) + + audio = self.code2wav_bigvgan_model(mel) + if i == 0: + audio_output = audio[: -self.future_size * 240] + elif finished: + audio_output = audio[self.future_size * 240 :] + else: + audio_output = audio[self.future_size * 240 : -self.future_size * 240] + return generated, audio_output diff --git a/src/types/llm/sampling.py b/src/types/llm/sampling.py index a1915e85..cf16d896 100644 --- a/src/types/llm/sampling.py +++ b/src/types/llm/sampling.py @@ -22,9 +22,9 @@ class LMGenerateArgs: }, ) lm_gen_min_new_tokens: int = field( - default=0, + default=1, metadata={ - "help": "Minimum number of new tokens to generate in a single completion. Default is 0." + "help": "Minimum number of new tokens to generate in a single completion. Default is 1." }, ) lm_gen_do_sample: bool = field( @@ -88,6 +88,10 @@ class LMGenerateArgs: "help": "The pad token id. Default is 0. If the pad id is a substring token id of the generated text, the generation will stop." }, ) + lm_gen_max_tokens_per_step: int = field( + default=3, + metadata={"help": "The maximum number of tokens to generate per step. Default is 3."}, + ) def update(self, **kwargs): unused_kwargs = dict() diff --git a/src/types/llm/transformers.py b/src/types/llm/transformers.py index 9108aaaa..deddc344 100644 --- a/src/types/llm/transformers.py +++ b/src/types/llm/transformers.py @@ -58,7 +58,7 @@ class TransformersLMArgs(LMGenerateArgs): metadata={"help": "Initial role for setting up the chat context. Default is 'system'."}, ) init_chat_prompt: str = field( - default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + default="", metadata={ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" }, diff --git a/src/types/omni/qwen2_vision_voice.py b/src/types/omni/qwen2_vision_voice.py new file mode 100644 index 00000000..600db60d --- /dev/null +++ b/src/types/omni/qwen2_vision_voice.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass, field + +from src.types.llm.transformers import TransformersLMArgs +from src.thirdparty.qwen2_code2wav import Code2WavEngineConfig + + +@dataclass +class Qwen2_5TransformersVisionVoiceLMArgs(TransformersLMArgs): + """ + text+vision(Image/video)+voice(audio+speech) lm args + token2wav(dit cfm + vocoder) args + """ + + thinker_eos_token_ids: list = field(default_factory=lambda: [151644, 151645]) + thinker_stop_strings_per_step: list = field(default_factory=lambda: [".", "。"]) + thinker_args: dict = field(default_factory=lambda: TransformersLMArgs().__dict__) + talker_args: dict = field(default_factory=lambda: TransformersLMArgs().__dict__) + talker_skip_thinker_token_ids: list[int] = field(default_factory=lambda: []) + talker_eos_token_ids: list[int] = field(default_factory=lambda: [8292, 8294]) + code2wav_args: dict = field(default_factory=lambda: Code2WavEngineConfig().__dict__) + speaker: str = "Chelsie" + is_use_sliding_window_code2wav: bool = True + save_wav: bool = False + disable_talker: bool = False + thinker_all_talker_stream: bool = False + mask_embedding: bool = True diff --git a/test/modules/speech/asr/test_qwen2_5omni_asr.py b/test/modules/speech/asr/test_qwen2_5omni_asr.py new file mode 100644 index 00000000..b2f99890 --- /dev/null +++ b/test/modules/speech/asr/test_qwen2_5omni_asr.py @@ -0,0 +1,150 @@ +import logging +import unittest +import os +import asyncio + + +from src.common.logger import Logger +from src.common.utils.helper import load_json, get_audio_segment +from src.common.utils.wav import save_audio_to_file +from src.common.session import Session +from src.common.interface import IAsr +from src.common.types import SessionCtx, TEST_DIR, MODELS_DIR, RECORDS_DIR +from src.modules.speech.asr import ASREnvInit + + +r""" +LLM_MODEL_NAME_OR_PATH=./models/Qwen/Qwen2.5-Omni-7B \ + THINKER_LLM_GEN_TEMPERATURE=0.9 \ + LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 \ + python -m unittest test.modules.speech.asr.test_qwen2_5omni_asr.TestQwen2_5OmniASR.test_transcribe_stream + +LLM_MODEL_NAME_OR_PATH=./models/Qwen/Qwen2.5-Omni-7B \ + THINKER_LLM_GEN_TEMPERATURE=0.9 \ + LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 \ + python -m unittest test.modules.speech.asr.test_qwen2_5omni_asr.TestQwen2_5OmniASR.test_transcribe + +LLM_MODEL_NAME_OR_PATH=./models/Qwen/Qwen2.5-Omni-7B \ + THINKER_LLM_GEN_TEMPERATURE=0.9 \ + LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 \ + python -m unittest test.modules.speech.asr.test_qwen2_5omni_asr.TestQwen2_5OmniASR.test_transcribe_with_bytes +""" + + +class TestQwen2_5OmniASR(unittest.TestCase): + @classmethod + def setUpClass(cls): + # wget + # https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav + # -O records/asr_example_zh.wav + audio_file = os.path.join(TEST_DIR, "audio_files/asr_example_zh.wav") + # Use an environment variable to get the ASR model TAG + cls.asr_tag = os.getenv("ASR_TAG", "qwen2_5omni_asr") + cls.audio_file = os.getenv("AUDIO_FILE", audio_file) + + Logger.init(os.getenv("LOG_LEVEL", "info").upper(), is_file=False) + + @classmethod + def tearDownClass(cls): + pass + + def setUp(self): + self.asr: IAsr = ASREnvInit.initASREngine(self.asr_tag) + + self.annotations_path = os.path.join(TEST_DIR, "audio_files/annotations.json") + + self.session = Session(**SessionCtx("test_client_id", 16000, 2).__dict__) + + def tearDown(self): + pass + + def test_transcribe_stream(self): + self.asr.set_audio_data(self.audio_file) + res = self.asr.transcribe_stream_sync(self.session) + for word in res: + print(word) + self.assertGreater(len(word), 0) + + def test_transcribe(self): + self.asr.set_audio_data(self.audio_file) + res = asyncio.run(self.asr.transcribe(self.session)) + print(res) + + def test_transcribe_with_bytes(self): + with open(self.audio_file, "rb") as file: + self.asr.set_audio_data(file.read()) + res = asyncio.run(self.asr.transcribe(self.session)) + print(res) + + def test_transcribe_with_record(self): + import pyaudio + + paud = pyaudio.PyAudio() + audio_stream = paud.open( + rate=16000, channels=1, format=pyaudio.paInt16, input=True, frames_per_buffer=1024 + ) + + audio_stream.start_stream() + logging.debug("start recording") + while True: + # empty, need use vad + read_audio_frames = audio_stream.read(512) + self.asr.set_audio_data(read_audio_frames) + res = asyncio.run(self.asr.transcribe(self.session)) + logging.info(res) + if len(res) > 0: + break + + audio_stream.stop_stream() + audio_stream.close() + paud.terminate() + + def test_transcribe_segments(self): + from sentence_transformers import SentenceTransformer, util + + self.similarity_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + annotations = asyncio.run(load_json(self.annotations_path)) + + for audio_file, data in annotations.items(): + audio_file_path = os.path.join(TEST_DIR, f"audio_files/{audio_file}") + + similarities = [] + for segment in data["segments"]: + audio_segment = asyncio.run( + get_audio_segment(audio_file_path, segment["start"], segment["end"]) + ) + audio_frames = bytearray(audio_segment.raw_data) + + file_path = asyncio.run( + save_audio_to_file( + audio_frames, self.session.get_record_audio_name(), audio_dir=RECORDS_DIR + ) + ) + self.asr.set_audio_data(file_path) + # self.asr.set_audio_data(audio_frames) + + transcription = asyncio.run(self.asr.transcribe(self.session))["text"] + + os.remove(file_path) + + embedding_1 = self.similarity_model.encode( + transcription.lower().strip(), convert_to_tensor=True + ) + embedding_2 = self.similarity_model.encode( + segment["transcription"].lower().strip(), convert_to_tensor=True + ) + similarity = util.pytorch_cos_sim(embedding_1, embedding_2).item() + similarities.append(similarity) + + print(f"\nSegment from '{audio_file}' ({segment['start']}-{segment['end']}s):") + print(f"Expected: {segment['transcription']}") + print(f"Actual: {transcription}") + # self.assertGreater(len(transcription), 0) + + # Calculate average similarity for the file + avg_similarity = sum(similarities) / len(similarities) + print(f"\nAverage similarity for '{audio_file}': {avg_similarity}") + + # Assert that the average similarity is above the threshold + # Adjust the threshold as needed + self.assertGreaterEqual(avg_similarity, 0.7)