Skip to content
31 changes: 29 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,22 @@ def _prepare_4d_causal_attention_mask(
key_value_length = input_shape[-1] + past_key_values_length

# 4d mask is passed through the layers
if attention_mask is not None:
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
Expand Down Expand Up @@ -340,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
is_tracing = torch.jit.is_tracing()

if attention_mask is not None:
if torch.all(attention_mask == 1):
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
Expand Down
98 changes: 98 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import glob
import json
import os
Expand Down Expand Up @@ -1850,3 +1851,100 @@ def test_not_available_sdpa(self):
)

self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))


@require_torch
@slow
class Mask4DTest(unittest.TestCase):
def setUp(self):
self.device = torch.device("cuda:0")
model_name = "JackFram/llama-160m" # small Llama-like model from FlexFlow
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(self.device)

the smaller the better for our CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that fp16 tests are more noisy, so what I did is:

  • retained fp32 testsm but used even smaller model
  • added fp16 test with relaxed tolerances
  • added fp16 testing option for the top tokens order.


def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
gc.collect()
torch.cuda.empty_cache()

def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
input_0 = torch.tensor(encoded, device=self.device)
# tensor([[ 1, 278, 6635, 3290],
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')

# Combining common prefix with the unique ending tokens:
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')

# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
device="cuda:0",
dtype=torch.int64,
)

# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=self.device, dtype=torch.int64)

return input_0, input_1, mask_1, position_ids_1

def test_attention(self):
"""comparing outputs of attention layer"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
# outs_0.shape == torch.Size([3, 4, 768])

hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
)[0]
# outs_1.shape == torch.Size([1, 6, 768])

outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens, atol=1e-8)

def test_model(self):
"""comparing hidden outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
assert torch.allclose(
logits_0_last_tokens, logits_1_last_tokens, atol=1e-5
) # note higher atol set to deal with noise

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
assert torch.allclose(
logits_0_last_tokens, logits_1_last_tokens, atol=1e-5
) # note higher atol set to deal with noise