Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_model_7b_logits_bf16(self):
("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
("rocm", (9, 4)): torch.tensor([[-6.5094, -4.1329, -4.9754, -3.5042, 0.8082, -2.9443, 1.2830, -3.3539]]),
("rocm", (9, 4)): torch.tensor([[-6.5067, -4.1154, -4.9819, -3.1408, 0.8117, -2.9435, 1.2883, -3.3221]]),
})

expected_mean = expected_means.get_expectation().to(torch_device)
Expand Down
18 changes: 13 additions & 5 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformers import (
AutoTokenizer,
ByT5Tokenizer,
GenerationConfig,
T5EncoderModel,
T5ForConditionalGeneration,
T5ForQuestionAnswering,
Expand Down Expand Up @@ -932,7 +933,17 @@ def is_pipeline_test_to_skip(


def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
task_params = model.config.task_specific_params[task]

# Get all valid GenerationConfig attributes
temp_config = GenerationConfig()
generation_config_attrs = set(temp_config.to_dict().keys())

for key, value in task_params.items():
if key in generation_config_attrs:
setattr(model.generation_config, key, value)
else:
setattr(model.config, key, value)


@require_torch
Expand Down Expand Up @@ -1032,14 +1043,11 @@ def test_torch_quant(self):
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")

input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device)

sequences = model.generate(input_ids)
sequences = model.generate(input_ids, max_length=8, num_beams=1, do_sample=False)

output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
Expand Down