Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/transformers/models/janus/modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class JanusPreTrainedModel(PreTrainedModel):
config_class = JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
_supports_sdpa = True
Expand Down Expand Up @@ -1054,6 +1054,10 @@ class JanusModel(JanusPreTrainedModel):
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config

# Language model is initialized first to get the correct device map
self.language_model = AutoModel.from_config(config=config.text_config)

# This is necessary for backward compatibility, see SiglipModel initialization
self.vision_model = JanusVisionModel._from_config(config.vision_config)
self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
Expand All @@ -1066,8 +1070,6 @@ def __init__(self, config: JanusConfig):
self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
self.generation_head = JanusVQVAEHead(self.vqmodel.config)

self.language_model = AutoModel.from_config(config=config.text_config)

self.gradient_checkpointing = False
# Initialize weights and apply final processing.
self.post_init()
Expand Down Expand Up @@ -1133,6 +1135,7 @@ def forward(
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)

image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

Expand Down Expand Up @@ -1395,9 +1398,15 @@ def generate(
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask

# Get BOI token ID
if hasattr(generation_config, "generation_kwargs"):
boi_token_id = generation_config.generation_kwargs.get("boi_token_id", generation_config.bos_token_id)
else:
boi_token_id = kwargs.get("boi_token_id", generation_config.bos_token_id)

# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
input_tokens[batch_size:, :] != boi_token_id
)
input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)

Expand Down
17 changes: 13 additions & 4 deletions src/transformers/models/janus/modular_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel):
config_class = JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
_supports_sdpa = True
Expand Down Expand Up @@ -892,6 +892,10 @@ class JanusModel(JanusPreTrainedModel):
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config

# Language model is initialized first to get the correct device map
self.language_model = AutoModel.from_config(config=config.text_config)

# This is necessary for backward compatibility, see SiglipModel initialization
self.vision_model = JanusVisionModel._from_config(config.vision_config)
self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
Expand All @@ -904,8 +908,6 @@ def __init__(self, config: JanusConfig):
self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
self.generation_head = JanusVQVAEHead(self.vqmodel.config)

self.language_model = AutoModel.from_config(config=config.text_config)

self.gradient_checkpointing = False
# Initialize weights and apply final processing.
self.post_init()
Expand Down Expand Up @@ -971,6 +973,7 @@ def forward(
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)

image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

Expand Down Expand Up @@ -1233,9 +1236,15 @@ def generate(
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask

# Get BOI token ID
if hasattr(generation_config, "generation_kwargs"):
boi_token_id = generation_config.generation_kwargs.get("boi_token_id", generation_config.bos_token_id)
else:
boi_token_id = kwargs.get("boi_token_id", generation_config.bos_token_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @remi-or , The other changes look logical to me, thanks for fixing them 🤗 . Can you expand on why boi_token_id won't be present in the generation_kwargs coz AFAIK I have added it explicitly in conversion file; hence should be present in checkpoints 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yaswanth19 , it seems like it is missing from the checkpoint used in testing (deepseek-community/Janus-Pro-1B) .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like generation_config is either 1. not loaded, which would be weird because guidance_scale is equal to 5 (but idk if that's usual) or 2. the generation_kwargs attribute is dropped at some point. Before model.generate is called in the test, if I check the value of model.generation_config I get:

model.generation_config = GenerationConfig {
  "bos_token_id": 100000,
  "eos_token_id": 100001,
  "pad_token_id": 100002
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@remi-or do you have a snippet reproducing the issue, would be nice to add it in the PR body as well 🤗

I wonder if the issue you describe appears only on certain hardware (which is very unlikely) or the inference script is doing smth unexpected

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! It's taken from tests/models/janus/test_modeling_janus.py::JanusIntegrationTest::test_model_generate_image :

from transformers import AutoProcessor, JanusForConditionalGeneration

if __name__ == "__main__":

    model_id = "deepseek-community/Janus-Pro-1B"
    model = JanusForConditionalGeneration.from_pretrained(model_id, device_map="auto")
    processor = AutoProcessor.from_pretrained(model_id)

    inputs = processor(
        text=["A portrait of young girl. masterpiece, film grained, best quality."],
        padding=True, generation_mode="image", return_tensors="pt",
    ).to(model.device)

    out = model.generate(**inputs, generation_mode="image", do_sample=False)

I tried changing device_map to CPU and it still crashed with AttributeError: 'GenerationConfig' object has no attribute 'generation_kwargs' so I dont think it's device-related.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, weird since I'd assume the config from the hub would be picked up. At least that was try for Whisper in the past. Let me check why this isn't loaded, we better make sure the pre-saved config values are used when running inference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, the issue was in the config saved in the hub. One of the flags was set to True thus overwriting config values from scratch

I think now the only issue is the multi-device inferece. @remi-or can you update the PR so we can merge?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
input_tokens[batch_size:, :] != boi_token_id
)
input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)

Expand Down
22 changes: 16 additions & 6 deletions tests/models/janus/test_modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.testing_utils import (
Expectations,
require_torch,
slow,
torch_device,
Expand Down Expand Up @@ -538,12 +539,21 @@ def test_model_generate_images(self):
self.assertTrue(out.shape[1] == 576)

# fmt: off
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
]).to(model.device)
expected_tokens = Expectations(
{
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
],
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
],
}
)
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
# fmt: on

# Compare the first 50 generated tokens.
Expand Down