Skip to content

Commit

Permalink
one more test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2024
1 parent 70c59b6 commit f5d3907
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,10 @@ def test_laser():

model(x)

@pytest.mark.parametrize('self_attn_custom_pos', (True, False))
@pytest.mark.parametrize('cross_attn_rotary', (True, False))
def test_cross_attn_rotary(
self_attn_custom_pos: bool,
cross_attn_rotary: bool
):

Expand All @@ -577,12 +579,14 @@ def test_cross_attn_rotary(
cross_attn_dim_context = 512
)

context_pos = torch.arange(128)
pos = torch.arange(64) if self_attn_custom_pos else None
context_pos = torch.arange(128) if cross_attn_rotary else None

embed = model(
x = x,
mask = mask,
context = context,
context_pos = context_pos if cross_attn_rotary else None,
pos = pos,
context_pos = context_pos,
context_mask = context_mask
)

0 comments on commit f5d3907

Please sign in to comment.