Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
Expand Down Expand Up @@ -147,7 +146,7 @@ def test_model_2b_bf16(self):

EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
Expand All @@ -168,34 +167,12 @@ def test_model_2b_eager(self):

EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_torch_sdpa
@require_read_token
def test_model_2b_sdpa(self):
model_id = "google/gemma-2b"

EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]

# bfloat16 gives strange values, likely due to it has lower precision + very short prompts
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, attn_implementation="eager"
)
model.to(torch_device)

Expand Down