Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye committed Aug 9, 2024
1 parent f5cbbd6 commit 6b1a4dd
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 11 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ all =
sentence-transformers>=2.7.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.25.0 # fix conflict with matcha-tts
imageio-ffmpeg # For video
controlnet_aux
orjson
auto-gptq ; sys_platform!='darwin'
Expand Down Expand Up @@ -158,6 +159,9 @@ rerank =
image =
diffusers>=0.25.0 # fix conflict with matcha-tts
controlnet_aux
video =
diffusers
imageio-ffmpeg
audio =
funasr
omegaconf~=2.3.0
Expand Down
2 changes: 1 addition & 1 deletion xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def text_to_video(
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the images, detail: {_get_error_string(response)}"
f"Failed to create the video, detail: {_get_error_string(response)}"
)

response_data = response.json()
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # Fo
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
imageio-ffmpeg # For video

# sglang
outlines>=0.0.44
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ matcha-tts # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
imageio-ffmpeg # For video
87 changes: 78 additions & 9 deletions xinference/model/video/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import logging
import os
import sys
import tempfile
import time
import uuid
from typing import TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import PIL.Image
import torch

from ...constants import XINFERENCE_VIDEO_DIR
from ...device_utils import move_model_to_available_device
from ...types import VideoList
from ...types import Video, VideoList

if TYPE_CHECKING:
from .core import VideoModelFamilyV1
Expand All @@ -32,6 +39,26 @@
logger = logging.getLogger(__name__)


def export_to_video_imageio(
video_frames: Union[List[np.ndarray], List["PIL.Image.Image"]],
output_video_path: Optional[str] = None,
fps: int = 8,
) -> str:
"""
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
"""
import imageio

if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
return output_video_path


class DiffUsersVideoModel:
def __init__(
self,
Expand Down Expand Up @@ -66,7 +93,7 @@ def load(self):
from diffusers import CogVideoXPipeline

self._model = CogVideoXPipeline.from_pretrained(
self._model_path, torch_dtype=torch.float16
self._model_path, **self._kwargs
)
else:
raise Exception(
Expand All @@ -86,32 +113,74 @@ def text_to_video(
self,
prompt: str,
n: int = 1,
num_inference_steps: int = 50,
guidance_scale: int = 6,
response_format: str = "b64_json",
**kwargs,
) -> VideoList:
from diffusers.utils import export_to_video
import gc

# cv2 bug will cause the video cannot be normally displayed
# thus we use the imageio one
# from diffusers.utils import export_to_video
from ...device_utils import empty_cache

logger.debug(
"diffusers text_to_video args: %s",
kwargs,
)
assert self._model is not None
if self._kwargs.get("cpu_offload"):
# if enabled cpu offload,
# the model.device would be CPU
device = "cuda"
else:
device = self._model.device
prompt_embeds, _ = self._model.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=True,
num_videos_per_prompt=n,
max_sequence_length=226,
device=self._model.device,
device=device,
dtype=torch.float16,
)
assert callable(self._model)
output = self._model(
num_inference_steps=50,
guidance_scale=6,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
**kwargs,
)

# clean cache
gc.collect()
empty_cache()

os.makedirs(XINFERENCE_VIDEO_DIR, exist_ok=True)
urls = []
for f in output.frames:
path = os.path.join(XINFERENCE_VIDEO_DIR, uuid.uuid4().hex + ".mp4")
p = export_to_video(f, path, fps=8)
p = export_to_video_imageio(f, path, fps=8)
urls.append(p)
return VideoList(created=int(time.time()), data=urls)
if response_format == "url":
return VideoList(
created=int(time.time()),
data=[Video(url=url, b64_json=None) for url in urls],
)
elif response_format == "b64_json":

def _gen_base64_video(_video_url):
try:
with open(_video_url, "rb") as f:
buffered = BytesIO()
buffered.write(f.read())
return base64.b64encode(buffered.getvalue()).decode()
finally:
os.remove(_video_url)

with ThreadPoolExecutor() as executor:
results = list(map(partial(executor.submit, _gen_base64_video), urls)) # type: ignore
video_list = [Video(url=None, b64_json=s.result()) for s in results]
return VideoList(created=int(time.time()), data=video_list)
else:
raise ValueError(f"Unsupported response format: {response_format}")
7 changes: 6 additions & 1 deletion xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ class ImageList(TypedDict):
data: List[Image]


class Video(TypedDict):
url: Optional[str]
b64_json: Optional[str]


class VideoList(TypedDict):
created: int
data: List[str]
data: List[Video]


class EmbeddingUsage(TypedDict):
Expand Down

0 comments on commit 6b1a4dd

Please sign in to comment.