Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -967,7 +967,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -993,9 +993,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -959,7 +959,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -986,9 +986,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -1094,9 +1094,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
110 changes: 109 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,6 @@ def test_not_available_sdpa(self):
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))


@slow
@require_torch_gpu
class Mask4DTestBase(unittest.TestCase):
def tearDown(self):
Expand Down Expand Up @@ -2011,6 +2010,7 @@ def setUp(self):

def test_attention(self):
"""comparing outputs of attention layer"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min

Expand All @@ -2030,6 +2030,7 @@ def test_attention(self):

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
Expand All @@ -2052,6 +2053,7 @@ def setUp(self):

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
Expand All @@ -2069,3 +2071,109 @@ def test_causal_model_logits(self):
# checking tokens order for the top tokens
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))


@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)

def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item

batch_0 = [template.format(x) for x in items] # 3 separate lines
batch_1 = template.format(" ".join(items)) # 1 line with options concatenated

input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)

mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
dtype=torch.int64,
)

position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer

return input_0, position_ids_0, input_1, mask_1, position_ids_1

def test_stacked_causal_mask(self):
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]

# single forward run with 4D custom mask
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens
decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]

self.assertEqual(decoded_0, decoded_1)

def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
# masks

# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]

# 2 forward runs with custom 4D masks
part_a = 3 # split point

input_1a = input_1[:, :part_a]
position_ids_1a = position_ids_1[:, :part_a]
mask_1a = mask_1[:, :, :part_a, :part_a]

outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]

input_1b = input_1[:, part_a:]
position_ids_1b = position_ids_1[:, part_a:]
mask_1b = mask_1[:, :, part_a:, :]

outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
)

decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
]

self.assertEqual(decoded_0, decoded_1b)