diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..220d18dddfb9 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1003,38 +1003,34 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape + # as the causal mask (i.e. [..., seq_len, full_len]) + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + offset = cache_position[0] + if attention_mask.shape[-2] == offset + sequence_length: + mask_slice = mask_slice[..., offset:, :] + causal_mask = mask_slice + else: + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..60b5ff112039 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -989,38 +989,34 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape + # as the causal mask (i.e. [..., seq_len, full_len]) + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + offset = cache_position[0] + if attention_mask.shape[-2] == offset + sequence_length: + mask_slice = mask_slice[..., offset:, :] + causal_mask = mask_slice + else: + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..cb4aac11bffd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1081,38 +1081,34 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape + # as the causal mask (i.e. [..., seq_len, full_len]) + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + offset = cache_position[0] + if attention_mask.shape[-2] == offset + sequence_length: + mask_slice = mask_slice[..., offset:, :] + causal_mask = mask_slice + else: + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c81..84c97c6c5211 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -14,6 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch LLaMA model. """ +import gc import tempfile import unittest @@ -821,3 +822,138 @@ def test_model_7b_logits(self): ] infilling = tokenizer.batch_decode(generated_ids) self.assertEqual(infilling, EXPECTED_INFILLING) + + +@slow +@require_torch_gpu +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def setUp(self): + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = LlamaTokenizer.from_pretrained(model_name) + self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + dtype=torch.int64, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) # same but nicer + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix.bool(), position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention + # masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + # Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len]) + input_1c = input_ids_shared_prefix[:, part_a:] + position_ids_1c = position_ids_shared_prefix[:, part_a:] + mask_1c = mask_shared_prefix + outs_1c = self.model.forward( + input_1c, attention_mask=mask_1c.bool(), position_ids=position_ids_1c, past_key_values=past_key_values_a + ) + decoded_1c = [ + self.tokenizer.decode(t) + for t in outs_1c.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1c) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa..49d6affd61e0 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -505,6 +505,12 @@ def test_load_balancing_loss(self): # This is to mimic torch.testing.assert_not_close self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + # TODO: fix me + @unittest.skip("Test is failing on Mixtral, needs to be fixed") + # Ignore copy + def test_custom_4d_attention_mask_logits(self): + pass + @require_torch class MixtralIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1c099a4035b4..486d83252aef 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4132,6 +4132,101 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + def _get_custom_4d_mask_test_data(self): + # Sequence in which all but the last token is the same + input_ids = torch.tensor( + [[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64 + ) + position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64) + + # Combining common prefix with the unique ending tokens: + input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0) + + # Creating a 4D mask where each of the last 3 tokens do not attend to each other. + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 1, 0], + [1, 1, 1, 0, 0, 1], + ] + ] + ], + device=torch_device, + dtype=torch.int64, + ) + + # Creating a position_ids tensor. note the repeating figures in the end. + position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64) + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_custom_4d_attention_mask(self): + if len(self.all_generative_model_classes) == 0: + self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + causal_mask_shared_prefix = (1 - mask_shared_prefix).to(model.dtype) * torch.finfo(model.dtype).min + + input_embeds = model.model.embed_tokens(input_ids) + model_output = model.model.layers[0].self_attn.forward(input_embeds, position_ids=position_ids)[0] + # model_output.shape == torch.Size([3, 4, ...]) + + input_embeds_shared_prefix = model.model.embed_tokens(input_ids_shared_prefix) + model_output_shared_prefix = model.model.layers[0].self_attn.forward( + input_embeds_shared_prefix, + attention_mask=causal_mask_shared_prefix, + position_ids=position_ids_shared_prefix, + )[0] + # model_output_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = model_output[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = model_output_shared_prefix[0, -3:, :] # last three tokens + torch.testing.assert_close(out_last_tokens, out_shared_prefix_last_tokens) + + def test_custom_4d_attention_mask_logits(self): + if len(self.all_generative_model_classes) == 0: + self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward(input_ids, position_ids=position_ids).logits + logits_shared_prefix = model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + + logits_last_tokens = logits[:, -1, :] # last tokens in each batch line + logits_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + torch.testing.assert_close(logits_last_tokens, logits_shared_prefix_last_tokens) + global_rng = random.Random() diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 37ae919a448c..c0fa7e43a817 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import gc import glob import json import os @@ -2058,230 +2057,6 @@ def test_not_available_sdpa(self): self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception)) - -@require_torch_gpu -class Mask4DTestBase(unittest.TestCase): - def tearDown(self): - gc.collect() - torch.cuda.empty_cache() - - def get_test_data(self): - texts = ["the cat sat", "the cat had", "the cat is"] - encoded = [self.tokenizer.encode(t) for t in texts] - input_0 = torch.tensor(encoded, device=torch_device) - # tensor([[ 1, 278, 6635, 3290], - # [ 1, 278, 6635, 750], - # [ 1, 278, 6635, 338]], device='cuda:0') - - position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64) - - # Combining common prefix with the unique ending tokens: - input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0) - # tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0') - - # Creating a 4D mask where each of the last 3 tokens do not attend to each other. - mask_1 = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 0, 1, 0], - [1, 1, 1, 0, 0, 1], - ] - ] - ], - device="cuda:0", - dtype=torch.int64, - ) - - # Creating a position_ids tensor. note the repeating figures in the end. - position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64) - - return input_0, position_ids_0, input_1, mask_1, position_ids_1 - - -@require_torch_gpu -class Mask4DTestFP32(Mask4DTestBase): - def setUp(self): - model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow - self.model_dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def test_attention(self): - """comparing outputs of attention layer""" - # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention - input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min - - hid_0 = self.model.model.embed_tokens(input_0) - outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0] - # outs_0.shape == torch.Size([3, 4, 768]) - - hid_1 = self.model.model.embed_tokens(input_1) - outs_1 = self.model.model.layers[0].self_attn.forward( - hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1 - )[0] - # outs_1.shape == torch.Size([1, 6, 768]) - - outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line - outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens - torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens) - - def test_causal_model_logits(self): - """comparing logits outputs of whole inner model""" - # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention - input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - - logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits - logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits - - logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line - logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens - torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens) - - -@require_torch_gpu -class Mask4DTestFP16(Mask4DTestBase): - test_attention = Mask4DTestFP32.test_attention - - def setUp(self): - model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow - self.model_dtype = torch.float16 - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def test_causal_model_logits(self): - """comparing logits outputs of whole inner model""" - # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention - input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - - logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits - logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits - - logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line - logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens - - indices_0 = logits_0_last_tokens.sort(descending=True).indices - indices_1 = logits_1_last_tokens.sort(descending=True).indices - - # checking logits, but note relaxed tolerances for FP16 - torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001) - - # checking tokens order for the top tokens - for token_ids_0, token_ids_1 in zip(indices_0, indices_1): - self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128])) - - -@slow -@require_torch_gpu -class Mask4DTestHard(unittest.TestCase): - def tearDown(self): - gc.collect() - torch.cuda.empty_cache() - - def setUp(self): - model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - self.model_dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_0 = [template.format(x) for x in items] # 3 separate lines - batch_1 = template.format(" ".join(items)) # 1 line with options concatenated - - input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device) - input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device) - - mask_1 = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, - dtype=torch.int64, - ) - - position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device) - # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer - - return input_0, position_ids_0, input_1, mask_1, position_ids_1 - - def test_stacked_causal_mask(self): - # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention - input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - - # regular batch - logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits - logits_0_last = logits_0[:, -1, :] # last tokens in each batch line - decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits - logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens - decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)] - - self.assertEqual(decoded_0, decoded_1) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention - # masks - - # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention - input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - - # regular batch - logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits - logits_0_last = logits_0[:, -1, :] # last tokens in each batch line - decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_1[:, :part_a] - position_ids_1a = position_ids_1[:, :part_a] - mask_1a = mask_1[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - input_1b = input_1[:, part_a:] - position_ids_1b = position_ids_1[:, part_a:] - mask_1b = mask_1[:, :, part_a:, :] - - outs_1b = self.model.forward( - input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a - ) - - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a] - ] - - self.assertEqual(decoded_0, decoded_1b) - - @require_torch class TestTensorSharing(TestCasePlus): def test_disjoint(self):