Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ The original code can be found [here](https://github.com/facebookresearch/chamel

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating.

- When generating images, we advice users to load the model in `bfloat16` for better results. Simply make sure to set `torch_dtype=torch.bfloat16` when loading the model.

- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question.

- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor.

> [!NOTE]
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.

> [!NOTE]
> The official model checkpoint currently only supports text generation. To generate images and interleaved text-image responses, you can use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135). Note however that Anole has a bias for "empty" or background patches, so it is recommended to use sampling when generating images (i.e. setting `do_sample=True` during generation) to reduce the likelihood of generating a blank image.

## Usage example

### Single image inference
Expand Down Expand Up @@ -124,6 +129,142 @@ generate_ids = model.generate(**inputs, max_new_tokens=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
```

### Text to image generation

Chameleon can also generate images. However, the official model checkpoint currently only supports text generation. We need to use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135) to do image generation. Here is how you can do it:

```python
import torch
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example failed in my environment with 4 gpus, complaining about device unmatch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minostauros can you provide the script you used for this? The complete error message would also help.

Thank you!

Copy link
Contributor

@minostauros minostauros Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to remove the device_map="auto" and manually send the model to specific cuda to properly run the code.

>>> import accelerate
>>> accelerate.__version__
'0.30.1'
>>> import torch
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> from PIL import Image
>>> 
>>> processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
Some kwargs in processor config are unused and will not have any effect: image_token, image_seq_length. 
>>> model = ChameleonForConditionalGeneration.from_pretrained(
...     "leloy/Anole-7b-v0.1-hf",
...     device_map="auto",
...     torch_dtype=torch.bfloat16,
... )
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.11it/s]
>>> model.device
device(type='cuda', index=0)
>>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
>>> image_snowman = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a variation of this image.<image>"
>>> inputs = processor(
...     prompt,
...     images=[image_snowman],
...     padding=True,
...     return_tensors="pt",
... ).to(model.device, dtype=model.dtype)
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
    return super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
    image_tokens = self.get_image_tokens(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1427, in get_image_tokens
    return self.img2bpe_mapping_tensor[image_toks]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
>>> inputs.input_ids.device
device(type='cuda', index=0)
>>> inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
>>> inputs.pixel_values.device
device(type='cuda', index=0)
>>> model = model.cuda()
You shouldn't move a model that is dispatched using accelerate hooks.
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
    return super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
    image_tokens = self.get_image_tokens(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1426, in get_image_tokens
    _, _, image_toks = self.vqmodel.encode(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1159, in encode
    hidden_states = self.encoder(pixel_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 979, in forward
    hidden_states = [self.conv_in(pixel_values)]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)
>>> model = ChameleonForConditionalGeneration.from_pretrained(
...     "leloy/Anole-7b-v0.1-hf",
...     torch_dtype=torch.bfloat16,
... ).to(device=0)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.02it/s]
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
>>> generate_ids.shape
torch.Size([1, 2062])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating accelerate to 0.33.0 did not help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minostauros does this happen with the base Chameleon model? I.e. without this PR?

The issue with F.conv2d may be unrelated to this PR but the issue with return self.img2bpe_mapping_tensor[image_toks] definitely is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I've never seen this happen before but I suspect it's because of the .float() (iirc, to_pil_image rescales the numpy array if it's of float type). What happens if you remove it or cast the array to uint8?

btw, I wouldn't be able to run tests myself for the next few hours as I'm still traveling

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you remove it or cast the array to uint8?

Great point!

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(response_ids[:, 1:-1])
images = processor.postprocess_pixel_values(pixel_values)

# Save the image
from torchvision.transforms.functional import to_pil_image
images = [to_pil_image(img.detach().cpu()) for img in images]
images[0].save("snowman.png")

snowman2

Perhaps just removing the 255 scaling and type casting in ChameleonImagProcessor.postprocess() may also support torchvision.utils.save_image().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output after postprocessing should have the same shape, range, and dtype as the original image so it's better to keep it this way IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also just added a test for model sharding btw

pls check it out!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code now works like a charm! Thanks a lot for your contribution.
Besides, the output does not seem as good as the anole paper states.

prompt: 'A piece of paper with word like "Anole" written on it, and a drawing of an Anole.'

  • from paper
image
  • from "leloy/Anole-7b-v0.1-hf"
image

How may I improve the results?

torch_dtype=torch.bfloat16,
)

# Prepare a prompt
prompt = "Generate an image of a snowman."

# Preprocess the prompt
inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype)

# Generate discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="image-only",
# Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
max_new_tokens=1026,
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(response_ids[:, 1:-1])
images = processor.postprocess_pixel_values(pixel_values)

# Save the image
images[0].save("snowman.png")
```

### Text-image to image generation

We can also interleave text and images in the prompt to generate images. Here is how you can do it:

```python
import requests

import torch
from PIL import Image
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
from transformers.image_transforms import to_pil_image

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Get image of a snowman
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image_snowman = Image.open(requests.get(url, stream=True).raw)

# Prepare a prompt
prompt = "Generate a variation of this image.<image>"

# Preprocess the prompt
inputs = processor(
images=[image_snowman],
text=prompt,
padding=True,
return_tensors="pt",
).to(model.device, dtype=model.dtype)

# Generate discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="image-only",
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]

# The generated image tokens are wrapped by the `image_start_token` and `image_end_token` tokens. We need to remove them before decoding the image tokens.
image_token_ids = response_ids[:, 1:-1]

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(image_token_ids)
pixel_values = processor.postprocess_pixel_values(pixel_values)

# Save the image
image = to_pil_image(pixel_values[0].detach().cpu())
image.save("snowman.png")
```

### Interleaved text-image generation

We can also generate interleaved text and images in the output. Here is how you can do it:

```python
import torch
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Prepare a prompt
prompt = "Can you draw a snowman and explain how to build one?"

# Preprocess the prompt
inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype)

# Generate interleaved text and discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="interleaved-text-image",
# Note: We will need a larger `max_new_tokens` value since we are generating both text and image tokens.
max_new_tokens=4096,
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]
```

From here, you can split the response tokens into text and image token segments, decode them separately as shown in the previous examples, and finally render the resulting text and images together. You can also use [MMSG](https://github.com/leloykun/mmsg) to do this more easily.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be demonstrated - it's not obvious how this should be done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ on this! Let's show a snippet that uses mmsg maybe? 🤗


## Model optimization

### Quantization using Bitsandbytes
Expand Down
135 changes: 135 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,61 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class SuppressTokensInIndexRangeLogitsProcessor(LogitsProcessor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom logits processors should all be tested

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw why don't we just use SuppressTokensLogitsProcessor? it supports passing a list of tokens and does not require adding a new one!

r"""
[`SuppressTokensInIndexRangeLogitsProcessor`] supresses a list of tokens from `start_index` to `end_index` (exclusive)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the other logits processors, a code example demo-ing the behaviour should be added here

Args:
suppress_tokens (`List[int]`):
List of token ids to suppress during generation.
start_index (`int`):
The index at which to start suppressing tokens.
end_index (`int`, *optional*):
The index at which to end suppressing tokens. If `None`, it will suppress tokens indefinitely.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.

Examples:

```python
>>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList
>>> from transformers.generation.logits_process import SuppressTokensInIndexRangeLogitsProcessor
>>> import torch

>>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf")
>>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf")

>>> inputs = processor("Can you draw a snowman?", return_tensors="pt")
>>> max_length = 1200
>>> # Don't start generating an image if there aren't enough space for the rest of the image tokens.
>>> logits_processor = SuppressTokensInIndexRangeLogitsProcessor(
... suppress_tokens=[model.vocabulary_mapping.boi_token_id],
... start_index=max_length - model.model.image_seq_length - 1,
... device=model.device,
... )

