Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 83 additions & 42 deletions tests/models/mistral3/test_modeling_mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@

import unittest

import accelerate

from transformers import (
AutoProcessor,
Mistral3Config,
is_bitsandbytes_available,
is_torch_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_bitsandbytes,
require_deterministic_for_xpu,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand All @@ -46,10 +48,6 @@
)


if is_bitsandbytes_available():
from transformers import BitsAndBytesConfig


class Mistral3VisionText2TextModelTester:
def __init__(
self,
Expand Down Expand Up @@ -292,20 +290,23 @@ def test_flex_attention_with_grads(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class Mistral3IntegrationTest(unittest.TestCase):
@require_read_token
def setUp(self):
cleanup(torch_device, gc_collect=True)
self.model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
self.model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, torch_dtype=torch.bfloat16
)
accelerate.cpu_offload(self.model, execution_device=torch_device)

def tearDown(self):
cleanup(torch_device, gc_collect=True)

@require_read_token
def test_mistral3_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)

messages = [
{
Expand All @@ -321,19 +322,23 @@ def test_mistral3_integration_generate_text_only(self):
).to(torch_device, dtype=torch.bfloat16)

with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
generate_ids = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "Sure, here's a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace."
expected_outputs = Expectations(
{
("xpu", 3): "Sure, here is a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace.",
("cuda", 7): "Sure, here is a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace.",
("cuda", 8): "Sure, here is a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace.",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(decoded_output, expected_output)

@require_read_token
def test_mistral3_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
Expand All @@ -348,25 +353,32 @@ def test_mistral3_integration_generate(self):
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
generate_ids = self.model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"

expected_outputs = Expectations(
{
("xpu", 3): "The image features two cats resting on a pink blanket. The cat on the left is a kitten",
("cuda", 7): "The image features two cats resting on a pink blanket. The cat on the left is a kitten",
("cuda", 8): "The image features two cats resting on a pink blanket. The cat on the left is a small kit",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()

self.assertEqual(decoded_output, expected_output)

@require_read_token
@require_deterministic_for_xpu
def test_mistral3_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{"type": "image", "url": "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/view.jpg"},
{"type": "text", "text": "Write a haiku for this image"},
],
},
Expand All @@ -384,44 +396,57 @@ def test_mistral3_integration_batched_generate(self):

inputs = processor.apply_chat_template(
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
).to(torch_device, dtype=torch.bfloat16)

output = self.model.generate(**inputs, do_sample=False, max_new_tokens=25)

output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
gen_tokens = output[:, inputs["input_ids"].shape[1] :]

# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's mirror gleams,\nWhispering pines"
decoded_output = processor.decode(gen_tokens[0], skip_special_tokens=True)

expected_outputs = Expectations(
{
("xpu", 3): "Calm lake's mirror gleams,\nWhispering pines stand in silence,\nPath to peace begins.",
("cuda", 7): "Calm waters reflect\nWhispering pines stand in silence\nPath to peace begins",
("cuda", 8): "Calm waters reflect\nWhispering pines stand in silence\nPath to peace begins",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"
decoded_output = processor.decode(gen_tokens[1], skip_special_tokens=True)
expected_outputs = Expectations(
{
("xpu", 3): "The image depicts a vibrant urban scene in what appears to be Chinatown. The focal point is a traditional Chinese archway",
("cuda", 7): 'The image depicts a vibrant street scene in Chinatown, likely in a major city. The focal point is a traditional Chinese',
("cuda", 8): 'The image depicts a vibrant street scene in what appears to be Chinatown in a major city. The focal point is a',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

@require_read_token
@require_bitsandbytes
@require_deterministic_for_xpu
def test_mistral3_integration_batched_generate_multi_image(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, quantization_config=quantization_config
)

# Prepare inputs
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{"type": "image", "url": "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/view.jpg"},
{"type": "text", "text": "Write a haiku for this image"},
],
},
Expand All @@ -432,11 +457,11 @@ def test_mistral3_integration_batched_generate_multi_image(self):
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
"url": "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
"url": "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
},
{
"type": "text",
Expand All @@ -448,22 +473,38 @@ def test_mistral3_integration_batched_generate_multi_image(self):
]
inputs = processor.apply_chat_template(
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.float16)
).to(torch_device, dtype=torch.bfloat16)

output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=25)
gen_tokens = output[:, inputs["input_ids"].shape[1] :]

# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n"
decoded_output = processor.decode(gen_tokens[0], skip_special_tokens=True)
expected_outputs = Expectations(
{
("xpu", 3): "Still lake reflects skies,\nWooden path to nature's heart,\nSilence speaks volumes.",
("cuda", 7): "Calm waters reflect\nWhispering pines stand in silence\nPath to peace begins",
("cuda", 8): "Calm waters reflect\nWhispering pines stand in silence\nPath to peace begins",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = "These images depict two different landmarks. Can you identify them?Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."
decoded_output = processor.decode(gen_tokens[1], skip_special_tokens=True)
expected_outputs = Expectations(
{
("xpu", 3): "Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City.",
("cuda", 7): "Certainly! The images depict the following landmarks:\n\n1. The first image shows the Statue of Liberty and the New York City",
("cuda", 8): "Certainly! The images depict the following landmarks:\n\n1. The first image shows the Statue of Liberty and the New York City",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()

self.assertEqual(
decoded_output,
expected_output,
Expand Down