From 6b1a4dd3b620fa53d7ef9c8283785c6f2b72ab5d Mon Sep 17 00:00:00 2001 From: qinxuye Date: Fri, 9 Aug 2024 07:10:56 +0000 Subject: [PATCH] fix --- setup.cfg | 4 + xinference/client/restful/restful_client.py | 2 +- xinference/deploy/docker/requirements.txt | 1 + xinference/deploy/docker/requirements_cpu.txt | 1 + xinference/model/video/diffusers.py | 87 +++++++++++++++++-- xinference/types.py | 7 +- 6 files changed, 91 insertions(+), 11 deletions(-) diff --git a/setup.cfg b/setup.cfg index 05e33f9b5a..8664c59fdd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ all = sentence-transformers>=2.7.0 vllm>=0.2.6 ; sys_platform=='linux' diffusers>=0.25.0 # fix conflict with matcha-tts + imageio-ffmpeg # For video controlnet_aux orjson auto-gptq ; sys_platform!='darwin' @@ -158,6 +159,9 @@ rerank = image = diffusers>=0.25.0 # fix conflict with matcha-tts controlnet_aux +video = + diffusers + imageio-ffmpeg audio = funasr omegaconf~=2.3.0 diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 24ba0b7038..c11c30c29f 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -402,7 +402,7 @@ def text_to_video( response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: raise RuntimeError( - f"Failed to create the images, detail: {_get_error_string(response)}" + f"Failed to create the video, detail: {_get_error_string(response)}" ) response_data = response.json() diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt index 66f6d650af..1830a7de25 100644 --- a/xinference/deploy/docker/requirements.txt +++ b/xinference/deploy/docker/requirements.txt @@ -60,6 +60,7 @@ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # Fo openai-whisper # For CosyVoice boto3>=1.28.55,<1.28.65 # For tensorizer tensorizer~=2.9.0 +imageio-ffmpeg # For video # sglang outlines>=0.0.44 diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt index a117e0c549..7ae0a2544d 100644 --- a/xinference/deploy/docker/requirements_cpu.txt +++ b/xinference/deploy/docker/requirements_cpu.txt @@ -55,3 +55,4 @@ matcha-tts # For CosyVoice onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice openai-whisper # For CosyVoice +imageio-ffmpeg # For video diff --git a/xinference/model/video/diffusers.py b/xinference/model/video/diffusers.py index 930cefa09e..0cd951aa62 100644 --- a/xinference/model/video/diffusers.py +++ b/xinference/model/video/diffusers.py @@ -12,18 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import os import sys +import tempfile import time import uuid -from typing import TYPE_CHECKING +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from io import BytesIO +from typing import TYPE_CHECKING, List, Optional, Union +import numpy as np +import PIL.Image import torch from ...constants import XINFERENCE_VIDEO_DIR from ...device_utils import move_model_to_available_device -from ...types import VideoList +from ...types import Video, VideoList if TYPE_CHECKING: from .core import VideoModelFamilyV1 @@ -32,6 +39,26 @@ logger = logging.getLogger(__name__) +def export_to_video_imageio( + video_frames: Union[List[np.ndarray], List["PIL.Image.Image"]], + output_video_path: Optional[str] = None, + fps: int = 8, +) -> str: + """ + Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) + """ + import imageio + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + with imageio.get_writer(output_video_path, fps=fps) as writer: + for frame in video_frames: + writer.append_data(frame) + return output_video_path + + class DiffUsersVideoModel: def __init__( self, @@ -66,7 +93,7 @@ def load(self): from diffusers import CogVideoXPipeline self._model = CogVideoXPipeline.from_pretrained( - self._model_path, torch_dtype=torch.float16 + self._model_path, **self._kwargs ) else: raise Exception( @@ -86,32 +113,74 @@ def text_to_video( self, prompt: str, n: int = 1, + num_inference_steps: int = 50, + guidance_scale: int = 6, + response_format: str = "b64_json", **kwargs, ) -> VideoList: - from diffusers.utils import export_to_video + import gc + + # cv2 bug will cause the video cannot be normally displayed + # thus we use the imageio one + # from diffusers.utils import export_to_video + from ...device_utils import empty_cache logger.debug( "diffusers text_to_video args: %s", kwargs, ) assert self._model is not None + if self._kwargs.get("cpu_offload"): + # if enabled cpu offload, + # the model.device would be CPU + device = "cuda" + else: + device = self._model.device prompt_embeds, _ = self._model.encode_prompt( prompt=prompt, do_classifier_free_guidance=True, num_videos_per_prompt=n, max_sequence_length=226, - device=self._model.device, + device=device, dtype=torch.float16, ) assert callable(self._model) output = self._model( - num_inference_steps=50, - guidance_scale=6, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, + **kwargs, ) + + # clean cache + gc.collect() + empty_cache() + + os.makedirs(XINFERENCE_VIDEO_DIR, exist_ok=True) urls = [] for f in output.frames: path = os.path.join(XINFERENCE_VIDEO_DIR, uuid.uuid4().hex + ".mp4") - p = export_to_video(f, path, fps=8) + p = export_to_video_imageio(f, path, fps=8) urls.append(p) - return VideoList(created=int(time.time()), data=urls) + if response_format == "url": + return VideoList( + created=int(time.time()), + data=[Video(url=url, b64_json=None) for url in urls], + ) + elif response_format == "b64_json": + + def _gen_base64_video(_video_url): + try: + with open(_video_url, "rb") as f: + buffered = BytesIO() + buffered.write(f.read()) + return base64.b64encode(buffered.getvalue()).decode() + finally: + os.remove(_video_url) + + with ThreadPoolExecutor() as executor: + results = list(map(partial(executor.submit, _gen_base64_video), urls)) # type: ignore + video_list = [Video(url=None, b64_json=s.result()) for s in results] + return VideoList(created=int(time.time()), data=video_list) + else: + raise ValueError(f"Unsupported response format: {response_format}") diff --git a/xinference/types.py b/xinference/types.py index 4e1106ae51..3f636d94c3 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -52,9 +52,14 @@ class ImageList(TypedDict): data: List[Image] +class Video(TypedDict): + url: Optional[str] + b64_json: Optional[str] + + class VideoList(TypedDict): created: int - data: List[str] + data: List[Video] class EmbeddingUsage(TypedDict):