Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,8 +1774,12 @@ def forward(
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if labels is not None:
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
else:
loss = None
logits = outputs.logits if return_dict else outputs[0]

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2243,8 +2247,12 @@ def forward(
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if labels is not None:
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
else:
loss = None
logits = outputs.logits if return_dict else outputs[0]

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2341,24 +2349,11 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
bos_tokens = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)
return outputs


Expand Down
25 changes: 4 additions & 21 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,27 +1628,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -1663,27 +1663,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
43 changes: 22 additions & 21 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@


class GenerationTesterMixin:
input_name = "input_ids"
model_tester = None
all_generative_model_classes = ()
max_new_tokens = 3
Expand Down Expand Up @@ -407,7 +408,7 @@ def _contrastive_generate(
def test_greedy_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
Expand All @@ -421,7 +422,7 @@ def test_greedy_generate(self):
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
Expand Down Expand Up @@ -452,7 +453,7 @@ def test_greedy_generate_dict_outputs(self):
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
Expand Down Expand Up @@ -483,7 +484,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
Expand All @@ -497,7 +498,7 @@ def test_sample_generate(self):
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
Expand Down Expand Up @@ -529,7 +530,7 @@ def test_sample_generate_dict_output(self):
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()

Expand All @@ -545,7 +546,7 @@ def test_beam_search_generate(self):
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
Expand Down Expand Up @@ -579,7 +580,7 @@ def test_beam_search_generate_dict_output(self):
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
Expand Down Expand Up @@ -644,7 +645,7 @@ def test_model_parallel_beam_search(self):
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
Expand Down Expand Up @@ -686,7 +687,7 @@ def test_beam_sample_generate(self):
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
Expand Down Expand Up @@ -743,7 +744,7 @@ def test_generate_without_input_ids(self):
def test_group_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
# check `generate()` and `group_beam_search()` are equal
Expand Down Expand Up @@ -775,7 +776,7 @@ def test_group_beam_search_generate(self):
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
Expand Down Expand Up @@ -811,7 +812,7 @@ def test_group_beam_search_generate_dict_output(self):
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()

Expand Down Expand Up @@ -868,7 +869,7 @@ def test_constrained_beam_search_generate(self):
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

model = model_class(config).to(torch_device).eval()

Expand Down Expand Up @@ -920,7 +921,7 @@ def test_contrastive_generate(self):
self.skipTest(reason="Won't fix: old model with different cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -950,7 +951,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
self.skipTest(reason="Won't fix: old model with different cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1105,7 +1106,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1179,7 +1180,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1235,7 +1236,7 @@ def test_dola_decoding_sample(self):

# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# Encoder-decoder models are not supported
if config.is_encoder_decoder:
Expand Down Expand Up @@ -1292,7 +1293,7 @@ def test_assisted_decoding_sample(self):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1838,7 +1839,7 @@ def test_generate_with_static_cache(self):
self.skipTest(reason="This model does not support the static cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
main_input = inputs_dict[self.input_name]

if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
Expand Down
Loading