diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 6cbbdf398274..524f8b4aaa25 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -50,6 +50,8 @@ The original code can be found [here](https://github.com/facebookresearch/chamel - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating. +- When generating images, we advice users to load the model in `bfloat16` for better results. Simply make sure to set `torch_dtype=torch.bfloat16` when loading the model. + - Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question. - Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor. @@ -57,6 +59,9 @@ The original code can be found [here](https://github.com/facebookresearch/chamel > [!NOTE] > Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: ``. You have to add `` to your prompt in the place where the image should be embedded for correct generation. +> [!NOTE] +> The official model checkpoint currently only supports text generation. To generate images and interleaved text-image responses, you can use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135). Note however that Anole has a bias for "empty" or background patches, so it is recommended to use sampling when generating images (i.e. setting `do_sample=True` during generation) to reduce the likelihood of generating a blank image. + ## Usage example ### Single image inference @@ -124,6 +129,142 @@ generate_ids = model.generate(**inputs, max_new_tokens=50) processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) ``` +### Text to image generation + +Chameleon can also generate images. However, the official model checkpoint currently only supports text generation. We need to use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135) to do image generation. Here is how you can do it: + +```python +import torch +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Prepare a prompt +prompt = "Generate an image of a snowman." + +# Preprocess the prompt +inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) + +# Generate discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="image-only", + # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token. + max_new_tokens=1026, + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] + +# Decode the generated image tokens +pixel_values = model.decode_image_tokens(response_ids[:, 1:-1]) +images = processor.postprocess_pixel_values(pixel_values) + +# Save the image +images[0].save("snowman.png") +``` + +### Text-image to image generation + +We can also interleave text and images in the prompt to generate images. Here is how you can do it: + +```python +import requests + +import torch +from PIL import Image +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration +from transformers.image_transforms import to_pil_image + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Get image of a snowman +url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +image_snowman = Image.open(requests.get(url, stream=True).raw) + +# Prepare a prompt +prompt = "Generate a variation of this image." + +# Preprocess the prompt +inputs = processor( + images=[image_snowman], + text=prompt, + padding=True, + return_tensors="pt", +).to(model.device, dtype=model.dtype) + +# Generate discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="image-only", + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] + +# The generated image tokens are wrapped by the `image_start_token` and `image_end_token` tokens. We need to remove them before decoding the image tokens. +image_token_ids = response_ids[:, 1:-1] + +# Decode the generated image tokens +pixel_values = model.decode_image_tokens(image_token_ids) +pixel_values = processor.postprocess_pixel_values(pixel_values) + +# Save the image +image = to_pil_image(pixel_values[0].detach().cpu()) +image.save("snowman.png") +``` + +### Interleaved text-image generation + +We can also generate interleaved text and images in the output. Here is how you can do it: + +```python +import torch +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Prepare a prompt +prompt = "Can you draw a snowman and explain how to build one?" + +# Preprocess the prompt +inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) + +# Generate interleaved text and discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="interleaved-text-image", + # Note: We will need a larger `max_new_tokens` value since we are generating both text and image tokens. + max_new_tokens=4096, + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] +``` + +From here, you can split the response tokens into text and image token segments, decode them separately as shown in the previous examples, and finally render the resulting text and images together. You can also use [MMSG](https://github.com/leloykun/mmsg) to do this more easily. + ## Model optimization ### Quantization using Bitsandbytes diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 39a38f9139ec..d2ede0734dd5 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1778,6 +1778,61 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed +class SuppressTokensInIndexRangeLogitsProcessor(LogitsProcessor): + r""" + [`SuppressTokensInIndexRangeLogitsProcessor`] supresses a list of tokens from `start_index` to `end_index` (exclusive) + + Args: + suppress_tokens (`List[int]`): + List of token ids to suppress during generation. + start_index (`int`): + The index at which to start suppressing tokens. + end_index (`int`, *optional*): + The index at which to end suppressing tokens. If `None`, it will suppress tokens indefinitely. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. + + Examples: + + ```python + >>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList + >>> from transformers.generation.logits_process import SuppressTokensInIndexRangeLogitsProcessor + >>> import torch + + >>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf") + >>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf") + + >>> inputs = processor("Can you draw a snowman?", return_tensors="pt") + >>> max_length = 1200 + >>> # Don't start generating an image if there aren't enough space for the rest of the image tokens. + >>> logits_processor = SuppressTokensInIndexRangeLogitsProcessor( + ... suppress_tokens=[model.vocabulary_mapping.boi_token_id], + ... start_index=max_length - model.model.image_seq_length - 1, + ... device=model.device, + ... ) + + >>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor])) + >>> print(torch.isin(outputs[input.input_ids.shape[1] + 1 : ], model.vocabulary_mapping.image_token_ids).all()) + True + ``` + """ + + def __init__( + self, suppress_tokens: List[int], start_index: int, end_index: Optional[int] = None, device: str = "cpu" + ): + self.suppress_tokens = torch.tensor(suppress_tokens, device=device) + self.start_index = start_index + self.end_index = end_index if end_index is not None else math.inf + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + current_index = input_ids.shape[1] + if self.start_index > current_index or current_index > self.end_index: + return scores + suppress_tokens_mask = torch.zeros_like(scores, dtype=torch.bool) + suppress_tokens_mask[:, self.suppress_tokens] = True + return scores.masked_fill(suppress_tokens_mask, -float("inf")) + + class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): r""" [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts @@ -2953,3 +3008,83 @@ def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> The expected mean g-value for watermarked text. """ return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size)) + + +class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor): + r""" + [`AllowOnlyTokensInRelativeWindowLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens + that can be generated at a relative window from a trigger token (e.g. begin image token). If `exclusive` is set to + `True`, the set of tokens allowed at this window will not be allowed anywhere else. This is useful for enforcing + multimodal generation constraints. + + Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon). + + Args: + trigger_token_id (`int`): + The token id that triggers the window check. + allowed_token_ids (`List[int]`): + The list of token ids that are allowed at the specified relative window. + window_width (`int`): + The window_width of the window from the trigger token. + exclusive (`bool`, *optional*, defaults to `False`): + If `True`, the set of tokens allowed at this window will not be allowed anywhere else. + device (`str`, *optional*, defaults to `cpu`): + The device to allocate the util tensor on. + + Examples: + + ```python + >>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList + >>> from transformers.generation.logits_process import AllowOnlyTokensInRelativeWindowLogitsProcessor + >>> import torch + + >>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf") + >>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf") + + >>> inputs = processor("Can you draw a snowman?", return_tensors="pt") + >>> max_length = 1200 + >>> # Generate only image token ids for `image_seq_length` steps when the boi-token is already generated + >>> logits_processor = AllowOnlyTokensInRelativeWindowLogitsProcessor( + ... trigger_token_id=model.vocabulary_mapping.boi_token_id, + ... allowed_token_ids=model.vocabulary_mapping.image_token_ids, + ... window_width=model.model.image_seq_length, + ... exclusive=True, + ... device=model.device, + ... ) + + >>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor])) + ``` + """ + + def __init__( + self, + trigger_token_id: int, + allowed_token_ids: List[int], + window_width: int, + exclusive: bool = False, + device: str = "cpu", + ): + self.trigger_token_id = trigger_token_id + self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device).unsqueeze(0) + self.window_width = window_width + self.exclusive = exclusive + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.shape[1] < self.window_width and not self.exclusive: + return scores + + window_width = min(self.window_width, input_ids.shape[1]) + trigger_positions = (input_ids[:, -window_width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1) + + disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool) + disallowed_tokens_mask[:, self.allowed_token_ids] = False + + if self.exclusive: + return scores.masked_fill( + ~(disallowed_tokens_mask ^ trigger_positions), + -float("inf"), + ) + return scores.masked_fill( + disallowed_tokens_mask & trigger_positions, + -float("inf"), + ) diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 0279f26a963e..a156cc8170a1 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -18,7 +18,7 @@ import numpy as np from .image_processing_base import BatchFeature, ImageProcessingMixin -from .image_transforms import center_crop, normalize, rescale +from .image_transforms import center_crop, normalize, rescale, unnormalize from .image_utils import ChannelDimension from .utils import logging @@ -112,6 +112,43 @@ def normalize( image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs ) + def unnormalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to unnormalize. + mean (`float` or `Iterable[float]`): + Image mean to use for unnormalization. + std (`float` or `Iterable[float]`): + Image standard deviation to use for unnormalization. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The normalized image. + """ + return unnormalize( + image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + def center_crop( self, image: np.ndarray, diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index e7d3a5abb7a8..0b1d7e27d02f 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -416,6 +416,67 @@ def normalize( return image +def unnormalize( + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`. + + image = (image * std) + mean + + Args: + images (`np.ndarray`): + The image to unnormalize. + mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If unset, will use the inferred format from the input. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if not isinstance(image, np.ndarray): + raise ValueError("image must be a numpy array") + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format) + num_channels = image.shape[channel_axis] + + if isinstance(mean, Iterable): + if len(mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") + else: + mean = [mean] * num_channels + + if isinstance(std, Iterable): + if len(std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") + else: + std = [std] * num_channels + + rev_image_mean = tuple(-mu / stdev for mu, stdev in zip(mean, std)) + rev_image_std = tuple(1 / stdev for stdev in std) + image = normalize( + image=image, + mean=rev_image_mean, + std=rev_image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + return image + + def center_crop( image: np.ndarray, size: Tuple[int, int], diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py index 9842127e7bb4..4f1ee7be02a4 100644 --- a/src/transformers/models/chameleon/configuration_chameleon.py +++ b/src/transformers/models/chameleon/configuration_chameleon.py @@ -45,6 +45,8 @@ class ChameleonVQVAEConfig(PretrainedConfig): Resolution of the input images. in_channels (`int`, *optional*, defaults to 3): Number of input channels. + out_channels (`int`, *optional*, defaults to 3): + Number of output channels. base_channels (`int`, *optional*, defaults to 128): Base channel count. channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): @@ -72,6 +74,7 @@ def __init__( latent_channels: int = 256, resolution: int = 512, in_channels: int = 3, + out_channels: int = 3, base_channels: int = 128, channel_multiplier: List[int] = [1, 1, 2, 2, 4], num_res_blocks: int = 2, @@ -88,6 +91,7 @@ def __init__( self.latent_channels = latent_channels self.resolution = resolution self.in_channels = in_channels + self.out_channels = out_channels self.base_channels = base_channels self.channel_multiplier = channel_multiplier self.num_res_blocks = num_res_blocks @@ -170,6 +174,12 @@ class ChameleonConfig(PretrainedConfig): ChameleonVQConfig instance containing the configuration for the VQ-VAE model. vocabulary_map (`dict`, *optional*): A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + image_token_index (`int`, *optional*, defaults to 8711): + The ID for the token used to represent the image in the input sequence. + boi_token_id (`int`, *optional*, defaults to 8197): + Beginning of image token stream id. + eoi_token_id (`int`, *optional*, defaults to 8196): + End of image token stream id. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. @@ -216,6 +226,9 @@ def __init__( swin_norm=False, vq_config=None, vocabulary_map=None, + image_token_index=8711, + boi_token_id=8197, + eoi_token_id=8196, mlp_bias=False, **kwargs, ): @@ -247,6 +260,9 @@ def __init__( self.vq_config = ChameleonVQVAEConfig(**vq_config) self.vocabulary_map = vocabulary_map + self.image_token_index = image_token_index + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py index ff45c9b597e0..40a747ff1bb5 100644 --- a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py +++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py @@ -81,7 +81,7 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, model_size, chameleon_version=1): +def write_model(model_path, input_base_path, model_size, chameleon_version=1, vqvae_path=None): os.makedirs(model_path, exist_ok=True) input_model_path = os.path.join(input_base_path, "models", model_size.lower()) params_path = os.path.join(input_model_path, "params.json") @@ -316,8 +316,6 @@ def permute(w, n_heads, dim1=dim, dim2=dim): vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt") vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"] for k, v in vqgan_state_dict.items(): - if "decoder" in k: - continue # we dont do image generation yet state_dict[f"model.vqmodel.{k}"] = v # Write configs @@ -327,9 +325,8 @@ def permute(w, n_heads, dim1=dim, dim2=dim): with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file: tokenizer_config = json.load(tokenizer_file) vocabulary_map = tokenizer_config["model"]["vocab"] - vocabulary_map[""] = vocabulary_map[ - "" - ] # use a reserved token instead of adding a new one + # use a reserved token instead of adding a new one + vocabulary_map[""] = vocabulary_map[""] del vocabulary_map[""] for token in tokenizer_config["added_tokens"]: @@ -370,6 +367,9 @@ def permute(w, n_heads, dim1=dim, dim2=dim): swin_norm=swin_norm, vq_config=vq_config, vocabulary_map=vocabulary_map, + image_token_id=vocabulary_map[""], + boi_token_id=vocabulary_map[""], + eoi_token_id=vocabulary_map[""], ) with init_empty_weights(): model = ChameleonForConditionalGeneration(config) @@ -377,9 +377,19 @@ def permute(w, n_heads, dim1=dim, dim2=dim): model.load_state_dict(state_dict, assign=True, strict=False) model.save_pretrained(model_path, safe_serialization=True) + if vqvae_path is not None: + model.model.vqmodel.save_pretrained(vqvae_path, safe_serialization=True) + # Load and save the processor + extra_special_tokens = { + "image_token": "", + "boi_token": "", + "eoi_token": "", + } tokenizer = LlamaTokenizerFast( - tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False + tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), + legacy=False, + extra_special_tokens=extra_special_tokens, ) tokenizer.sep_token_id = 8710 # assign to sep so that we can append it after input text tokenizer.pad_token_id = 1 # assing to special pad_token @@ -463,12 +473,18 @@ def main(): type=int, help="Version of the Chameleon model to convert", ) + parser.add_argument( + "--vqvae_path", + default=None, + help="Location to write VQ-VAE model", + ) args = parser.parse_args() write_model( model_path=args.output_dir, input_base_path=args.input_dir, model_size=args.model_size, chameleon_version=args.chameleon_version, + vqvae_path=args.vqvae_path, ) diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index 46d081973bb4..937e3cfff244 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -31,17 +31,21 @@ infer_channel_dimension_format, is_scaled_image, is_valid_image, + make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils import TensorType, filter_out_non_signature_kwargs, is_torch_available, is_vision_available, logging logger = logging.get_logger(__name__) +if is_torch_available(): + import torch if is_vision_available(): import PIL + from PIL import Image def make_batched_images(images) -> List[List[ImageInput]]: @@ -209,7 +213,8 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> PIL.Image.Image: + **kwargs, + ) -> BatchFeature: """ Preprocess an image or batch of images. @@ -362,3 +367,81 @@ def blend_rgba(self, image: ImageInput) -> ImageInput: alpha = img_rgba[:, :, 3] / 255.0 img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_unnormalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> "torch.Tensor": + """ + Postprocess a batch of pixel values to images. + + Args: + images (`ImageInput`): + Image to postprocess. Expects a single or batch of images with pixel values. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_unnormalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to unnormalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for unnormalization. Only has an effect if `do_unnormalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for unnormalization. Only has an effect if `do_unnormalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `PIL.Image` or `'pil'`: Return a batch of type `PIL.Image`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_unnormalize = do_unnormalize if do_unnormalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + postprocessed_images = [] + for image in images: + if do_unnormalize: + image = self.unnormalize( + image, mean=image_mean, std=image_std, data_format=data_format, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + image = image.clip(0, 255).astype(np.uint8) + + if do_unnormalize and do_rescale and return_tensors == "pil": + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) + image = Image.fromarray(image) + postprocessed_images.append(image) + + return_tensors = return_tensors if return_tensors != "pil" else None + return BatchFeature(data={"pixel_values": postprocessed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 3255b6f44c05..626fa8f2b758 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -15,8 +15,10 @@ """PyTorch Chameleon model.""" import math +import warnings +from dataclasses import dataclass from functools import cached_property -from typing import Optional, Tuple, Union +from typing import Dict, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -27,6 +29,15 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...generation.configuration_utils import GenerationConfig +from ...generation.logits_process import ( + AllowOnlyTokensInRelativeWindowLogitsProcessor, + LogitsProcessorList, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensInIndexRangeLogitsProcessor, + SuppressTokensLogitsProcessor, +) +from ...generation.utils import GenerateOutput from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -36,6 +47,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -60,6 +72,22 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" +@dataclass +class ChameleonVQVAEOutput(ModelOutput): + """ + Base class for Chameleon Vq-VAE mode model outputs. + + Args: + decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + Reconstructed pixel values after encoding and decoding the input. + emb_loss (`torch.FloatTensor`): + Embedding loss. + """ + + decoded_pixel_values: Optional[torch.FloatTensor] = None + emb_loss: torch.FloatTensor = None + + # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon class ChameleonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -750,12 +778,14 @@ def __init__(self, config): super().__init__() self.num_embeddings = config.num_embeddings self.embedding_dim = config.embed_dim + self.quant_state_dims = [config.resolution // 2 ** (len(config.channel_multiplier) - 1)] * 2 self.beta = getattr(config, "beta", 0.25) self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) self.re_embed = self.num_embeddings - def forward(self, hidden_state: torch.Tensor): + def forward(self, hidden_state: torch.FloatTensor): + batch_size = hidden_state.shape[0] hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) @@ -780,7 +810,30 @@ def forward(self, hidden_state: torch.Tensor): # reshape back to match original input shape hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() - return hidden_state_quant, loss, min_encoding_indices + return hidden_state_quant, loss, min_encoding_indices.view(batch_size, -1) + + def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: + batch_size = image_tokens.shape[0] + emb_dim: int = self.embedding.weight.shape[-1] + # get quantized latent vectors + hidden_state_quant = self.embedding(image_tokens) + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim)) + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant + + +class ChameleonVQVAEDecoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states class ChameleonVQVAEEncoderConvDownsample(nn.Module): @@ -795,7 +848,7 @@ def forward(self, hidden_states): return hidden_states -class ChameleonVQVAEEncoderResnetBlock(nn.Module): +class ChameleonVQVAEResnetBlock(nn.Module): def __init__( self, config, @@ -839,7 +892,7 @@ def forward(self, hidden_states): return residual + hidden_states -class ChameleonVQVAEEncoderAttnBlock(nn.Module): +class ChameleonVQVAEAttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -850,7 +903,7 @@ def __init__(self, in_channels): self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: residual = hidden_states hidden_states = self.norm(hidden_states) query_states = self.q(hidden_states) @@ -887,7 +940,7 @@ def __init__(self, config): latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier - self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) curr_res = resolution in_channel_multiplier = (1,) + tuple(channel_multiplier) @@ -900,7 +953,7 @@ def __init__(self, config): block_out = base_channels * channel_multiplier[i_level] for i_block in range(self.num_res_blocks): block.append( - ChameleonVQVAEEncoderResnetBlock( + ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_out, @@ -912,7 +965,7 @@ def __init__(self, config): and curr_res in config.attn_resolutions and config.attn_type == "vanilla" ): - attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + attn.append(ChameleonVQVAEAttnBlock(block_in)) down = nn.Module() down.block = block @@ -923,13 +976,13 @@ def __init__(self, config): self.down.append(down) self.mid = nn.Module() - self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + self.mid.block_1 = ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) - self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() - self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + self.mid.attn_1 = ChameleonVQVAEAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_in, @@ -944,7 +997,7 @@ def __init__(self, config): padding=1, ) - def forward(self, pixel_values: torch.LongTensor): + def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: # downsampling hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): @@ -971,6 +1024,95 @@ def forward(self, pixel_values: torch.LongTensor): return last_hidden_state +class ChameleonVQVAEDecoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + latent_channels = config.latent_channels + out_channels = config.out_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, latent_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): + attn.append(ChameleonVQVAEAttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = ChameleonVQVAEDecoderConvUpsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor: + hidden_state = self.conv_in(hidden_state) + + # middle + hidden_state = self.mid.block_1(hidden_state) + hidden_state = self.mid.attn_1(hidden_state) + hidden_state = self.mid.block_2(hidden_state) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden_state = self.up[i_level].block[i_block](hidden_state) + if len(self.up[i_level].attn) > 0: + hidden_state = self.up[i_level].attn[i_block](hidden_state) + if i_level != 0: + hidden_state = self.up[i_level].upsample(hidden_state) + + hidden_state = self.norm_out(hidden_state) + hidden_state *= torch.sigmoid(hidden_state) + hidden_state = self.conv_out(hidden_state) + return hidden_state + + CHAMELEON_VQ_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -998,11 +1140,14 @@ def forward(self, pixel_values: torch.LongTensor): class ChameleonVQVAE(PreTrainedModel): config_class = ChameleonVQVAEConfig _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + main_input_name = "pixel_values" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.GroupNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -1015,33 +1160,107 @@ def __init__(self, config: ChameleonVQVAEConfig): super().__init__(config) self.encoder = ChameleonVQVAEEncoder(config) + self.decoder = ChameleonVQVAEDecoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.post_init() self.eval() # Chameleon's VQ model is frozen - def encode(self, pixel_values: torch.LongTensor): + def encode(self, pixel_values: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: + """ + Encodes pixel values into quantized tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + + Returns: + quant (`torch.FloatTensor` of shape `(batch_size, embed_dim, quantize.quant_state_dims[0], quantize.quant_state_dims[1])`): + Embeddings of quantized tokens. + emb_loss (`torch.FloatTensor`): + Embedding loss. + indices (`torch.LongTensor` of shape `(batch_size, quantize.quant_state_dims[0] * quantize.quant_state_dims[1])`): + Token IDs + """ hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states) return quant, emb_loss, indices + def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: + """ + Decodes quantized token IDs into pixel values. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, quantize.quant_state_dims[0] * quantize.quant_state_dims[1])`): + Batch of token IDs. + + Returns: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + Pixel values decoded from the token IDs. + """ + if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]: + raise ValueError( + f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, " + f"but got shape `{image_tokens.shape}`." + ) + codebook_entry = self.quantize.get_codebook_entry(image_tokens) + hidden_states = self.post_quant_conv(codebook_entry) + pixel_values = self.decoder(hidden_states) + return pixel_values + + def forward( + self, pixel_values: torch.FloatTensor, return_dict: bool = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """ + Encodes pixel values into quantized tokens and decodes them back. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + Reconstructed pixel values after encoding and decoding the input. + emb_loss (`torch.FloatTensor`): + Embedding loss. + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + quant, emb_loss, indices = self.encode(pixel_values) + decoded_pixel_values = self.decode(indices) + if not return_dict: + return (decoded_pixel_values, emb_loss) + return ChameleonVQVAEOutput(decoded_pixel_values, emb_loss) + class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. """ - def __init__(self, vocab_map): + def __init__( + self, + vocab_map: Dict[str, int], + image_token_index: int, + boi_token_id: int, + eoi_token_id: int, + ): self.vocab_map = vocab_map - self.image_token_id = vocab_map.get("") + self.image_token_index = image_token_index + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id @cached_property def val2name(self): return {v: k for k, v in self.vocab_map.items()} @cached_property - def image_tokens(self): + def image_token_ids(self): return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) @cached_property @@ -1051,15 +1270,18 @@ def bpe2img(self): def remap(old_name: str) -> str: return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) - return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} + return {tok: int(remap(self.val2name[tok])) for tok in self.image_token_ids} @cached_property def img2bpe(self): return {v: k for k, v in self.bpe2img.items()} @cached_property - def bpe2img_search_tensors(self): - return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping @cached_property def img2bpe_mapping_tensor(self): @@ -1068,11 +1290,6 @@ def img2bpe_mapping_tensor(self): mapping[k] = v return mapping - def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: - device = img_batch.device - img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] - return img_tokens.to(device) - CHAMELEON_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1099,7 +1316,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): config_class = ChameleonConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] + _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer", "ChameleonVQVAE"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -1209,7 +1426,22 @@ def __init__(self, config: ChameleonConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + self.vocabulary_mapping = ChameleonImageVocabularyMapping( + config.vocabulary_map, + config.image_token_index, + config.boi_token_id, + config.eoi_token_id, + ) + self.register_buffer( + "img2bpe_mapping_tensor", + self.vocabulary_mapping.img2bpe_mapping_tensor, + persistent=False, + ) + self.register_buffer( + "bpe2img_mapping_tensor", + self.vocabulary_mapping.bpe2img_mapping_tensor, + persistent=False, + ) decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer self.layers = nn.ModuleList( [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -1221,12 +1453,65 @@ def __init__(self, config: ChameleonConfig): # Initialize weights and apply final processing self.post_init() + @property + def image_seq_length(self) -> int: + return self.vqmodel.quantize.quant_state_dims[0] * self.vqmodel.quantize.quant_state_dims[1] + def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value + def convert_img2bpe_tokens(self, img_batch: torch.LongTensor) -> torch.LongTensor: + """ + Converts image tokens generated by the VQVAE model into BPE tokens compatible with the text tokenizer. + + Notes: + - It is important to move the `img_batch` tensor to the same device as the `img2bpe_mapping_tensor` buffer + as Accelerate may move the buffer to a different device when loading the model with `device_map="auto"`. + - Accelerate up to version 0.33.0 (and also maybe later versions) has a bug where buffers in downstream modules + may be ignored when inferring the proper device map. See: https://github.com/huggingface/accelerate/blob/79ca85c27df292dbf64cfa2bcc12dbb62fbe9267/src/accelerate/utils/modeling.py#L1273 + This causes the `img2bpe_mapping_tensor` buffer to be placed on the CPU by default, which may cause a performance + loss--especially with prompts that contain many images. No action needs to be done when this bug is fixed. + + Args: + img_batch (`torch.Tensor` of shape `(batch_size, image_seq_length)`): + The image tokens generated by the VQVAE model. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The image tokens converted to be compatible with the text tokenizer's BPE tokens. + """ + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to(self.img2bpe_mapping_tensor.device)] + return img_tokens.to(device) + + def convert_bpe2img_tokens(self, bpe_batch: torch.LongTensor) -> torch.LongTensor: + """ + Converts image tokens that are compatible with the text tokenizer into image tokens compatible with the VQVAE + model. + + Notes: + - It is important to move the `img_batch` tensor to the same device as the `img2bpe_mapping_tensor` buffer + as Accelerate may move the buffer to a different device when loading the model with `device_map="auto"`. + - Accelerate up to version 0.33.0 (and also maybe later versions) has a bug where buffers in downstream modules + may be ignored when inferring the proper device map. See: https://github.com/huggingface/accelerate/blob/79ca85c27df292dbf64cfa2bcc12dbb62fbe9267/src/accelerate/utils/modeling.py#L1273 + This causes the `img2bpe_mapping_tensor` buffer to be placed on the CPU by default, which may cause a performance + loss--especially when generating interleaved text & images. No action needs to be done when this bug is fixed. + + Args: + bpe_batch (`torch.Tensor` of shape `(batch_size, image_seq_length)`): + The image tokens compatible with the text tokenizer. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The image tokens converted to be compatible with the VQVAE model. + """ + device = bpe_batch.device + img_tokens = self.bpe2img_mapping_tensor[bpe_batch.to(self.bpe2img_mapping_tensor.device)] + return img_tokens.to(device) + def get_image_tokens(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts @@ -1236,12 +1521,30 @@ def get_image_tokens(self, pixel_values: torch.FloatTensor): Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The BPE tokens generated by the model. """ - batch_size = pixel_values.shape[0] _, _, image_toks = self.vqmodel.encode(pixel_values) - bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) - bpe_toks = bpe_toks.view(batch_size, -1) - return bpe_toks + return self.convert_img2bpe_tokens(image_toks) + + def decode_image_tokens(self, bpe_tokens: torch.LongTensor) -> torch.LongTensor: + """ + Converts BPE tokens generated by the model into discrete image tokens + compatible with the VQGAN module, then decodes them into pixel values. + + Args: + bpe_tokens (`torch.tensor` of shape `(batch, image_seq_length)`): + The BPE tokens generated by the model. + + Returns: + `torch.Tensor` of shape `(batch, num_channels, 512, 512)`: + """ + if bpe_tokens.shape[1] != self.image_seq_length: + raise ValueError(f"All batches must have {self.image_seq_length} tokens.") + image_tensor = self.convert_bpe2img_tokens(bpe_tokens) + return self.vqmodel.decode(image_tensor) @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1287,13 +1590,13 @@ def forward( if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) - n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_index).sum().item() n_image_features = image_tokens.shape[0] * image_tokens.shape[1] if n_image_tokens_in_text != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" ) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + special_image_mask = input_ids == self.vocabulary_mapping.image_token_index image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) @@ -1510,9 +1813,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): + def __init__(self, config: ChameleonConfig): super().__init__(config) self.model = ChameleonModel(config) + self.vocabulary_mapping = self.model.vocabulary_mapping self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1537,6 +1841,149 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def _prepare_generation_config( + self, + generation_config: Optional[GenerationConfig] = None, + multimodal_generation_mode: Optional[ + Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"] + ] = None, + **kwargs, + ): + if ( + multimodal_generation_mode == "image-only" + and kwargs.get("max_length") is None + and kwargs.get("max_new_tokens") is None + and ( + generation_config is None + or (generation_config.max_length is None and generation_config.max_new_tokens is None) + ) + ): + kwargs["max_new_tokens"] = self.model.image_seq_length + 2 + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + if multimodal_generation_mode is not None: + generation_config.multimodal_generation_mode = multimodal_generation_mode + if ( + not hasattr(generation_config, "multimodal_generation_mode") + or generation_config.multimodal_generation_mode is None + ): + generation_config.multimodal_generation_mode = "text-only" + return generation_config, model_kwargs + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + multimodal_generation_mode: Optional[ + Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"] + ] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, multimodal_generation_mode, **kwargs + ) + + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + # Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + if logits_processor is None: + logits_processor = LogitsProcessorList() + if generation_config.multimodal_generation_mode == "text-only": + # Suppress all image tokens + logits_processor.append( + SuppressTokensLogitsProcessor( + suppress_tokens=self.vocabulary_mapping.image_token_ids + + [ + self.vocabulary_mapping.boi_token_id, + self.vocabulary_mapping.eoi_token_id, + ], + device=self.device, + ) + ) + elif generation_config.multimodal_generation_mode == "image-only": + inferred_max_new_tokens = generation_config.max_length - input_ids_length + if inferred_max_new_tokens < self.model.image_seq_length + 2: + warnings.warn( + f"The VQVAE decoder expects to receive {self.model.image_seq_length} image tokens to generate an image." + "And Chameleon wraps the image tokens with the `beginning-of-image` and `end-of-image` tokens when on image generation mode." + f"Therefore, the `max_new_tokens` must be at least {self.model.image_seq_length + 2}." + f"However, the inferred `max_new_tokens` from the generation config is only {inferred_max_new_tokens}." + "You would need to pad the output tokens with dummy image tokens before passing them to the VQVAE decoder." + f"To avoid this warning, set `max_new_tokens` to at least {self.model.image_seq_length + 2}." + ) + allowed_tokens = self.vocabulary_mapping.image_token_ids + [ + self.config.eos_token_id, + self.vocabulary_mapping.boi_token_id, + self.vocabulary_mapping.eoi_token_id, + ] + suppress_tokens = [token_id for token_id in range(self.vocab_size) if token_id not in allowed_tokens] + logits_processor.extend( + [ + # Don't start generating an image if there aren't enough space for the rest of the image tokens. + SuppressTokensInIndexRangeLogitsProcessor( + suppress_tokens=[self.vocabulary_mapping.boi_token_id], + start_index=generation_config.max_length - self.model.image_seq_length - 1, + device=self.device, + ), + # Allow only image tokens + SuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens, device=self.device), + # Force image generation + SuppressTokensAtBeginLogitsProcessor( + begin_suppress_tokens=[self.config.eos_token_id], + begin_index=input_ids_length, + device=self.device, + ), + ] + ) + elif generation_config.multimodal_generation_mode == "interleaved-text-image": + logits_processor.extend( + [ + # Generate only image token ids for `image_seq_length` steps when the boi-token is already generated + AllowOnlyTokensInRelativeWindowLogitsProcessor( + trigger_token_id=self.vocabulary_mapping.boi_token_id, + allowed_token_ids=self.vocabulary_mapping.image_token_ids, + window_width=self.model.image_seq_length, + exclusive=True, + device=self.device, + ), + # Don't start generating an image if there aren't enough space for the rest of the image tokens. + SuppressTokensInIndexRangeLogitsProcessor( + suppress_tokens=[self.vocabulary_mapping.boi_token_id], + start_index=generation_config.max_length - self.model.image_seq_length - 1, + device=self.device, + ), + ] + ) + elif generation_config.multimodal_generation_mode == "unrestricted": + pass + else: + raise ValueError( + f"Unknown multimodal generation mode: {generation_config.multimodal_generation_mode}. " + f"Please choose one of 'unrestricted', 'text-only', 'image-only', or 'interleaved-text-image'." + ) + return super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + **kwargs, + ) + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1607,10 +2054,6 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - # Disallow image tokens which does not include special begin-image and end-image tokens - image_tokens = self.model.vocabulary_mapping.image_tokens - logits[:, :, image_tokens] = torch.finfo(logits.dtype).min - loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues @@ -1689,3 +2132,17 @@ def prepare_inputs_for_generation( } ) return model_inputs + + def decode_image_tokens(self, bpe_tokens: torch.Tensor): + """ + Converts BPE tokens generated by the model into discrete image tokens + compatible with the VQGAN module, then decodes them into pixel values. + + Args: + bpe_tokens (`torch.tensor` of shape `(batch, image_seq_length)`): + The BPE tokens generated by the model. + + Returns: + `torch.Tensor` of shape `(batch, num_channels, 512, 512)`: + """ + return self.model.decode_image_tokens(bpe_tokens) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index e2a50d1af51b..3d3e99a22a53 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -168,3 +168,18 @@ def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def postprocess(self, images, return_tensors=None, **kwargs): + """ + Postprocess a batch of images. + + Args: + images (`ImageInput`): + A batch of images or a single image to postprocess. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values that are postprocessed in the requested return type. + """ + return self.image_processor.postprocess(images, return_tensors=return_tensors, **kwargs) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index aeebb5c4c53d..d4a497e358f0 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -57,7 +57,11 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, WatermarkLogitsProcessor, ) - from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor + from transformers.generation.logits_process import ( + AllowOnlyTokensInRelativeWindowLogitsProcessor, + BarkEosPrioritizerLogitsProcessor, + SuppressTokensInIndexRangeLogitsProcessor, + ) @require_torch @@ -709,6 +713,33 @@ def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids): # processor should not change logits in-place self.assertFalse(torch.all(scores == filtered_scores)) + def test_chameleon_processors(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + scores = self._get_uniform_logits(batch_size, vocab_size) + + suppress_tokens_proc = SuppressTokensInIndexRangeLogitsProcessor( + suppress_tokens=[0], start_index=sequence_length - 2, device=torch_device + ) + scores = suppress_tokens_proc(input_ids, scores) + self.assertTrue(torch.isinf(scores[:, 0]).all()) + + input_ids[:, -1] = 1 + allow_tokens_proc = AllowOnlyTokensInRelativeWindowLogitsProcessor( + trigger_token_id=0, allowed_token_ids=[1, 2, 3], window_width=2, device=torch_device + ) + scores = allow_tokens_proc(input_ids, scores) + self.assertFalse(torch.isinf(scores).all()) + + input_ids[:, -1] = 0 + scores = allow_tokens_proc(input_ids, scores) + self.assertTrue(torch.isinf(scores[:, [4, 5, 6]]).all()) + self.assertFalse(torch.isinf(scores[:, [1, 2, 3]]).all()) + def test_hamming_diversity(self): vocab_size = 4 num_beams = 2 diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 063e9a3da8fd..44373f7dd201 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1617,7 +1617,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` pixel_values_is_mutually_exclusive = any( model_name in model_class.__name__.lower() - for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"] + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "chameleon"] ) if pixel_values_is_mutually_exclusive: inputs_dict.pop("pixel_values", None) @@ -1691,17 +1691,30 @@ def test_generate_from_inputs_embeds_with_static_cache(self): if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): self.skipTest(reason="This model does not support `inputs_embeds` in generation") + # Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the + # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the + # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` + pixel_values_is_mutually_exclusive = any( + model_name in model_class.__name__.lower() + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "chameleon"] + ) + if pixel_values_is_mutually_exclusive: + inputs_dict.pop("pixel_values", None) + inputs_dict.pop("pixel_values_videos", None) + inputs_dict.pop("pixel_values_images", None) + input_ids = inputs_dict.pop("input_ids") model.config.use_cache = True model.config.is_decoder = True batch_size = input_ids.shape[0] - max_cache_len = 30 + max_new_tokens = 5 + max_cache_len = max_new_tokens + input_ids.shape[1] # here we force to not stop at eos and go until max-length model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 generation_kwargs = { - "max_length": max_cache_len, + "max_new_tokens": max_new_tokens, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` } @@ -1935,6 +1948,10 @@ def test_generate_with_static_cache(self): "output_scores": True, "use_cache": True, } + inputs_dict = { + k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v + for k, v in inputs_dict.items() + } static_cache_generation = model.generate( **generation_kwargs, **inputs_dict, cache_implementation="static" diff --git a/tests/models/chameleon/test_image_processing_chameleon.py b/tests/models/chameleon/test_image_processing_chameleon.py index 4a5c8c546790..2fdd0e07e607 100644 --- a/tests/models/chameleon/test_image_processing_chameleon.py +++ b/tests/models/chameleon/test_image_processing_chameleon.py @@ -204,3 +204,28 @@ def test_nested_input(self): # Image processor should return same pixel values, independently of input format self.assertTrue((encoded_images_nested == encoded_images).all()) + + def test_postprocessing(self): + """Tests image postprocessing, in other words converting a normalized image back to `PIL` image""" + image_processing = self.image_processing_class(**self.image_processor_dict) + # Pixel values for an image with 3 channels and 32x32 resolution + pixel_values_single = torch.zeros((1, 3, 32, 32)) + # Pixel values for a batch of 2 images with 3 channels and 32x32 resolution + pixel_values_batch = torch.zeros((2, 3, 32, 32)) + + for pixel_values in [pixel_values_single, pixel_values_batch]: + unnormalized_pixel_values = image_processing.postprocess(pixel_values, return_tensors="pt").pixel_values + self.assertEqual(unnormalized_pixel_values.shape, pixel_values.shape) + expected_pixel_values = torch.full_like(pixel_values, 128) + self.assertTrue(torch.equal(unnormalized_pixel_values, expected_pixel_values)) + + # Test normalize -> unnormalize back if the arrays match + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], do_resize=False, do_center_crop=False, return_tensors="pt") + unnormalized_image = image_processing.postprocess( + encoded_images.pixel_values, return_tensors="pt" + ).pixel_values + # the diff of 1 because the images are in range 0-255 in `uint8` and some precision errors might apply + self.assertTrue((image_inputs[0] - unnormalized_image).abs().max() <= 1) diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index bb2ba8b34281..7d71cb107f6d 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -19,18 +19,19 @@ import requests from parameterized import parameterized -from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed +from transformers import ChameleonConfig, ChameleonVQVAEConfig, is_torch_available, is_vision_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_read_token, require_torch, + require_torch_multi_gpu, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -44,6 +45,7 @@ ChameleonForConditionalGeneration, ChameleonModel, ChameleonProcessor, + ChameleonVQVAE, ) @@ -57,7 +59,9 @@ def __init__( use_input_mask=True, use_labels=True, vocab_size=99, - image_token_id=98, + image_token_index=1, + boi_token_id=97, + eoi_token_id=96, hidden_size=32, num_hidden_layers=2, num_attention_heads=2, @@ -70,6 +74,7 @@ def __init__( type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, + image_size=10, num_labels=3, num_choices=4, pad_token_id=0, @@ -81,12 +86,13 @@ def __init__( ): self.parent = parent self.batch_size = batch_size - self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask self.use_labels = use_labels self.vocab_size = vocab_size - self.image_token_id = image_token_id + self.image_token_index = image_token_index + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads @@ -103,13 +109,19 @@ def __init__( self.num_choices = num_choices self.pad_token_id = pad_token_id self.scope = scope + self.image_size = image_size self.vq_num_embeds = vq_num_embeds self.vq_embed_dim = vq_embed_dim self.vq_channel_multiplier = vq_channel_multiplier self.vq_img_token_start_id = vq_img_token_start_id + self.image_seq_length = 25 + self.seq_length = seq_length + self.image_seq_length def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids[input_ids == self.image_token_index] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_index + pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]) input_mask = None if self.use_input_mask: @@ -125,7 +137,7 @@ def prepare_config_and_inputs(self): config = self.get_config() - return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + return config, input_ids, input_mask, pixel_values, sequence_labels, token_labels, choice_labels def get_config(self): # create dummy vocab map for image2bpe mapping if it needs remapping @@ -134,11 +146,13 @@ def get_config(self): # we will need "vq_num_embeds" amount of tokens vocab_map = {i: chr(i) for i in range(self.vocab_size)} - vocab_map[self.image_token_id] = "" + vocab_map[self.image_token_index] = "" start = self.vq_img_token_start_id end = self.vq_img_token_start_id + self.vq_num_embeds for i in range(start, end): - vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG + image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i)) + # dummy str for each image token, anything starting with IMGIMG + vocab_map[i] = f"IMGIMG{image_token_infix}Z" return ChameleonConfig( vocab_size=self.vocab_size, @@ -157,6 +171,9 @@ def get_config(self): pad_token_id=self.pad_token_id, vocabulary_map={v: k for k, v in vocab_map.items()}, vq_config=self.get_vq_config(), + image_token_index=self.image_token_index, + boi_token_id=self.boi_token_id, + eoi_token_id=self.eoi_token_id, ) def get_vq_config(self): @@ -167,13 +184,16 @@ def get_vq_config(self): "in_channels": 3, "base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less "channel_multiplier": self.vq_channel_multiplier, + "initializer_range": self.initializer_range, } - def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + def create_and_check_model( + self, config, input_ids, input_mask, pixel_values, sequence_labels, token_labels, choice_labels + ): model = ChameleonModel(config=config) model.to(torch_device) model.eval() - result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, attention_mask=input_mask, pixel_values=pixel_values) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) @@ -260,11 +280,12 @@ def prepare_config_and_inputs_for_common(self): config, input_ids, input_mask, + pixel_values, sequence_labels, token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask, "pixel_values": pixel_values} return config, inputs_dict @@ -281,7 +302,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if is_torch_available() else {} ) - test_headmasking = False + test_head_masking = False test_pruning = False fx_compatible = False @@ -327,13 +348,149 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @unittest.skip("Chameleon forces some token ids to be -inf!") def test_batching_equivalence(self): pass - # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow - @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") - def test_generate_compile_fullgraph(self): + +class ChameleonVQModelTester: + def __init__( + self, + parent, + batch_size=5, + is_training=False, + initializer_range=0.02, + image_size=30, + num_embeds=12, + base_channels=32, # we have a GroupNorm of 32 groups, so can't do less + embed_dim=12, + channel_multiplier=[1, 2], + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.initializer_range = initializer_range + self.image_size = image_size + self.base_channels = base_channels + self.num_embeds = num_embeds + self.embed_dim = embed_dim + self.channel_multiplier = channel_multiplier + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return ChameleonVQVAEConfig( + embed_dim=self.embed_dim, + num_embeddings=self.num_embeds, + latent_channels=self.embed_dim, + in_channels=3, + base_channels=self.base_channels, + channel_multiplier=self.channel_multiplier, + initializer_range=self.initializer_range, + resolution=self.image_size, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ChameleonVQModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (ChameleonVQVAE,) if is_torch_available() else () + test_head_masking = False + test_pruning = False + fx_compatible = False + has_attentions = False + test_resize_embeddings = False + + def setUp(self): + self.model_tester = ChameleonVQModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=ChameleonVQVAEConfig, + has_text_modality=False, + common_properties=["embed_dim", "num_embeddings"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip("Chameleon VQ module cannot offload due to using `self.weight` directly") + def test_cpu_offload(self): + pass + + @unittest.skip("Chameleon VQ module cannot offload due to using `self.weight` directly") + def test_disk_offload_bin(self): + pass + + @unittest.skip("Chameleon VQ module cannot offload due to using `self.weight` directly") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip("Chameleon VQ module has no hidden states") + def test_hidden_states_output(self): + pass + + @unittest.skip("Chameleon VQ module has no hidden states") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip("Chameleon VQ module has no get/set embeddings method") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip("Chameleon VQ module has no hidden states") + def test_retain_grad_hidden_states_attentions(self): pass @@ -418,3 +575,29 @@ def test_model_7b_multi_image(self): generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + @require_read_token + @require_torch_multi_gpu + def test_model_7b_multi_gpu(self): + model = ChameleonForConditionalGeneration.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + device_map="auto", + max_memory={0: "1GB"}, + ) + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + prompt = "Describe what do you see here and tell me about the history behind it?" + + inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 12004cc3c8ad..0b964bfc67e8 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -789,6 +789,10 @@ def test_custom_4d_attention_mask(self): def test_generate_compile_fullgraph(self): pass + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") + def test_generate_with_static_cache(self): + pass + @unittest.skip(reason="We only test the model that takes in multiple images") def test_model(self): pass diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 25775d787e49..5dc5eda7318f 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -48,6 +48,7 @@ rgb_to_id, to_channel_dimension_format, to_pil_image, + unnormalize, ) @@ -346,6 +347,75 @@ def test_normalize(self): self.assertEqual(normalized_image.dtype, np.float32) self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6)) + def test_unnormalize(self): + image = np.random.randint(0, 256, (224, 224, 3)) / 255 + + # Test that exception is raised if inputs are incorrect + # Not a numpy array image + with self.assertRaises(ValueError): + unnormalize(5, 5, 5) + + # Number of mean values != number of channels + with self.assertRaises(ValueError): + unnormalize(image, mean=(0.5, 0.6), std=1) + + # Number of std values != number of channels + with self.assertRaises(ValueError): + unnormalize(image, mean=1, std=(0.5, 0.6)) + + # Test result is correct - output data format is channels_first and normalization + # correctly computed + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = (image * std + mean).transpose((2, 0, 1)) + + unnormalized_image = unnormalize(image, mean=mean, std=std, data_format="channels_first") + self.assertIsInstance(unnormalized_image, np.ndarray) + self.assertEqual(unnormalized_image.shape, (3, 224, 224)) + self.assertTrue(np.allclose(unnormalized_image, expected_image, atol=1e-6)) + + # Test image with 4 channels is unnormalized_image correctly + image = np.random.randint(0, 256, (224, 224, 4)) / 255 + mean = (0.5, 0.6, 0.7, 0.8) + std = (0.1, 0.2, 0.3, 0.4) + expected_image = (image * std) + mean + self.assertTrue( + np.allclose( + unnormalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image, atol=1e-6 + ) + ) + + # Test float32 image input keeps float32 dtype + image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float32) / 255 + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = ((image * std) + mean).astype(np.float32) + unnormalized_image = unnormalize(image, mean=mean, std=std) + self.assertEqual(unnormalized_image.dtype, np.float32) + self.assertTrue(np.allclose(unnormalized_image, expected_image, atol=1e-6)) + + # Test float16 image input keeps float16 dtype + image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255 + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + + # The mean and std are cast to match the dtype of the input image + cast_mean = np.array(mean, dtype=np.float16) + cast_std = np.array(std, dtype=np.float16) + expected_image = (image * cast_std) + cast_mean + unnormalized_image = unnormalize(image, mean=mean, std=std) + self.assertEqual(unnormalized_image.dtype, np.float16) + self.assertTrue(np.allclose(unnormalized_image, expected_image, atol=1e-3)) + + # Test int image input is converted to float32 + image = np.random.randint(0, 2, (224, 224, 3), dtype=np.uint8) + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = (image.astype(np.float32) * std) + mean + unnormalized_image = unnormalize(image, mean=mean, std=std) + self.assertEqual(unnormalized_image.dtype, np.float32) + self.assertTrue(np.allclose(unnormalized_image, expected_image, atol=1e-6)) + def test_center_crop(self): image = np.random.randint(0, 256, (3, 224, 224)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 99d0a8058c67..8108b8c02eeb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2954,6 +2954,10 @@ def test_inputs_embeds(self): model.to(torch_device) model.eval() + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest(reason="This model doesn't use `inputs_embeds`") + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) if not self.is_encoder_decoder: