From 972c1ce9b810e178ff9c6de303aa10b56517e40f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 24 Mar 2026 10:58:27 -0300 Subject: [PATCH 1/2] add workaround for aiu compilation issue Signed-off-by: Max de Bayser --- fms/modules/attention.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 141e37457..b5e182d33 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -594,6 +594,17 @@ def forward( queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) + if torch._dynamo.is_compiling(): + queries = ( + queries.transpose(-1, -2) + .contiguous() + .transpose(-1, -2) + .contiguous() + ) + keys = ( + keys.transpose(-1, -2).contiguous().transpose(-1, -2).contiguous() + ) + # Apply normalization per head queries = self.q_norm(queries) keys = self.k_norm(keys) @@ -949,14 +960,13 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert nheads % world_size == 0, ( - "The number of heads must be divisible by world size" - ) - assert (kvheads >= world_size and kvheads % world_size == 0) or ( - kvheads < world_size and world_size % kvheads == 0 - ), ( - "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" - ) + assert ( + nheads % world_size == 0 + ), "The number of heads must be divisible by world size" + assert ( + (kvheads >= world_size and kvheads % world_size == 0) + or (kvheads < world_size and world_size % kvheads == 0) + ), "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" MultiHeadAttention.__init__( self, emb_dim, From 30d747df1d3d7874cb13c65a7c6d81e9d622e461 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 24 Mar 2026 11:12:46 -0300 Subject: [PATCH 2/2] appease ruff Signed-off-by: Max de Bayser --- fms/modules/attention.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index b5e182d33..79df3a983 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -960,13 +960,14 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert ( - nheads % world_size == 0 - ), "The number of heads must be divisible by world size" - assert ( - (kvheads >= world_size and kvheads % world_size == 0) - or (kvheads < world_size and world_size % kvheads == 0) - ), "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" + assert nheads % world_size == 0, ( + "The number of heads must be divisible by world size" + ) + assert (kvheads >= world_size and kvheads % world_size == 0) or ( + kvheads < world_size and world_size % kvheads == 0 + ), ( + "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" + ) MultiHeadAttention.__init__( self, emb_dim,