Skip to content

Commit

Permalink
Make tensor split contiguous (#6580)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <[email protected]>
  • Loading branch information
aklife97 committed May 8, 2023
1 parent 680cdac commit f2103da
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def forward(
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(
mixed_x_layer, 3, contiguous_split_chunks=True
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
Expand All @@ -386,7 +388,9 @@ def forward(
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(
mixed_kv_layer, 2, contiguous_split_chunks=True
)

# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
Expand Down

0 comments on commit f2103da

Please sign in to comment.