diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index 46714244a14b..ac6fa3c2672a 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -14,12 +14,14 @@ """Testing suite for the PyTorch PhiMoE model.""" +import copy import unittest from parameterized import parameterized from transformers import PhimoeConfig, StaticCache, is_torch_available from transformers.testing_utils import ( + cleanup, require_torch, slow, torch_device, @@ -130,31 +132,47 @@ def test_model_rope_scaling_from_config(self, scaling_type): @slow @require_torch class PhimoeIntegrationTest(unittest.TestCase): - def test_model_phimoe_instruct_logits(self): - input_ids = { - "input_ids": torch.tensor( - [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device + model = None + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = PhimoeForCausalLM.from_pretrained( + "microsoft/Phi-3.5-MoE-instruct", dtype="auto", device_map="auto" ) - } + return cls.model + + @classmethod + def tearDownClass(cls): + del cls.model + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) - model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct").to(torch_device) + def test_model_phimoe_instruct_logits(self): + input_ids = {"input_ids": torch.tensor([[1212, 318, 281, 1672]], dtype=torch.long, device=torch_device)} + + model = self.get_model() model.eval() - output = model(**input_ids).logits + with torch.no_grad(): + output = model(**input_ids).logits - EXPECTED_OUTPUT = torch.tensor([[-3.5312, -2.5000, -1.2734, 0.3555, -0.7578, -0.4727, 0.5977, -0.4316, - 0.2256, -1.2188, -1.6797, 0.9961, 3.7656, 11.3125, -1.3828, -4.8438, - -5.7500, -1.9375, 0.7227, -0.3438, -0.2100, -0.4277, -0.0444, -0.5352, - -0.6406, -0.1016, -0.4258, -1.0234, 0.4297, -0.6250], - [-0.9883, 0.1455, -0.4902, 2.3594, 0.7031, 3.1406, 0.4375, 0.2559, - 0.6172, -2.1094, -1.3359, 2.5938, 4.9062, 10.8125, -0.1094, 1.5781, - -4.9375, 0.7148, -0.0972, 1.7656, -0.0801, 0.2217, 0.1875, -0.4629, - 1.5781, 0.3535, 0.0874, 0.6836, -0.0518, -1.2969]]).to(torch_device) # fmt: skip + EXPECTED_OUTPUT = torch.tensor( + [ + [-3.4844, -2.4531, -1.1719, 0.6055, -0.4922, -0.1001, 0.8086, -0.2422, 0.3477, -1.0078], + [-0.9766, 0.1631, -0.5508, 2.3594, 0.7031, 3.1719, 0.4141, 0.2305, 0.6055, -2.1250], + ] + ).to(device=torch_device, dtype=output.dtype) # fmt: skip - torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output[0, :2, :10], EXPECTED_OUTPUT, rtol=1e-4, atol=1e-4) def test_phimoe_instruct_generation(self): - model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + model = self.get_model() tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") messages = [ @@ -166,17 +184,29 @@ def test_phimoe_instruct_generation(self): ] inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") - outputs = model.generate(inputs, max_new_tokens=32) + outputs = model.generate(inputs, max_new_tokens=30) output_text = tokenizer.batch_decode(outputs) EXPECTED_OUTPUT = [ - "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create tast" + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create", ] - self.assertListEqual(output_text, EXPECTED_OUTPUT) def test_phimoe_instruct_with_static_cache(self): - model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + model = self.get_model() + # Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one + config = copy.deepcopy(model.config) + config.num_hidden_layers = 2 + # make `head_dim = 128` + config.hidden_size = 512 + config.num_attention_heads = 4 + config.num_key_value_heads = 1 + config.intermediate_size = 512 + config.max_position_embeddinqgs = 64 + config.num_local_experts = 4 + torch.manual_seed(42) + model = PhimoeForCausalLM(config).to(torch_device) + model.eval() tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") messages = [ @@ -186,14 +216,17 @@ def test_phimoe_instruct_with_static_cache(self): }, {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, ] - inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to( + torch_device + ) - response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, 64) + response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=30) output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device)) + # This is dummy outputs. We actually check if it could run with static cache, not the output quality. EXPECTED_OUTPUT = [ - "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can" + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> awards" ] self.assertListEqual(output_text, EXPECTED_OUTPUT)