From 94e0fe7f3019f5824081e8452881114ce1efa07c Mon Sep 17 00:00:00 2001 From: seungwoos Date: Wed, 5 Feb 2025 18:31:11 +0900 Subject: [PATCH 1/3] Add computed position embedding external --- awq/quantize/quantizer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 9823ad19..02a7047c 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -1,3 +1,4 @@ +import transformers import torch import inspect import logging @@ -153,6 +154,15 @@ def quantize(self): # https://github.com/huggingface/transformers/pull/32617 self.awq_model.move_embed(self.model, common_device) + # Transformers >= 4.48.0 requires positional embeddings should be computed before forward pass + if ( + transformers.__version__ >= "4.48.0" + and self.module_kwargs.get("position_embeddings") is None + ): + self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb( + self.inps, self.module_kwargs["position_ids"] + ) + for k, v in self.module_kwargs.items(): # position embeddings found in tuple if isinstance(v, tuple): From ea724a17e226054a359db25f357768a27549abce Mon Sep 17 00:00:00 2001 From: seungwoos Date: Wed, 5 Feb 2025 18:48:50 +0900 Subject: [PATCH 2/3] Add Qwen2.5 VL model support --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 3 +- awq/models/qwen2_5_vl.py | 81 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 awq/models/qwen2_5_vl.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 9be2e8bd..52e42d9b 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -29,3 +29,4 @@ from .internlm2 import InternLM2AWQForCausalLM from .minicpm3 import MiniCPM3AWQForCausalLM from .qwen2vl import Qwen2VLAWQForCausalLM +from .qwen2_5_vl import Qwen2_5_VLAWQForCausalLM \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index e88ba584..ea90a344 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -40,6 +40,7 @@ "internlm2": InternLM2AWQForCausalLM, "minicpm3": MiniCPM3AWQForCausalLM, "qwen2_vl": Qwen2VLAWQForCausalLM, + "qwen2_5_vl": Qwen2_5_VLAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index e63e903e..ac0baabf 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -84,9 +84,10 @@ "deepseek_v2": "AutoModelForCausalLM", "deepseek_v3": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", - "minicpm3":"AutoModelForCausalLM", + "minicpm3": "AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", "qwen2_vl": "AutoModelForVision2Seq", + "qwen2_5_vl": "AutoModelForVision2Seq", } diff --git a/awq/models/qwen2_5_vl.py b/awq/models/qwen2_5_vl.py new file mode 100644 index 00000000..849103b3 --- /dev/null +++ b/awq/models/qwen2_5_vl.py @@ -0,0 +1,81 @@ +from .base import BaseAWQForCausalLM +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLDecoderLayer, + ) + + +class Qwen2_5_VLAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Qwen2_5_VLDecoderLayer" + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["visual"] + + @staticmethod + def get_model_layers(model: "Qwen2_5_VLForConditionalGeneration"): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: "Qwen2_5_VLForConditionalGeneration"): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: "Qwen2_5_VLForConditionalGeneration", device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + model.visual = model.visual.to(device) + model.model.rotary_emb = model.model.rotary_emb.to(device) + + @staticmethod + def get_layers_for_scaling( + module: "Qwen2_5_VLDecoderLayer", input_feat, module_kwargs + ): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers From 4448165791296d843fbe7588c1d5973a251a9ff1 Mon Sep 17 00:00:00 2001 From: seungwoos Date: Thu, 6 Feb 2025 17:11:35 +0900 Subject: [PATCH 3/3] Update Qwen VL Utils --- awq/utils/qwen_vl_utils.py | 141 +++++++++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 29 deletions(-) diff --git a/awq/utils/qwen_vl_utils.py b/awq/utils/qwen_vl_utils.py index 08ba02f7..083ce266 100644 --- a/awq/utils/qwen_vl_utils.py +++ b/awq/utils/qwen_vl_utils.py @@ -17,6 +17,7 @@ from PIL import Image from torchvision import io, transforms from torchvision.transforms import InterpolationMode +from typing import Optional logger = logging.getLogger(__name__) @@ -28,12 +29,19 @@ VIDEO_MIN_PIXELS = 128 * 28 * 28 VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 768 +# Set the maximum number of video token inputs. +# Here, 128K represents the maximum number of input tokens for the VLLM model. +# Remember to adjust it according to your own configuration. +VIDEO_TOTAL_PIXELS = int( + float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9)) +) +logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}") + def round_by_factor(number: int, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" @@ -51,7 +59,11 @@ def floor_by_factor(number: int, factor: int) -> int: def smart_resize( - height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, ) -> tuple[int, int]: """ Rescales the image so that the following conditions are met: @@ -79,7 +91,20 @@ def smart_resize( return h_bar, w_bar -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == "RGBA": + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste( + pil_image, mask=pil_image.split()[3] + ) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +def fetch_image( + ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR +) -> Image.Image: if "image" in ele: image = ele["image"] else: @@ -88,7 +113,8 @@ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACT if isinstance(image, Image.Image): image_obj = image elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) + response = requests.get(image, stream=True) + image_obj = Image.open(BytesIO(response.content)) elif image.startswith("file://"): image_obj = Image.open(image[7:]) elif image.startswith("data:image"): @@ -99,8 +125,10 @@ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACT else: image_obj = Image.open(image) if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") + raise ValueError( + f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" + ) + image = to_rgb(image_obj) ## resize if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( @@ -147,24 +175,34 @@ def smart_nframes( Returns: int: the number of frames for video used for model inputs. """ - assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + assert not ("fps" in ele and "nframes" in ele), ( + "Only accept either `fps` or `nframes`" + ) if "nframes" in ele: nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) else: fps = ele.get("fps", FPS) min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) - max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) + max_frames = floor_by_factor( + ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR + ) nframes = total_frames / video_fps * fps - nframes = min(max(nframes, min_frames), max_frames) - nframes = round_by_factor(nframes, FRAME_FACTOR) + if nframes > total_frames: + logger.warning( + f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]" + ) + nframes = min(min(max(nframes, min_frames), max_frames), total_frames) + nframes = floor_by_factor(nframes, FRAME_FACTOR) if not (FRAME_FACTOR <= nframes and nframes <= total_frames): - raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") + raise ValueError( + f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." + ) return nframes def _read_video_torchvision( ele: dict, -) -> torch.Tensor: +) -> (torch.Tensor, float): """read video using torchvision.io.read_video Args: @@ -179,7 +217,9 @@ def _read_video_torchvision( video_path = ele["video"] if version.parse(torchvision.__version__) < version.parse("0.19.0"): if "http://" in video_path or "https://" in video_path: - warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + warnings.warn( + "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." + ) if "file://" in video_path: video_path = video_path[7:] st = time.time() @@ -191,11 +231,14 @@ def _read_video_torchvision( output_format="TCHW", ) total_frames, video_fps = video.size(0), info["video_fps"] - logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + logger.info( + f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" + ) nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long() + sample_fps = nframes / max(total_frames, 1e-6) * video_fps video = video[idx] - return video + return video, sample_fps def is_decord_available() -> bool: @@ -206,7 +249,7 @@ def is_decord_available() -> bool: def _read_video_decord( ele: dict, -) -> torch.Tensor: +) -> (torch.Tensor, float): """read video using decord.VideoReader Args: @@ -219,19 +262,25 @@ def _read_video_decord( torch.Tensor: the video tensor with shape (T, C, H, W). """ import decord + video_path = ele["video"] st = time.time() vr = decord.VideoReader(video_path) # TODO: support start_pts and end_pts - if 'video_start' in ele or 'video_end' in ele: - raise NotImplementedError("not support start_pts and end_pts in decord for now.") + if "video_start" in ele or "video_end" in ele: + raise NotImplementedError( + "not support start_pts and end_pts in decord for now." + ) total_frames, video_fps = len(vr), vr.get_avg_fps() - logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + logger.info( + f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" + ) nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() video = vr.get_batch(idx).asnumpy() video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format - return video + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + return video, sample_fps VIDEO_READER_BACKENDS = { @@ -254,16 +303,32 @@ def get_video_reader_backend() -> str: return video_reader_backend -def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: +def fetch_video( + ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False +) -> torch.Tensor | list[Image.Image]: if isinstance(ele["video"], str): video_reader_backend = get_video_reader_backend() - video = VIDEO_READER_BACKENDS[video_reader_backend](ele) - nframes, _, height, width = video.shape + try: + video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele) + except Exception as e: + logger.warning( + f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}" + ) + video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele) + nframes, _, height, width = video.shape min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) - max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) - max_pixels = ele.get("max_pixels", max_pixels) + max_pixels = max( + min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), + int(min_pixels * 1.05), + ) + max_pixels_supposed = ele.get("max_pixels", max_pixels) + if max_pixels_supposed > max_pixels: + logger.warning( + f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]." + ) + max_pixels = min(max_pixels_supposed, max_pixels) if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( ele["resized_height"], @@ -284,6 +349,8 @@ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | l interpolation=InterpolationMode.BICUBIC, antialias=True, ).float() + if return_video_sample_fps: + return video, sample_fps return video else: assert isinstance(ele["video"], (list, tuple)) @@ -291,12 +358,16 @@ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | l process_info.pop("type", None) process_info.pop("video", None) images = [ - fetch_image({"image": video_element, **process_info}, size_factor=image_factor) + fetch_image( + {"image": video_element, **process_info}, size_factor=image_factor + ) for video_element in ele["video"] ] nframes = ceil_by_factor(len(images), FRAME_FACTOR) if len(images) < nframes: images.extend([images[-1]] * (nframes - len(images))) + if return_video_sample_fps: + return images, process_info.pop("fps", 2.0) return images @@ -320,20 +391,32 @@ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[di def process_vision_info( conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: + return_video_kwargs: bool = False, +) -> tuple[ + list[Image.Image] | None, + list[torch.Tensor | list[Image.Image]] | None, + Optional[dict], +]: vision_infos = extract_vision_info(conversations) ## Read images or videos image_inputs = [] video_inputs = [] + video_sample_fps_list = [] for vision_info in vision_infos: if "image" in vision_info or "image_url" in vision_info: image_inputs.append(fetch_image(vision_info)) elif "video" in vision_info: - video_inputs.append(fetch_video(vision_info)) + video_input, video_sample_fps = fetch_video( + vision_info, return_video_sample_fps=True + ) + video_sample_fps_list.append(video_sample_fps) + video_inputs.append(video_input) else: raise ValueError("image, image_url or video should in content.") if len(image_inputs) == 0: image_inputs = None if len(video_inputs) == 0: video_inputs = None - return image_inputs, video_inputs \ No newline at end of file + if return_video_kwargs: + return image_inputs, video_inputs, {"fps": video_sample_fps_list} + return image_inputs, video_inputs