>>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor]))
>>> print(torch.isin(outputs[input.input_ids.shape[1] + 1 : ], model.vocabulary_mapping.image_token_ids).all())
True
```
"""

def __init__(
self, suppress_tokens: List[int], start_index: int, end_index: Optional[int] = None, device: str = "cpu"
):
self.suppress_tokens = torch.tensor(suppress_tokens, device=device)
self.start_index = start_index
self.end_index = end_index if end_index is not None else math.inf

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
current_index = input_ids.shape[1]
if self.start_index > current_index or current_index > self.end_index:
return scores
suppress_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
suppress_tokens_mask[:, self.suppress_tokens] = True
return scores.masked_fill(suppress_tokens_mask, -float("inf"))


class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
Expand Down Expand Up @@ -2953,3 +3008,83 @@ def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) ->
The expected mean g-value for watermarked text.
"""
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))


class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
r"""
[`AllowOnlyTokensInRelativeWindowLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens
that can be generated at a relative window from a trigger token (e.g. begin image token). If `exclusive` is set to
`True`, the set of tokens allowed at this window will not be allowed anywhere else. This is useful for enforcing
multimodal generation constraints.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why someone would want to do that?


Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon).

Args:
trigger_token_id (`int`):
The token id that triggers the window check.
allowed_token_ids (`List[int]`):
The list of token ids that are allowed at the specified relative window.
window_width (`int`):
The window_width of the window from the trigger token.
exclusive (`bool`, *optional*, defaults to `False`):
If `True`, the set of tokens allowed at this window will not be allowed anywhere else.
device (`str`, *optional*, defaults to `cpu`):
The device to allocate the util tensor on.

Examples:

```python
>>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
>>> from transformers import AutoProcessir, ChameleonForConditionalGenerartion, LogitsProcessorList
>>> from transformers import AutoProcessor, ChameleonForConditionalGeneration, LogitsProcessorList

>>> from transformers.generation.logits_process import AllowOnlyTokensInRelativeWindowLogitsProcessor
>>> import torch

>>> model = ChameleonForConditionalGenerartion.from_pretrained("leloy/Anole-7b-v0.1-hf")
>>> processor = AutoProcessir.from_pretrained("leloy/Anole-7b-v0.1-hf")

>>> inputs = processor("Can you draw a snowman?", return_tensors="pt")
>>> max_length = 1200
>>> # Generate only image token ids for `image_seq_length` steps when the boi-token is already generated
>>> logits_processor = AllowOnlyTokensInRelativeWindowLogitsProcessor(
... trigger_token_id=model.vocabulary_mapping.boi_token_id,
... allowed_token_ids=model.vocabulary_mapping.image_token_ids,
... window_width=model.model.image_seq_length,
... exclusive=True,
... device=model.device,
... )

>>> outputs = model.generate(**inputs, max_length=max_length, logits_processors=LogitsProcessorList([logits_processor]))
```
"""

def __init__(
self,
trigger_token_id: int,
allowed_token_ids: List[int],
window_width: int,
exclusive: bool = False,
device: str = "cpu",
):
self.trigger_token_id = trigger_token_id
self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device).unsqueeze(0)
self.window_width = window_width
self.exclusive = exclusive

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] < self.window_width and not self.exclusive:
return scores

window_width = min(self.window_width, input_ids.shape[1])
trigger_positions = (input_ids[:, -window_width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)

disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool)
disallowed_tokens_mask[:, self.allowed_token_ids] = False

if self.exclusive:
return scores.masked_fill(
~(disallowed_tokens_mask ^ trigger_positions),
-float("inf"),
)
return scores.masked_fill(
disallowed_tokens_mask & trigger_positions,
-float("inf"),
)
39 changes: 38 additions & 1 deletion src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from .image_processing_base import BatchFeature, ImageProcessingMixin
from .image_transforms import center_crop, normalize, rescale
from .image_transforms import center_crop, normalize, rescale, unnormalize
from .image_utils import ChannelDimension
from .utils import logging

Expand Down Expand Up @@ -112,6 +112,43 @@ def normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)

def unnormalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.

Args:
image (`np.ndarray`):
Image to unnormalize.
mean (`float` or `Iterable[float]`):
Image mean to use for unnormalization.
std (`float` or `Iterable[float]`):
Image standard deviation to use for unnormalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

Returns:
`np.ndarray`: The normalized image.
"""
return unnormalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)

def center_crop(
self,
image: np.ndarray,
Expand Down
Loading