diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 141e37457..79df3a983 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -594,6 +594,17 @@ def forward( queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) + if torch._dynamo.is_compiling(): + queries = ( + queries.transpose(-1, -2) + .contiguous() + .transpose(-1, -2) + .contiguous() + ) + keys = ( + keys.transpose(-1, -2).contiguous().transpose(-1, -2).contiguous() + ) + # Apply normalization per head queries = self.q_norm(queries) keys = self.k_norm(keys)