Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,11 +1768,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs will have loss and logits twice there

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this was probably already a bug)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, in general i don't like that we return it as this and would better return unwrapped lm outputs. But we can't prob just delete it for BC reasons


if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2233,11 +2234,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2334,24 +2336,11 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
bos_tokens = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)
return outputs


Expand Down
25 changes: 4 additions & 21 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,27 +1625,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -1660,27 +1660,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
1 change: 1 addition & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@


class GenerationTesterMixin:
input_name = "input_ids"
model_tester = None
all_generative_model_classes = ()
max_new_tokens = 3
Expand Down
Loading