-
Notifications
You must be signed in to change notification settings - Fork 32.3k
Improve support for image generation with Chameleon & Anole #32013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1ce50c6
d5c0be1
de4a196
65cbc08
bbc28fc
cb45340
3182f74
bf051b3
7b83d11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,13 +50,18 @@ The original code can be found [here](https://github.com/facebookresearch/chamel | |
|
|
||
| - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating. | ||
|
|
||
| - When generating images, we advice users to load the model in `bfloat16` for better results. Simply make sure to set `torch_dtype=torch.bfloat16` when loading the model. | ||
|
|
||
| - Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question. | ||
|
|
||
| - Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor. | ||
|
|
||
| > [!NOTE] | ||
| > Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation. | ||
|
|
||
| > [!NOTE] | ||
| > The official model checkpoint currently only supports text generation. To generate images and interleaved text-image responses, you can use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135). Note however that Anole has a bias for "empty" or background patches, so it is recommended to use sampling when generating images (i.e. setting `do_sample=True` during generation) to reduce the likelihood of generating a blank image. | ||
|
|
||
| ## Usage example | ||
|
|
||
| ### Single image inference | ||
|
|
@@ -124,6 +129,142 @@ generate_ids = model.generate(**inputs, max_new_tokens=50) | |
| processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | ||
| ``` | ||
|
|
||
| ### Text to image generation | ||
leloykun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Chameleon can also generate images. However, the official model checkpoint currently only supports text generation. We need to use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135) to do image generation. Here is how you can do it: | ||
|
|
||
| ```python | ||
| import torch | ||
| from transformers import ChameleonProcessor, ChameleonForConditionalGeneration | ||
|
|
||
| processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||
| model = ChameleonForConditionalGeneration.from_pretrained( | ||
| "leloy/Anole-7b-v0.1-hf", | ||
| device_map="auto", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example failed in my environment with 4 gpus, complaining about device unmatch.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @minostauros can you provide the script you used for this? The complete error message would also help. Thank you!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed to remove the >>> import accelerate
>>> accelerate.__version__
'0.30.1'
>>> import torch
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> from PIL import Image
>>>
>>> processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
Some kwargs in processor config are unused and will not have any effect: image_token, image_seq_length.
>>> model = ChameleonForConditionalGeneration.from_pretrained(
... "leloy/Anole-7b-v0.1-hf",
... device_map="auto",
... torch_dtype=torch.bfloat16,
... )
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.11it/s]
>>> model.device
device(type='cuda', index=0)
>>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
>>> image_snowman = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a variation of this image.<image>"
>>> inputs = processor(
... prompt,
... images=[image_snowman],
... padding=True,
... return_tensors="pt",
... ).to(model.device, dtype=model.dtype)
>>> generate_ids = model.generate(
... **inputs,
... multimodal_generation_mode="image-only",
... # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
... max_new_tokens=1026,
... # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
... do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
return super().generate(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
image_tokens = self.get_image_tokens(pixel_values)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1427, in get_image_tokens
return self.img2bpe_mapping_tensor[image_toks]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
>>> inputs.input_ids.device
device(type='cuda', index=0)
>>> inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
>>> inputs.pixel_values.device
device(type='cuda', index=0)
>>> model = model.cuda()
You shouldn't move a model that is dispatched using accelerate hooks.
>>> generate_ids = model.generate(
... **inputs,
... multimodal_generation_mode="image-only",
... # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
... max_new_tokens=1026,
... # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
... do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
return super().generate(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
image_tokens = self.get_image_tokens(pixel_values)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1426, in get_image_tokens
_, _, image_toks = self.vqmodel.encode(pixel_values)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1159, in encode
hidden_states = self.encoder(pixel_values)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 979, in forward
hidden_states = [self.conv_in(pixel_values)]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)
>>> model = ChameleonForConditionalGeneration.from_pretrained(
... "leloy/Anole-7b-v0.1-hf",
... torch_dtype=torch.bfloat16,
... ).to(device=0)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.02it/s]
>>> generate_ids = model.generate(
... **inputs,
... multimodal_generation_mode="image-only",
... # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
... max_new_tokens=1026,
... # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
... do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
>>> generate_ids.shape
torch.Size([1, 2062])
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating accelerate to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @minostauros does this happen with the base Chameleon model? I.e. without this PR? The issue with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, I've never seen this happen before but I suspect it's because of the btw, I wouldn't be able to run tests myself for the next few hours as I'm still traveling
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Great point! # Decode the generated image tokens
pixel_values = model.decode_image_tokens(response_ids[:, 1:-1])
images = processor.postprocess_pixel_values(pixel_values)
# Save the image
from torchvision.transforms.functional import to_pil_image
images = [to_pil_image(img.detach().cpu()) for img in images]
images[0].save("snowman.png")Perhaps just removing the 255 scaling and type casting in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output after postprocessing should have the same shape, range, and dtype as the original image so it's better to keep it this way IMO
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've also just added a test for model sharding btw pls check it out!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code now works like a charm! Thanks a lot for your contribution. prompt: 'A piece of paper with word like "Anole" written on it, and a drawing of an Anole.'
How may I improve the results? |
||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Prepare a prompt | ||
| prompt = "Generate an image of a snowman." | ||
|
|
||
| # Preprocess the prompt | ||
| inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) | ||
|
|
||
| # Generate discrete image tokens | ||
| generate_ids = model.generate( | ||
leloykun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **inputs, | ||
| multimodal_generation_mode="image-only", | ||
| # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token. | ||
| max_new_tokens=1026, | ||
zucchini-nlp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. | ||
| do_sample=True, | ||
| ) | ||
|
|
||
| # Only keep the tokens from the response | ||
| response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] | ||
|
|
||
| # Decode the generated image tokens | ||
| pixel_values = model.decode_image_tokens(response_ids[:, 1:-1]) | ||
| images = processor.postprocess_pixel_values(pixel_values) | ||
|
|
||
| # Save the image | ||
| images[0].save("snowman.png") | ||
| ``` | ||
|
|
||
| ### Text-image to image generation | ||
|
|
||
| We can also interleave text and images in the prompt to generate images. Here is how you can do it: | ||
|
|
||
| ```python | ||
| import requests | ||
|
|
||
| import torch | ||
| from PIL import Image | ||
| from transformers import ChameleonProcessor, ChameleonForConditionalGeneration | ||
| from transformers.image_transforms import to_pil_image | ||
|
|
||
| processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||
| model = ChameleonForConditionalGeneration.from_pretrained( | ||
| "leloy/Anole-7b-v0.1-hf", | ||
| device_map="auto", | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Get image of a snowman | ||
| url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" | ||
| image_snowman = Image.open(requests.get(url, stream=True).raw) | ||
|
|
||
| # Prepare a prompt | ||
| prompt = "Generate a variation of this image.<image>" | ||
|
|
||
| # Preprocess the prompt | ||
| inputs = processor( | ||
| images=[image_snowman], | ||
| text=prompt, | ||
| padding=True, | ||
| return_tensors="pt", | ||
| ).to(model.device, dtype=model.dtype) | ||
|
|
||
| # Generate discrete image tokens | ||
| generate_ids = model.generate( | ||
| **inputs, | ||
| multimodal_generation_mode="image-only", | ||
| # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. | ||
| do_sample=True, | ||
| ) | ||
|
|
||
| # Only keep the tokens from the response | ||
| response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] | ||
|
|
||
| # The generated image tokens are wrapped by the `image_start_token` and `image_end_token` tokens. We need to remove them before decoding the image tokens. | ||
| image_token_ids = response_ids[:, 1:-1] | ||
|
|
||
| # Decode the generated image tokens | ||
| pixel_values = model.decode_image_tokens(image_token_ids) | ||
| pixel_values = processor.postprocess_pixel_values(pixel_values) | ||
|
|
||
| # Save the image | ||
| image = to_pil_image(pixel_values[0].detach().cpu()) | ||
| image.save("snowman.png") | ||
| ``` | ||
|
|
||
| ### Interleaved text-image generation | ||
|
|
||
| We can also generate interleaved text and images in the output. Here is how you can do it: | ||
|
|
||
| ```python | ||
| import torch | ||
| from transformers import ChameleonProcessor, ChameleonForConditionalGeneration | ||
|
|
||
| processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||
| model = ChameleonForConditionalGeneration.from_pretrained( | ||
| "leloy/Anole-7b-v0.1-hf", | ||
| device_map="auto", | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Prepare a prompt | ||
| prompt = "Can you draw a snowman and explain how to build one?" | ||
|
|
||
| # Preprocess the prompt | ||
| inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) | ||
|
|
||
| # Generate interleaved text and discrete image tokens | ||
| generate_ids = model.generate( | ||
| **inputs, | ||
| multimodal_generation_mode="interleaved-text-image", | ||
| # Note: We will need a larger `max_new_tokens` value since we are generating both text and image tokens. | ||
| max_new_tokens=4096, | ||
| # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. | ||
| do_sample=True, | ||
| ) | ||
|
|
||
| # Only keep the tokens from the response | ||
| response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] | ||
| ``` | ||
|
|
||
| From here, you can split the response tokens into text and image token segments, decode them separately as shown in the previous examples, and finally render the resulting text and images together. You can also use [MMSG](https://github.com/leloykun/mmsg) to do this more easily. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be demonstrated - it's not obvious how this should be done
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ➕ on this! Let's show a snippet that uses mmsg maybe? 🤗 |
||
|
|
||
| ## Model optimization | ||
|
|
||
| ### Quantization using Bitsandbytes | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1778,6 +1778,61 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to | |||||
| return scores_processed | ||||||
|
|
||||||
|
|
||||||
| class SuppressTokensInIndexRangeLogitsProcessor(LogitsProcessor): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The custom logits processors should all be tested
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw why don't we just use |
||||||
| r""" | ||||||
| [`SuppressTokensInIndexRangeLogitsProcessor`] supresses a list of tokens from `start_index` to `end_index` (exclusive) | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per the other logits processors, a code example demo-ing the behaviour should be added here |
||||||
| Args: | ||||||
| suppress_tokens (`List[int]`): | ||||||
| List of token ids to suppress during generation. | ||||||
| start_index (`int`): | ||||||
| The index at which to start suppressing tokens. | ||||||
| end_index (`int`, *optional*): | ||||||
| The index at which to end suppressing tokens. If `None`, it will suppress tokens indefinitely. | ||||||
| device (`str`, *optional*, defaults to `"cpu"`): | ||||||
| The device to allocate the tensors. | ||||||
|
|
||||||
| Examples: | ||||||
|
|
||||||
| ```python | ||||||
| >>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList | ||||||
| >>> from transformers.generation.logits_process import SuppressTokensInIndexRangeLogitsProcessor | ||||||
| >>> import torch | ||||||
|
|
||||||
| >>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||||||
| >>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||||||
|
|
||||||
| >>> inputs = processor("Can you draw a snowman?", return_tensors="pt") | ||||||
| >>> max_length = 1200 | ||||||
| >>> # Don't start generating an image if there aren't enough space for the rest of the image tokens. | ||||||
| >>> logits_processor = SuppressTokensInIndexRangeLogitsProcessor( | ||||||
| ... suppress_tokens=[model.vocabulary_mapping.boi_token_id], | ||||||
| ... start_index=max_length - model.model.image_seq_length - 1, | ||||||
| ... device=model.device, | ||||||
| ... ) | ||||||
|
|
||||||
| >>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor])) | ||||||
| >>> print(torch.isin(outputs[input.input_ids.shape[1] + 1 : ], model.vocabulary_mapping.image_token_ids).all()) | ||||||
| True | ||||||
| ``` | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, suppress_tokens: List[int], start_index: int, end_index: Optional[int] = None, device: str = "cpu" | ||||||
| ): | ||||||
| self.suppress_tokens = torch.tensor(suppress_tokens, device=device) | ||||||
| self.start_index = start_index | ||||||
| self.end_index = end_index if end_index is not None else math.inf | ||||||
|
|
||||||
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||||||
| current_index = input_ids.shape[1] | ||||||
| if self.start_index > current_index or current_index > self.end_index: | ||||||
| return scores | ||||||
| suppress_tokens_mask = torch.zeros_like(scores, dtype=torch.bool) | ||||||
| suppress_tokens_mask[:, self.suppress_tokens] = True | ||||||
| return scores.masked_fill(suppress_tokens_mask, -float("inf")) | ||||||
|
|
||||||
|
|
||||||
| class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): | ||||||
| r""" | ||||||
| [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts | ||||||
|
|
@@ -2953,3 +3008,83 @@ def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> | |||||
| The expected mean g-value for watermarked text. | ||||||
| """ | ||||||
| return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size)) | ||||||
|
|
||||||
|
|
||||||
| class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor): | ||||||
| r""" | ||||||
| [`AllowOnlyTokensInRelativeWindowLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens | ||||||
| that can be generated at a relative window from a trigger token (e.g. begin image token). If `exclusive` is set to | ||||||
| `True`, the set of tokens allowed at this window will not be allowed anywhere else. This is useful for enforcing | ||||||
| multimodal generation constraints. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain why someone would want to do that? |
||||||
|
|
||||||
| Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon). | ||||||
|
|
||||||
zucchini-nlp marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| Args: | ||||||
| trigger_token_id (`int`): | ||||||
| The token id that triggers the window check. | ||||||
| allowed_token_ids (`List[int]`): | ||||||
| The list of token ids that are allowed at the specified relative window. | ||||||
| window_width (`int`): | ||||||
| The window_width of the window from the trigger token. | ||||||
| exclusive (`bool`, *optional*, defaults to `False`): | ||||||
| If `True`, the set of tokens allowed at this window will not be allowed anywhere else. | ||||||
| device (`str`, *optional*, defaults to `cpu`): | ||||||
| The device to allocate the util tensor on. | ||||||
|
|
||||||
| Examples: | ||||||
|
|
||||||
| ```python | ||||||
| >>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| >>> from transformers.generation.logits_process import AllowOnlyTokensInRelativeWindowLogitsProcessor | ||||||
| >>> import torch | ||||||
|
|
||||||
| >>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||||||
| >>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf") | ||||||
|
|
||||||
| >>> inputs = processor("Can you draw a snowman?", return_tensors="pt") | ||||||
| >>> max_length = 1200 | ||||||
| >>> # Generate only image token ids for `image_seq_length` steps when the boi-token is already generated | ||||||
| >>> logits_processor = AllowOnlyTokensInRelativeWindowLogitsProcessor( | ||||||
| ... trigger_token_id=model.vocabulary_mapping.boi_token_id, | ||||||
| ... allowed_token_ids=model.vocabulary_mapping.image_token_ids, | ||||||
| ... window_width=model.model.image_seq_length, | ||||||
| ... exclusive=True, | ||||||
| ... device=model.device, | ||||||
| ... ) | ||||||
|
|
||||||
| >>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor])) | ||||||
| ``` | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| trigger_token_id: int, | ||||||
| allowed_token_ids: List[int], | ||||||
| window_width: int, | ||||||
| exclusive: bool = False, | ||||||
| device: str = "cpu", | ||||||
| ): | ||||||
| self.trigger_token_id = trigger_token_id | ||||||
| self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device).unsqueeze(0) | ||||||
| self.window_width = window_width | ||||||
| self.exclusive = exclusive | ||||||
|
|
||||||
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||||||
| if input_ids.shape[1] < self.window_width and not self.exclusive: | ||||||
| return scores | ||||||
|
|
||||||
| window_width = min(self.window_width, input_ids.shape[1]) | ||||||
| trigger_positions = (input_ids[:, -window_width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1) | ||||||
|
|
||||||
| disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool) | ||||||
| disallowed_tokens_mask[:, self.allowed_token_ids] = False | ||||||
|
|
||||||
| if self.exclusive: | ||||||
| return scores.masked_fill( | ||||||
| ~(disallowed_tokens_mask ^ trigger_positions), | ||||||
| -float("inf"), | ||||||
| ) | ||||||
| return scores.masked_fill( | ||||||
| disallowed_tokens_mask & trigger_positions, | ||||||
| -float("inf"), | ||||||
| ) | ||||||



Uh oh!
There was an error while loading. Please reload this page.