-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Llama: fix custom 4D masks #29930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: fix custom 4D masks #29930
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1003,38 +1003,34 @@ def _update_causal_mask( | |
| dtype, device = input_tensor.dtype, input_tensor.device | ||
| min_dtype = torch.finfo(dtype).min | ||
| sequence_length = input_tensor.shape[1] | ||
| if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache | ||
| target_length = self.config.max_position_embeddings | ||
| else: # dynamic cache | ||
| target_length = ( | ||
| attention_mask.shape[-1] | ||
| if isinstance(attention_mask, torch.Tensor) | ||
| else past_seen_tokens + sequence_length + 1 | ||
| ) | ||
|
|
||
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) | ||
| if sequence_length != 1: | ||
| causal_mask = torch.triu(causal_mask, diagonal=1) | ||
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | ||
| causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) | ||
| if attention_mask is not None: | ||
| causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | ||
| if attention_mask.dim() == 2: | ||
| if attention_mask is not None and attention_mask.dim() == 4: | ||
| # we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape | ||
| # as the causal mask (i.e. [..., seq_len, full_len]) | ||
| mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype | ||
| offset = cache_position[0] | ||
| if attention_mask.shape[-2] == offset + sequence_length: | ||
| mask_slice = mask_slice[..., offset:, :] | ||
| causal_mask = mask_slice | ||
| else: | ||
|
||
| if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache | ||
| target_length = self.config.max_position_embeddings | ||
| else: # dynamic cache | ||
| target_length = ( | ||
| attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 | ||
| ) | ||
| causal_mask = torch.full( | ||
| (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device | ||
| ) | ||
| if sequence_length != 1: | ||
| causal_mask = torch.triu(causal_mask, diagonal=1) | ||
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | ||
| causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) | ||
| if attention_mask is not None: | ||
| causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | ||
| mask_length = attention_mask.shape[-1] | ||
| padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) | ||
| causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) | ||
| elif attention_mask.dim() == 4: | ||
| # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with | ||
| # cache. In that case, the 4D attention mask attends to the newest tokens only. | ||
| if attention_mask.shape[-2] < cache_position[0] + sequence_length: | ||
| offset = cache_position[0] | ||
| else: | ||
| offset = 0 | ||
| mask_shape = attention_mask.shape | ||
| mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype | ||
| causal_mask[ | ||
| : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] | ||
| ] = mask_slice | ||
|
|
||
| if ( | ||
| self.config._attn_implementation == "sdpa" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| # limitations under the License. | ||
| """ Testing suite for the PyTorch LLaMA model. """ | ||
|
|
||
| import gc | ||
| import tempfile | ||
| import unittest | ||
|
|
||
|
|
@@ -821,3 +822,138 @@ def test_model_7b_logits(self): | |
| ] | ||
| infilling = tokenizer.batch_decode(generated_ids) | ||
| self.assertEqual(infilling, EXPECTED_INFILLING) | ||
|
|
||
|
|
||
| @slow | ||
|
||
| @require_torch_gpu | ||
| class Mask4DTestHard(unittest.TestCase): | ||
| def tearDown(self): | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def setUp(self): | ||
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | ||
| self.model_dtype = torch.float32 | ||
| self.tokenizer = LlamaTokenizer.from_pretrained(model_name) | ||
| self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) | ||
|
|
||
| def get_test_data(self): | ||
| template = "my favorite {}" | ||
| items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item | ||
|
|
||
| batch_separate = [template.format(x) for x in items] # 3 separate lines | ||
| batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated | ||
|
|
||
| input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) | ||
| input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) | ||
|
|
||
| mask_shared_prefix = torch.tensor( | ||
| [ | ||
| [ | ||
| [ | ||
| [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], | ||
| [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], | ||
| [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], | ||
| ] | ||
| ] | ||
| ], | ||
| device=torch_device, | ||
| dtype=torch.int64, | ||
| ) | ||
|
|
||
| position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) | ||
| # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) | ||
| position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) # same but nicer | ||
|
|
||
| return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix | ||
|
|
||
| def test_stacked_causal_mask(self): | ||
| ( | ||
| input_ids, | ||
| position_ids, | ||
| input_ids_shared_prefix, | ||
| mask_shared_prefix, | ||
| position_ids_shared_prefix, | ||
| ) = self.get_test_data() | ||
|
|
||
| # regular batch | ||
| logits = self.model.forward(input_ids, position_ids=position_ids).logits | ||
| logits_last = logits[:, -1, :] # last tokens in each batch line | ||
| decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] | ||
|
|
||
| # single forward run with 4D custom mask | ||
| logits_shared_prefix = self.model.forward( | ||
| input_ids_shared_prefix, attention_mask=mask_shared_prefix.bool(), position_ids=position_ids_shared_prefix | ||
| ).logits | ||
| logits_shared_prefix_last = logits_shared_prefix[ | ||
| 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : | ||
| ] # last three tokens | ||
| decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] | ||
|
|
||
| self.assertEqual(decoded, decoded_shared_prefix) | ||
|
|
||
| def test_partial_stacked_causal_mask(self): | ||
| # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention | ||
| # masks | ||
|
|
||
| ( | ||
| input_ids, | ||
| position_ids, | ||
| input_ids_shared_prefix, | ||
| mask_shared_prefix, | ||
| position_ids_shared_prefix, | ||
| ) = self.get_test_data() | ||
|
|
||
| # regular batch | ||
| logits = self.model.forward(input_ids, position_ids=position_ids).logits | ||
| logits_last = logits[:, -1, :] # last tokens in each batch line | ||
| decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] | ||
|
|
||
| # 2 forward runs with custom 4D masks | ||
| part_a = 3 # split point | ||
|
|
||
| input_1a = input_ids_shared_prefix[:, :part_a] | ||
| position_ids_1a = position_ids_shared_prefix[:, :part_a] | ||
| mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] | ||
|
|
||
| outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a) | ||
| past_key_values_a = outs_1a["past_key_values"] | ||
|
|
||
| # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) | ||
| input_1b = input_ids_shared_prefix[:, part_a:] | ||
| position_ids_1b = position_ids_shared_prefix[:, part_a:] | ||
| mask_1b = mask_shared_prefix[:, :, part_a:, :] | ||
| outs_1b = self.model.forward( | ||
| input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a | ||
| ) | ||
| decoded_1b = [ | ||
| self.tokenizer.decode(t) | ||
| for t in outs_1b.logits.argmax(-1)[ | ||
| 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a | ||
| ] | ||
| ] | ||
| self.assertEqual(decoded, decoded_1b) | ||
|
|
||
| # Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len]) | ||
| input_1c = input_ids_shared_prefix[:, part_a:] | ||
| position_ids_1c = position_ids_shared_prefix[:, part_a:] | ||
| mask_1c = mask_shared_prefix | ||
| outs_1c = self.model.forward( | ||
| input_1c, attention_mask=mask_1c.bool(), position_ids=position_ids_1c, past_key_values=past_key_values_a | ||
| ) | ||
| decoded_1c = [ | ||
| self.tokenizer.decode(t) | ||
| for t in outs_1c.logits.argmax(-1)[ | ||
| 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a | ||
| ] | ||
| ] | ||
| self.assertEqual(decoded, decoded_1c) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordered the logic: custom 4D masks are now a superset of the default mask, so we don't need to create the default mask first :)