diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py new file mode 100644 index 00000000000..f5676a0056f --- /dev/null +++ b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py @@ -0,0 +1,260 @@ +import inspect +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( + QwenImagePipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( + QwenImageEditPipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( + QwenImageEditPlusPipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered import ( + QwenImageLayeredPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _RejectingTextEncoder: + dtype = torch.float32 + + def __call__(self, *args, **kwargs): + raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") + + +class _FakeModelInputs: + def __init__(self, total_sequence_length: int): + attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) + self.input_ids = attention_mask.clone() + self.attention_mask = attention_mask + self.pixel_values = None + self.image_grid_thw = None + + def to(self, device): + return self + + +class _FakeTokenizer: + def __init__(self, total_sequence_length: int | list[int]): + if isinstance(total_sequence_length, list): + self.total_sequence_lengths = list(total_sequence_length) + else: + self.total_sequence_lengths = [total_sequence_length] + + def __call__(self, *args, **kwargs): + if len(self.total_sequence_lengths) > 1: + total_sequence_length = self.total_sequence_lengths.pop(0) + else: + total_sequence_length = self.total_sequence_lengths[0] + return _FakeModelInputs(total_sequence_length) + + +class _FakeProcessor(_FakeTokenizer): + pass + + +class _FakeScheduler: + def __init__(self): + self.begin_index = None + + def set_begin_index(self, begin_index: int): + self.begin_index = begin_index + + +PIPELINE_CASES = [ + pytest.param(QwenImagePipeline, 34, "tokenizer", id="qwen-image"), + pytest.param(QwenImageLayeredPipeline, 34, "tokenizer", id="qwen-image-layered"), + pytest.param(QwenImageEditPipeline, 64, "processor", id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, 64, "processor", id="qwen-image-edit-plus"), +] + + +def _make_pipeline( + pipeline_class: type, + *, + total_sequence_length: int, + drop_idx: int, + input_kind: str, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = drop_idx + pipeline.tokenizer = _FakeTokenizer([total_sequence_length, 0]) + if input_kind == "processor": + pipeline.processor = _FakeProcessor(total_sequence_length) + return pipeline + + +@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length( + pipeline_class: type, + drop_idx: int, + input_kind: str, +): + pipeline = _make_pipeline( + pipeline_class, + total_sequence_length=1025, + drop_idx=drop_idx, + input_kind=input_kind, + ) + + with pytest.raises(ValueError, match=r"got 1025 tokens, but `max_sequence_length` is 1024"): + pipeline.encode_prompt(prompt="prompt") + + +@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length( + pipeline_class: type, + drop_idx: int, + input_kind: str, +): + pipeline = _make_pipeline( + pipeline_class, + total_sequence_length=17, + drop_idx=drop_idx, + input_kind=input_kind, + ) + + with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): + pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) + + +def test_prepare_encode_defaults_to_tokenizer_max_length(): + pipeline = object.__new__(QwenImagePipeline) + nn.Module.__init__(pipeline) + pipeline.tokenizer_max_length = 1024 + pipeline.vae_scale_factor = 8 + pipeline.default_sample_size = 128 + pipeline.scheduler = _FakeScheduler() + pipeline._extract_prompts = lambda prompts: (["prompt"], None) + + captured = {} + + def _fake_prepare_generation_context(**kwargs): + captured["max_sequence_length"] = kwargs["max_sequence_length"] + embeds = torch.ones((1, 1, 1)) + mask = torch.ones((1, 1), dtype=torch.long) + return { + "prompt_embeds": embeds, + "prompt_embeds_mask": mask, + "negative_prompt_embeds": None, + "negative_prompt_embeds_mask": None, + "latents": embeds, + "timesteps": torch.tensor([1]), + "do_true_cfg": False, + "guidance": None, + "img_shapes": [[(1, 1, 1)]], + "txt_seq_lens": [1], + "negative_txt_seq_lens": None, + } + + pipeline._prepare_generation_context = _fake_prepare_generation_context + state = SimpleNamespace( + prompts=["prompt"], + sampling=SimpleNamespace( + height=None, + width=None, + num_inference_steps=None, + sigmas=None, + guidance_scale_provided=False, + num_outputs_per_prompt=0, + generator=None, + true_cfg_scale=None, + max_sequence_length=None, + ), + ) + + pipeline.prepare_encode(state) + + assert captured["max_sequence_length"] == 1024 + + +@pytest.mark.parametrize( + ("pipeline_class", "drop_idx"), + [ + pytest.param(QwenImageEditPipeline, 64, id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, 64, id="qwen-image-edit-plus"), + ], +) +def test_edit_pipelines_validate_text_prompt_length_before_image_token_expansion( + pipeline_class: type, + drop_idx: int, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = drop_idx + pipeline.tokenizer = _FakeTokenizer([8, 0]) + pipeline.processor = _FakeProcessor(drop_idx + 1500) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="short prompt") + + +@pytest.mark.parametrize( + "pipeline_class", + [ + pytest.param(QwenImagePipeline, id="qwen-image"), + pytest.param(QwenImageLayeredPipeline, id="qwen-image-layered"), + ], +) +def test_qwen_generation_validator_excludes_template_suffix_from_budget(pipeline_class: type): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = 34 + pipeline.tokenizer = _FakeTokenizer([1029, 5]) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="boundary prompt") + + +@pytest.mark.parametrize( + "pipeline_class", + [ + pytest.param(QwenImageEditPipeline, id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, id="qwen-image-edit-plus"), + ], +) +def test_qwen_edit_validator_excludes_image_placeholders_from_budget(pipeline_class: type): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = 64 + pipeline.tokenizer = _FakeTokenizer([30, 20]) + pipeline.processor = _FakeProcessor(1500) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="short prompt") + + +@pytest.mark.parametrize( + "pipeline_class", + [ + QwenImagePipeline, + QwenImageLayeredPipeline, + QwenImageEditPipeline, + QwenImageEditPlusPipeline, + ], +) +def test_forward_max_sequence_length_default_is_1024(pipeline_class: type): + assert inspect.signature(pipeline_class.forward).parameters["max_sequence_length"].default == 1024 diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py new file mode 100644 index 00000000000..64c2b271c9c --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -0,0 +1,150 @@ +from types import SimpleNamespace + +import PIL.Image +import pytest +import torch +from torch import nn + +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, + Wan22Pipeline, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import ( + Wan22I2VPipeline, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import ( + Wan22TI2VPipeline, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import ( + Wan22VACEPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _RejectingTextEncoder: + dtype = torch.float32 + + def __call__(self, *args, **kwargs): + raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") + + +class _FakeTokenBatch: + def __init__(self, total_sequence_length: int): + attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) + self.input_ids = attention_mask.clone() + self.attention_mask = attention_mask + + +class _FakeTokenizer: + def __init__(self, total_sequence_length: int): + self.total_sequence_length = total_sequence_length + + def __call__(self, *args, **kwargs): + return _FakeTokenBatch(self.total_sequence_length) + + +PIPELINE_CASES = [ + pytest.param(Wan22Pipeline, id="wan22-t2v"), + pytest.param(Wan22I2VPipeline, id="wan22-i2v"), + pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"), + pytest.param(Wan22VACEPipeline, id="wan22-vace"), +] + + +def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer = _FakeTokenizer(total_sequence_length) + pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH + return pipeline + + +@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): + pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) + + with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"): + pipeline.encode_prompt(prompt="prompt") + + +@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type): + pipeline = _make_pipeline(pipeline_class, total_sequence_length=17) + + with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): + pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) + + +def _sampling_params(**overrides): + defaults = dict( + height=None, + width=None, + num_frames=None, + num_inference_steps=None, + generator=None, + guidance_scale_provided=False, + guidance_scale_2=None, + boundary_ratio=None, + num_outputs_per_prompt=0, + max_sequence_length=None, + seed=None, + extra_args={}, + prompt_embeds=None, + negative_prompt_embeds=None, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +@pytest.mark.parametrize( + ("pipeline_class", "prompt_value", "forward_kwargs"), + [ + pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"), + pytest.param( + Wan22I2VPipeline, + {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, + {"image": PIL.Image.new("RGB", (64, 64))}, + id="wan22-i2v", + ), + pytest.param( + Wan22TI2VPipeline, + {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, + {"image": PIL.Image.new("RGB", (64, 64))}, + id="wan22-ti2v", + ), + pytest.param(Wan22VACEPipeline, "prompt", {}, id="wan22-vace"), + ], +) +def test_forward_defaults_to_wan22_tokenizer_max_length( + pipeline_class: type, + prompt_value, + forward_kwargs, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH + pipeline.boundary_ratio = None + pipeline.vae_scale_factor_temporal = 4 + pipeline.vae_scale_factor_spatial = 8 + pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2)) + + captured = {} + + def _fake_check_inputs(*args, **kwargs): + captured["max_sequence_length"] = kwargs["max_sequence_length"] + raise RuntimeError("stop after capture") + + pipeline.check_inputs = _fake_check_inputs + + req = SimpleNamespace( + prompts=[prompt_value], + sampling_params=_sampling_params(), + ) + + with pytest.raises(RuntimeError, match="stop after capture"): + pipeline.forward(req, **forward_kwargs) + + assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 9f75c84538e..9ef0cacd5a0 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -34,6 +34,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -363,8 +366,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -378,6 +383,8 @@ def _get_qwen_prompt_embeds( self, prompt: str | list[str] = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): dtype = dtype or self.text_encoder.dtype @@ -388,12 +395,27 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, - max_length=self.tokenizer_max_length + drop_idx, padding=True, - truncation=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Validate only the user prompt contribution. The Qwen template also + # adds a fixed suffix after the user text, so subtracting only + # prompt_template_encode_start_idx would overcount near-limit prompts. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, return_tensors="pt", ).to(self.device) - # print(f"attention mask: {txt_tokens.attention_mask}") + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -422,6 +444,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -439,7 +462,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -632,6 +659,7 @@ def _prepare_generation_context( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) else: negative_prompt_embeds = None @@ -703,7 +731,7 @@ def prepare_encode( num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1, generator=sampling.generator, true_cfg_scale=sampling.true_cfg_scale or 4.0, - max_sequence_length=sampling.max_sequence_length or 512, + max_sequence_length=sampling.max_sequence_length or self.tokenizer_max_length, attention_kwargs=kwargs.get("attention_kwargs"), ) @@ -934,7 +962,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: extracted_prompt, negative_prompt = self._extract_prompts(req.prompts) prompt = extracted_prompt or prompt diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index dd77d71b1ea..cef7fe473a8 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -37,6 +37,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -323,8 +326,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -384,6 +389,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] = None, image: PIL.Image.Image | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with image support for editing.""" dtype = dtype or self.text_encoder.dtype @@ -393,6 +400,33 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # The edit template contains fixed multimodal scaffolding around the + # instruction. Validate against the empty-template baseline so image + # placeholder text does not consume the user's text budget. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Qwen-Image-Edit expands image placeholders into many vision tokens + # inside the processor. `max_sequence_length` is meant to constrain the + # prompt text length, so validate on the text template before image + # token expansion. + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -434,6 +468,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -453,7 +488,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -624,7 +664,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -739,6 +779,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 6f6c9d2ba38..a2702f3d295 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -40,6 +40,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -283,8 +286,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -299,6 +304,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str], image: list[torch.Tensor] | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with support for multiple images.""" dtype = dtype or self.text_encoder.dtype @@ -319,6 +326,32 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(base_img_prompt + e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Multi-image edit prepends "Picture N" placeholders before the user + # instruction. Subtract the placeholder-aware baseline so attached + # images do not reduce the remaining prompt budget. + template_tokens = self.tokenizer( + [template.format(base_img_prompt)], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # The processor expands image placeholders into many vision tokens. + # `max_sequence_length` should guard the prompt text length before that + # multimodal expansion happens. + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -360,6 +393,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -379,7 +413,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -557,7 +596,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -692,6 +731,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 38866d89c58..905ef5b4243 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -36,6 +36,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -340,8 +343,10 @@ def check_inputs( "generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -356,6 +361,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): device = device or self.device dtype = dtype or self.text_encoder.dtype @@ -368,8 +375,26 @@ def _get_qwen_prompt_embeds( txt_tokens = self.tokenizer( txt, padding=True, + truncation=False, + return_tensors="pt", + ).to(device) + # The layered template also appends fixed non-user tokens after the + # editable text, so use the empty-template tokenized baseline instead of + # counting everything after prompt_template_encode_start_idx. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, return_tensors="pt", ).to(device) + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -399,6 +424,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -419,7 +445,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -603,7 +633,7 @@ def forward( negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, + max_sequence_length: int = 1024, resolution: int = 640, cfg_normalize: bool = False, use_en_prompt: bool = False, @@ -736,6 +766,7 @@ def forward( device=self.device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) # 4. Prepare latent variables diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index a1b10439c85..e1249e889c7 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -29,12 +29,16 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) DEBUG_PERF = False WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} +WAN22_MAX_SEQUENCE_LENGTH = 512 def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any: @@ -289,6 +293,7 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -464,6 +469,7 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -497,7 +503,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -801,6 +807,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -829,8 +849,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -902,6 +938,7 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -928,5 +965,10 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index ddc6e0bc2b9..ca042ca228b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -35,6 +36,9 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -213,6 +217,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -392,6 +397,7 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -420,7 +426,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -671,6 +677,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -699,8 +719,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -842,6 +878,7 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -863,6 +900,11 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 62df13cbdea..6cbd6d2d6bd 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -37,6 +37,7 @@ from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -46,6 +47,9 @@ ) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -185,6 +189,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -301,6 +306,7 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -324,7 +330,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -518,6 +524,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -546,8 +566,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -672,6 +708,7 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -687,6 +724,11 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index ea523363111..b661108cc6f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -187,6 +187,7 @@ def check_inputs( video=None, mask=None, reference_images=None, + max_sequence_length=None, ): super().check_inputs( prompt=prompt, @@ -195,6 +196,7 @@ def check_inputs( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, ) # VACE-specific: validate video/mask/reference_images consistency @@ -491,6 +493,7 @@ def forward( video=source_video, mask=source_mask, reference_images=reference_images, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) device = self.device @@ -509,7 +512,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py new file mode 100644 index 00000000000..fc1769f4d54 --- /dev/null +++ b/vllm_omni/diffusion/utils/prompt_utils.py @@ -0,0 +1,38 @@ +import torch + + +def validate_prompt_sequence_lengths( + attention_mask: torch.Tensor, + *, + max_sequence_length: int, + supported_max_sequence_length: int, + prompt_name: str = "prompt", + length_offset: int = 0, + baseline_attention_mask: torch.Tensor | None = None, + error_context: str, +) -> None: + sequence_lengths = attention_mask.sum(dim=1) + if baseline_attention_mask is not None: + # Some callers need to validate only the user-controlled portion of a + # templated prompt. In those cases we subtract the fully-tokenized + # template baseline instead of only removing a fixed prefix length, + # because the template may also contribute a suffix or image markers. + baseline_lengths = baseline_attention_mask.sum(dim=1) + if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1: + baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0]) + sequence_lengths = sequence_lengths - baseline_lengths + if length_offset: + sequence_lengths = sequence_lengths - length_offset + sequence_lengths = torch.clamp(sequence_lengths, min=0) + too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) + if too_long.numel() == 0: + return + + batch_idx = int(too_long[0].item()) + actual_length = int(sequence_lengths[batch_idx].item()) + prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" + raise ValueError( + f"{prompt_ref} is too long {error_context}: got {actual_length} tokens, but " + f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " + f"`max_sequence_length` up to {supported_max_sequence_length}." + )