-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[Model] EVS support for nano_nemotron_vl #26267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c9388c7
86502dc
219bc0b
23a205f
e8fd68a
859e9f1
69ea5b8
0adec4b
d1a4d41
a7417d0
20fbfd7
46a2847
f9bf392
8003828
d4dc907
0261d11
55327c7
c693625
daaf453
671b93c
577110c
21face0
2e38ecf
776946a
1476b1c
7d60078
7baeed5
4021798
1572926
4256477
ba97f4f
bff1764
2a2c0b5
65944e5
eae25d9
1fdef63
d2195ab
733e515
893c7f8
1bdb001
0d64369
fd3f60f
6767f8c
f540576
62b3535
23fcf23
74f323b
175c835
c058872
f960f1e
2432e04
1f9d23d
b85d33b
e6681b4
680223f
055680f
9be6890
369f144
8e00d2e
8071c5a
edf0b6e
4ba3705
670382a
1212587
51ee4c9
73e138a
c71f8ef
d26bae4
5dd79da
5a8a8fc
e313609
54b8e41
7e71da5
5d22264
2add6d5
4ef812a
45b3629
efc7a1b
4b427d8
4d7c7eb
04d85e2
f882803
56c7852
0a212d5
4fff719
ac1dec8
abc3966
79c8bed
ec625a7
6137ac0
ee5f2ad
eeb4b15
ff8945d
635f277
a77f694
37c2551
2876b00
16c4ce6
6379eae
a26a1d3
3389e2a
63825a2
cc258ed
268ef21
001a19c
13dcdb5
5508cce
0b5de21
298e730
c2a2acd
88fb7b4
bf6ddfa
3064f88
63c869d
4e88df0
a2079d6
82e112e
9fba170
0e13c0a
a9e50dd
f7502c5
f021439
09ffe07
0a3b75c
60e9d4f
a05aa92
48f7031
2b4eadc
4023428
d0b6bef
7a9f450
16414a0
ec37d88
cf7f947
e0ad480
03386b2
9c3d84f
5a0dcdf
4298159
fe4577d
55e7e7e
6d4463d
7cf2f77
c9ae940
17546d5
3642f77
674a6cd
a0862bf
35ea5af
0ee1039
f68af14
43b6959
07f7a9a
7fe088c
c47afb0
1e50901
0505a94
46e1130
cf78827
505ce80
151293b
85b632d
371651b
abe8a61
030ccbf
4aa7dd6
2bbd103
516f106
a63a36a
e3b1d98
0e8da6c
f82a350
334ca27
5bdc29b
70d9843
b950e54
e33893e
d7ccd65
668ba11
2aa85d7
bad8d59
318f3eb
652a359
1b2424f
29a4b3d
d6e4f05
e32e5e3
c63b1fe
3a8bfdb
bea94fb
dc12348
bf02a1e
9d589a0
4ba5334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,10 @@ | |
| maybe_prefix, | ||
| ) | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.multimodal.evs import ( | ||
| compute_retained_tokens_count, | ||
| compute_retention_mask, | ||
| ) | ||
| from vllm.multimodal.inputs import ( | ||
| MultiModalDataDict, | ||
| MultiModalFieldConfig, | ||
|
|
@@ -62,13 +66,20 @@ | |
| PromptReplacement, | ||
| PromptUpdate, | ||
| PromptUpdateDetails, | ||
| _seq2tokens, | ||
| ) | ||
| from vllm.multimodal.profiling import BaseDummyInputsBuilder | ||
| from vllm.sequence import IntermediateTensors | ||
| from vllm.transformers_utils.configs.radio import RadioConfig | ||
| from vllm.transformers_utils.tokenizer import AnyTokenizer | ||
| from vllm.transformers_utils.tokenizer import ( | ||
| AnyTokenizer, | ||
| cached_tokenizer_from_config, | ||
| encode_tokens, | ||
| ) | ||
| from vllm.utils.tensor_schema import TensorSchema, TensorShape | ||
|
|
||
| from .utils import _merge_multimodal_embeddings | ||
|
|
||
| # Configure PIL to handle large images without warnings | ||
| # This prevents DecompressionBombWarning for legitimate large images | ||
| Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely | ||
|
|
@@ -382,6 +393,7 @@ def __init__( | |
| max_dynamic_patch: Optional[int] = None, | ||
| dynamic_image_size: Optional[bool] = None, | ||
| video_token: Optional[str] = None, | ||
| video_pruning_rate: Optional[float] = None, | ||
| ) -> None: | ||
| super().__init__( | ||
| config=config, | ||
|
|
@@ -392,6 +404,7 @@ def __init__( | |
| ) | ||
| # add extra video token for video processing | ||
| self.video_token = video_token | ||
| self.video_pruning_rate = video_pruning_rate | ||
|
|
||
| @property | ||
| def supports_video(self) -> bool: | ||
|
|
@@ -446,12 +459,38 @@ def _preprocess_video( | |
| ), | ||
| } | ||
|
|
||
| image_size: int = self.config.force_image_size | ||
| patch_size: int = self.config.patch_size | ||
| downsample_ratio = self.config.downsample_ratio | ||
| tokens_per_frame = int( | ||
| (image_size * image_size // patch_size**2) * (downsample_ratio**2) | ||
| ) | ||
|
|
||
| for pixel_values in pixel_values_lst_video: | ||
| num_patches = pixel_values.shape[0] | ||
| num_frames = pixel_values.shape[0] | ||
|
|
||
| if ( | ||
| self.video_pruning_rate is not None | ||
| and self.video_pruning_rate > 0.0 | ||
| ): | ||
| # Start of EVS-specific code | ||
| num_tokens = compute_retained_tokens_count( | ||
| tokens_per_frame=tokens_per_frame, | ||
| num_frames=num_frames, | ||
| q=self.video_pruning_rate, | ||
| ) | ||
|
|
||
| # Here we just need placeholders that won't actually be replaced - | ||
| # we just need to make sure the total number of tokens is correct | ||
| # assign all tokens to the first frame | ||
| tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) | ||
|
|
||
| # End of EVS-specific code | ||
| else: | ||
| tokens_per_frame = [tokens_per_frame] * num_frames | ||
|
|
||
| video_repl = self.get_video_repl(tokens_per_frame, self.video_token) | ||
|
|
||
| video_repl = self.get_video_repl( | ||
| self.num_image_token, num_patches, self.video_token | ||
| ) | ||
| text = [t.replace("<video>", video_repl.full, 1) for t in text] | ||
| return text, video_inputs | ||
|
|
||
|
|
@@ -501,20 +540,40 @@ def get_image_repl( | |
|
|
||
| return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) | ||
|
|
||
| @classmethod | ||
| def get_video_repl( | ||
| self, | ||
| feature_size: int, | ||
| num_patches: Optional[int] = None, | ||
| cls, | ||
| tokens_per_frame: list[int], | ||
| video_context_token: str = IMG_CONTEXT, | ||
| ) -> PromptUpdateDetails[str]: | ||
| repl_features = video_context_token * self.num_image_token | ||
| repl_features_with_sep = IMG_START + repl_features + IMG_END | ||
| # num_patches is equal to num_frames | ||
| """ | ||
| Build prompt replacement for a video. | ||
| The replacement returned is not actually used to replace the placeholder | ||
| tokens - it's just used to make sure we allocate the correct number | ||
| of tokens. | ||
| Actual replacement is done in get_multimodal_embeddings of | ||
| NemotronH_Nano_VL_V2 | ||
| (specifically in _process_video_input -> _create_final_video_embeddings). | ||
| There, we create the final embeddings with text embeddings for indicator tokens | ||
| and video embeddings for video tokens. | ||
| This is a single function that handles all cases - non EVS, EVS dummy, EVS real. | ||
| The differentiation is done via tokens_per_frame parameter. | ||
| - non EVS case - constant value same value across all frames | ||
| - EVS dummy - Doesn't matter how tokens are distributed between frames - just | ||
| make sure the total number of tokens is correct. | ||
| - EVS real (called from get_real_video_repl_for_evs) - different value per frame | ||
| Args: | ||
| tokens_per_frame (list[int]): number of tokens per frame | ||
| video_context_token (str): the token to use for the video context | ||
| """ | ||
| repl_full = "".join( | ||
| [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] | ||
| [ | ||
| f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}" | ||
| for i, num_tokens in enumerate(tokens_per_frame) | ||
| ] | ||
| ) | ||
|
|
||
| return PromptUpdateDetails.select_text(repl_full, video_context_token) | ||
| return PromptUpdateDetails.select_text(repl_full, repl_full) | ||
|
|
||
|
|
||
| class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): | ||
|
|
@@ -605,6 +664,9 @@ def get_supported_mm_limits(self): | |
| def get_video_token(self) -> Optional[str]: | ||
| return IMG_CONTEXT | ||
|
|
||
| def get_video_pruning_rate(self) -> Optional[float]: | ||
| return self.ctx.get_mm_config().video_pruning_rate | ||
|
|
||
| def get_num_frames_with_most_features( | ||
| self, | ||
| seq_len: int, | ||
|
|
@@ -628,6 +690,7 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: | |
| config=self.get_hf_config(), | ||
| tokenizer=self.get_tokenizer(), | ||
| video_token=self.get_video_token(), | ||
| video_pruning_rate=self.get_video_pruning_rate(), | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -805,8 +868,26 @@ def get_video_replacement_internvl(item_idx: int): | |
| if num_patches is not None: | ||
| assert isinstance(num_patches, int) | ||
|
|
||
| video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate | ||
| if video_pruning_rate is not None and video_pruning_rate > 0.0: | ||
| # Start of EVS-specific code | ||
| num_tokens = compute_retained_tokens_count( | ||
| tokens_per_frame=feature_size, | ||
| num_frames=num_patches, | ||
| q=video_pruning_rate, | ||
| ) | ||
| # Here we just need placeholders that won't actually be replaced - | ||
| # we just need to make sure the total number of tokens is correct | ||
| # assign all tokens to the first frame | ||
| tokens_per_frame = [num_tokens] + [0] * (num_patches - 1) | ||
|
|
||
| # End of EVS-specific code | ||
| else: | ||
| tokens_per_frame = [feature_size] * num_patches | ||
|
|
||
| return hf_processor.get_video_repl( | ||
| feature_size, num_patches, video_context_token=hf_processor.video_token | ||
| tokens_per_frame, | ||
| video_context_token=hf_processor.video_token, | ||
| ) | ||
|
|
||
| if self.info.supports_video: | ||
|
|
@@ -913,7 +994,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: | |
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| super().__init__() | ||
| config = vllm_config.model_config.hf_config | ||
|
|
||
| multimodal_config = vllm_config.model_config.multimodal_config | ||
| image_size = config.force_image_size | ||
| patch_size = config.patch_size | ||
| self.patch_size = patch_size | ||
|
|
@@ -924,7 +1005,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.downsample_ratio = config.downsample_ratio | ||
| self.ps_version = config.ps_version | ||
| self.image_tag_type = config.image_tag_type | ||
|
|
||
| self.video_pruning_rate = multimodal_config.video_pruning_rate | ||
| self.language_model = init_vllm_registered_model( | ||
| vllm_config=vllm_config, | ||
| hf_config=config.text_config, | ||
|
|
@@ -957,6 +1038,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.img_context_token_id = None | ||
| self.video_context_token_id = None | ||
| self.config = config | ||
| self.model_config = vllm_config.model_config | ||
|
|
||
| def pixel_shuffle(self, x, scale_factor=0.5): | ||
| n, w, h, c = x.size() | ||
|
|
@@ -1049,7 +1131,7 @@ def _parse_and_validate_image_input( | |
|
|
||
| def _process_image_input( | ||
| self, image_input: NanoNemotronVLImageInputs | ||
| ) -> torch.Tensor: | ||
| ) -> tuple[torch.Tensor, ...]: | ||
| if image_input["type"] == "image_embeds": | ||
| return image_input["data"] | ||
|
|
||
|
|
@@ -1071,6 +1153,109 @@ def _process_image_input( | |
| ] | ||
| return image_embeds.split(image_feature_sizes) | ||
|
|
||
| def _process_video_input( | ||
| self, video_input: NanoNemotronVLVideoPixelInputs | ||
| ) -> tuple[torch.Tensor, ...]: | ||
| """Process video input and create final embeddings with video content | ||
| and indicator tokens.""" | ||
| # Get video embeddings using the same processing as images | ||
| video_embeddings = self._process_image_input(video_input) | ||
|
|
||
| final_video_embeddings: tuple[torch.Tensor, ...] = () | ||
|
|
||
| image_rows = image_cols = self.config.force_image_size | ||
| downsample_ratio = self.config.downsample_ratio | ||
| patch_size = self.config.patch_size | ||
| rows = int(image_rows * downsample_ratio // patch_size) | ||
| cols = int(image_cols * downsample_ratio // patch_size) | ||
| video_pruning_rate = self.video_pruning_rate | ||
|
|
||
| # Calculate video feature dimensions (number of frames and | ||
| # their feature size (AKA tokens per frame)) | ||
| # TODO: Maybe this can be optimized to avoid the loop? | ||
| for i, single_video_embeddings in enumerate(video_embeddings): | ||
| num_frames = video_input["num_patches"][i].item() | ||
| assert single_video_embeddings.shape[0] % num_frames == 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using if single_video_embeddings.shape[0] % num_frames != 0:
raise ValueError(
f"The number of video embeddings ({single_video_embeddings.shape[0]}) "
f"is not divisible by the number of frames ({num_frames})."
) |
||
|
|
||
| if video_pruning_rate is not None and video_pruning_rate > 0.0: | ||
| # Start of EVS-specific code | ||
| retention_mask = compute_retention_mask( | ||
| single_video_embeddings, | ||
| video_size_thw=torch.tensor([num_frames, rows, cols]), | ||
| spatial_merge_size=1, | ||
| q=video_pruning_rate, | ||
| ) | ||
|
|
||
| # apply retention mask | ||
| single_video_embeddings = single_video_embeddings[retention_mask] | ||
|
|
||
| # calculate the actual number of retained tokens per frame | ||
| retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) | ||
| num_tokens_per_frame = ( | ||
| retention_mask_thw.sum(dim=(1, 2)).long().tolist() | ||
| ) | ||
| # End of EVS-specific code | ||
| else: | ||
| feature_size = single_video_embeddings.shape[0] // num_frames | ||
| num_tokens_per_frame = [feature_size] * num_frames | ||
|
|
||
| final_video_embeddings += ( | ||
| self._create_final_video_embeddings( | ||
| single_video_embeddings, | ||
| num_tokens_per_frame, | ||
| ), | ||
| ) | ||
|
|
||
| return final_video_embeddings | ||
|
|
||
| def _create_final_video_embeddings( | ||
| self, | ||
| video_embeddings: torch.Tensor, | ||
| num_tokens_per_frame: list[int], | ||
| ) -> torch.Tensor: | ||
| """Create final embeddings that combine video embeddings with | ||
| text embeddings of indicator tokens. | ||
|
|
||
| These final embeddings contain: | ||
| - Actual video embeddings in positions corresponding to video content | ||
| - Text embeddings for indicator tokens (<img>, </img>, and | ||
| frame separation text) in their respective positions | ||
|
|
||
| These embeddings will replace the placeholder embeddings to create | ||
| input_embeds for the LLM. | ||
| """ | ||
| device = video_embeddings.device | ||
|
|
||
| # Generate video replacement text and convert to token IDs | ||
| video_repl_text = NanoNemotronVLProcessor.get_video_repl( | ||
| num_tokens_per_frame, | ||
| IMG_CONTEXT, | ||
| ).full | ||
|
|
||
| tokenizer = cached_tokenizer_from_config(self.model_config) | ||
| repl_token_ids = torch.tensor( | ||
| _seq2tokens(tokenizer, video_repl_text), device=device | ||
| ) | ||
|
|
||
| # Get embedding token IDs for image context | ||
| embed_token_ids = torch.tensor( | ||
| encode_tokens(tokenizer, IMG_CONTEXT), device=device | ||
| ) | ||
|
|
||
| # Create mask for video embedding positions | ||
| is_video_embed = torch.isin(repl_token_ids, embed_token_ids) | ||
|
|
||
| # Create final video embeddings, merging text embeddings for indicator | ||
| # tokens with video embeddings | ||
| text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids) | ||
| final_video_embeddings = _merge_multimodal_embeddings( | ||
| inputs_embeds=text_embeddings, | ||
| multimodal_embeddings=video_embeddings, | ||
| is_multimodal=is_video_embed, | ||
| ) | ||
|
|
||
| return final_video_embeddings | ||
|
|
||
| def _parse_and_validate_video_input( | ||
| self, **kwargs: object | ||
| ) -> Optional[NanoNemotronVLVideoPixelInputs]: | ||
|
|
@@ -1152,7 +1337,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: | |
| multimodal_embeddings += vision_embeddings | ||
| if modality == "videos": | ||
| video_input = modalities["videos"] | ||
| video_embeddings = self._process_image_input(video_input) | ||
| video_embeddings = self._process_video_input(video_input) | ||
| multimodal_embeddings += video_embeddings | ||
|
|
||
| return multimodal_embeddings | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inside
_preprocess_videothe variabletokens_per_frameis initialized once outside the loop as an integer per frame, but the loop overwrites it with a list ([num_tokens] + [0] * …or[tokens_per_frame] * num_frames). On the next iteration the same variable is passed back intocompute_retained_tokens_count, which now receives a list instead of anintand will raise or generate malformed placeholders whenever more than one video is present. Multi‑video batches with EVS enabled will crash and those without EVS will produce nested lists and incorrect token counts.Useful? React with 👍 / 👎.