Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fast-glu activation in change partitions #6909

Merged
merged 2 commits into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/nlp/language_modeling/megatron_change_num_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def compute_tp_splits(
# alias the global index to idx
idx = global_idx

swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower()
fast_glu_activation = str(model_cfg.get('activation', '')).lower() in ['fast-geglu', 'fast-swiglu', 'fast-reglu']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add swiglu to the list

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swiglu didn't use the torch chunk tricks, so we don't need to handle the partition. Only fast_glu_activation need.
https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/megatron/mlp.py#L230-L231

Copy link
Collaborator Author

@hsiehjackson hsiehjackson Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Megatron-LM, they do special handing for swiglu for partition conversion (code) is because they have chunk operation for swiglu in their implementation (code). In NeMo, we also have chunk operation but it only used when our activation is fast_glu_activation (code). Therefore, I change the partition conversion script from swiglu to fast_glu_activation

I explain the reason why the chunk operation needs special handling for partition conversion:

TP=2: 
    GPU0: tensor A [a1, a2] -> chunk to tensor a1 and tensor a2 -> activation(a1) * a2 
    GPU1: tensor B [b1, b2] -> chunk to tensor b1 and tensor b2 -> activation(b1) * b2 
(Wrong) TP = 1 
    GPU0: tensor C = [a1, a2, b1, b2] (normal TP concatenation) -> chunk to tensor [a1,a2] and tensor [b1,b2] -> activation([a1,a2]) * [b1,b2]
(Correct) TP = 1
    GPU0: tensor C = [a1, b1, a2, b2] (special handling) -> chunk to tensor [a1,b1] and tensor [a1,b2] -> activation([a1,b1]) * [a2,b2]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense.


if param.shape == partitions[0][idx].shape:
split = [partitions[0][idx].data] * tp_size
Expand Down Expand Up @@ -230,8 +230,8 @@ def compute_tp_splits(
for i in range(tp_size):
tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 2, tp_size)])
split.append(tp_qkv)
elif 'dense_h_to_4h.weight' in param_name and swiglu_activation:
# For Megatron GPT model with Swiglu activation
elif 'dense_h_to_4h.weight' in param_name and fast_glu_activation:
# For Megatron GPT model with Fast Glu activation
# Handle gated linear units
# concat all the first halves ('W's) and all the second halves ('V's)
w_split, k_split = torch.chunk(partitions[0][idx].data, 2, dim=0)
Expand Down Expand Up @@ -261,7 +261,7 @@ def compute_tp_merge(idx, name, param, partitions_pp, model_cfg):
Returns:
The concatenated parameter for TP 1 PP 1.
"""
swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower()
fast_glu_activation = str(model_cfg.get('activation', '')).lower() in ['fast-geglu', 'fast-swiglu', 'fast-reglu']
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

# Logic from original TP rank change
if param.shape == partitions_pp[0][idx].shape:
Expand All @@ -271,8 +271,8 @@ def compute_tp_merge(idx, name, param, partitions_pp, model_cfg):
else:
concated = torch.cat([partitions_pp[i][idx].data for i in range(len(partitions_pp))], dim=0)

# Logic for Swiglu activation
if 'dense_h_to_4h.weight' in name and swiglu_activation:
# Logic for Fast Glu activation
if 'dense_h_to_4h.weight' in name and fast_glu_activation:
# concat all the first halves ('W's) and all the second halves ('V's)
wk_splits = []
for tpr in range(len(partitions_pp)):
Expand Down