-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[janus] Fix failing tests on mi3XX #38426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0808a34
a772bf8
82109b4
985b75c
a21b570
55d84ee
19977a7
3a82946
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -379,7 +379,7 @@ class JanusPreTrainedModel(PreTrainedModel): | |
| config_class = JanusConfig | ||
| base_model_prefix = "model" | ||
| supports_gradient_checkpointing = True | ||
| _no_split_modules = ["LlamaDecoderLayer"] | ||
| _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] | ||
| _skip_keys_device_placement = ["past_key_values", "causal_mask"] | ||
| _supports_flash_attn_2 = True | ||
| _supports_sdpa = True | ||
|
|
@@ -892,6 +892,10 @@ class JanusModel(JanusPreTrainedModel): | |
| def __init__(self, config: JanusConfig): | ||
| super().__init__(config) | ||
| self.config = config | ||
|
|
||
| # Language model is initialized first to get the correct device map | ||
| self.language_model = AutoModel.from_config(config=config.text_config) | ||
|
|
||
| # This is necessary for backward compatibility, see SiglipModel initialization | ||
| self.vision_model = JanusVisionModel._from_config(config.vision_config) | ||
| self.aligner = JanusVisionAlignerMLP(self.vision_model.config) | ||
|
|
@@ -904,8 +908,6 @@ def __init__(self, config: JanusConfig): | |
| self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config) | ||
| self.generation_head = JanusVQVAEHead(self.vqmodel.config) | ||
|
|
||
| self.language_model = AutoModel.from_config(config=config.text_config) | ||
|
|
||
| self.gradient_checkpointing = False | ||
| # Initialize weights and apply final processing. | ||
| self.post_init() | ||
|
|
@@ -971,6 +973,7 @@ def forward( | |
| image_features = image_embeds.reshape(-1, embed_dim) | ||
| image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | ||
|
|
||
| image_attention_mask = image_attention_mask.to(inputs_embeds.device) | ||
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) | ||
|
|
||
|
|
@@ -1233,9 +1236,15 @@ def generate( | |
| attention_mask = attention_mask.repeat(2, 1) | ||
| model_kwargs["attention_mask"] = attention_mask | ||
|
|
||
| # Get BOI token ID | ||
| if hasattr(generation_config, "generation_kwargs"): | ||
| boi_token_id = generation_config.generation_kwargs.get("boi_token_id", generation_config.bos_token_id) | ||
| else: | ||
| boi_token_id = kwargs.get("boi_token_id", generation_config.bos_token_id) | ||
|
|
||
|
||
| # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits. | ||
| mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & ( | ||
| input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"] | ||
| input_tokens[batch_size:, :] != boi_token_id | ||
| ) | ||
| input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.