Skip to content

Commit

Permalink
Support Swiglu in TP PP Conversion (#6437)
Browse files Browse the repository at this point in the history
* Support Swiglu in TP PP Conversion

Signed-off-by: smajumdar <[email protected]>

* Guard activation

Signed-off-by: smajumdar <[email protected]>

* Guard activation

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 authored and web-flow committed Apr 19, 2023
1 parent 499a3b2 commit 58ed9bb
Showing 1 changed file with 61 additions and 12 deletions.
73 changes: 61 additions & 12 deletions examples/nlp/language_modeling/megatron_change_num_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
--target_pipeline_model_parallel_size=1 \
--target_pipeline_model_parallel_split_rank=0 \
--precision=bf16
### Only Tensor Parallelism conversion ###
To the above commands, add the following argument: `--tp_conversion_only`
Expand Down Expand Up @@ -99,13 +99,14 @@
"""


#################
### Utilities ###
#################


def compute_tp_splits(
param_name, param, partitions, global_idx, tp_size, pp_size, pp_rank, pp_split_rank, megatron_legacy
param_name, param, partitions, global_idx, tp_size, pp_size, pp_rank, pp_split_rank, megatron_legacy, model_cfg
):
"""
Function to compute the splits required for tensor-parallelism.
Expand All @@ -120,13 +121,16 @@ def compute_tp_splits(
pp_rank: Int, pipeline-parallelism rank.
pp_split_rank: Int, pipeline-parallelism split rank. This should be > 1 if TP is being used with EncDec models (T5)
megatron_legacy: Bool, whether the model is a legacy Megatron model or not.
model_cfg: The model config as a OmegaConf DictConfig.
Returns:
List of torch tensors, each of which is a split of the current parameter.
"""
# alias the global index to idx
idx = global_idx

swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower()

if param.shape == partitions[0][idx].shape:
split = [partitions[0][idx].data] * tp_size
logging.debug(">> Perfect match, no splitting needed")
Expand Down Expand Up @@ -156,14 +160,23 @@ def compute_tp_splits(
for i in range(tp_size):
tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 2, tp_size)])
split.append(tp_qkv)
elif 'dense_h_to_4h.weight' in param_name and swiglu_activation:
# For Megatron GPT model with Swiglu activation
# Handle gated linear units
# concat all the first halves ('W's) and all the second halves ('V's)
w_split, k_split = torch.chunk(partitions[0][idx].data, 2, dim=0)
w_split = torch.chunk(w_split, tp_size, dim=0)
k_split = torch.chunk(k_split, tp_size, dim=0)
split = [torch.cat(weights, dim=0) for weights in zip(w_split, k_split)] # split per tp rank

# Regular split for Megatron and NeMo-Megatron models.
else:
split = torch.split(partitions[0][idx].data, param.shape[0], dim=0)

return split


def compute_tp_merge(idx, name, param, partitions_pp):
def compute_tp_merge(idx, name, param, partitions_pp, model_cfg):
"""
Function to compute the partition merge required for tensor-parallelism.
Expand All @@ -173,17 +186,33 @@ def compute_tp_merge(idx, name, param, partitions_pp):
param: The parameter to be merged under TP 1 PP 1.
partitions_pp: List of all TP partitions of the flattened parameter of the current model for a given PP rank
(TP X PP Y). Indexed as partitions_pp[tp_rank][idx].
model_cfg: The model config as an OmegaConf DictConfig.
Returns:
The concatenated parameter for TP 1 PP 1.
"""
swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower()

# Logic from original TP rank change
if param.shape == partitions_pp[0][idx].shape:
concated = partitions_pp[0][idx].data
elif param.shape[0] == partitions_pp[0][idx].shape[0]:
concated = torch.cat([partitions_pp[i][idx].data for i in range(len(partitions_pp))], dim=-1)
else:
concated = torch.cat([partitions_pp[i][idx].data for i in range(len(partitions_pp))], dim=0)

# Logic for Swiglu activation
if 'dense_h_to_4h.weight' in name and swiglu_activation:
# concat all the first halves ('W's) and all the second halves ('V's)
wk_splits = []
for tpr in range(len(partitions_pp)):
wk_splits.append(torch.chunk(partitions_pp[tpr][idx].data, 2, dim=0))

w_split = torch.cat([w[0] for w in wk_splits], dim=0)
k_split = torch.cat([w[1] for w in wk_splits], dim=0)
concated = torch.cat([w_split, k_split], dim=0)

# Trim padding
if concated.shape != param.shape:
logging.info(
f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, merged shape: {concated.shape}. Narrowing to match required size."
Expand Down Expand Up @@ -301,7 +330,16 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank

# Tensor Parallel Splitting
split = compute_tp_splits(
param_name, param, partitions, idx, tp_size, pp_size, pp_rank, pp_split_rank, self.megatron_legacy
param_name,
param,
partitions,
idx,
tp_size,
pp_size,
pp_rank,
pp_split_rank,
self.megatron_legacy,
model.cfg,
)

splits.append(split)
Expand Down Expand Up @@ -419,7 +457,16 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank

# Tensor Parallel Splitting
split = compute_tp_splits(
param_name, param, partitions, idx, tp_size, pp_size, pp_rank, pp_split_rank, self.megatron_legacy
param_name,
param,
partitions,
idx,
tp_size,
pp_size,
pp_rank,
pp_split_rank,
self.megatron_legacy,
model.cfg,
)

splits.append(split)
Expand All @@ -445,12 +492,13 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank
param_name,
param,
partitions,
0,
tp_size,
pp_size,
pp_rank,
pp_split_rank,
self.megatron_legacy,
global_idx=0,
tp_size=tp_size,
pp_size=pp_size,
pp_rank=pp_rank,
pp_split_rank=pp_split_rank,
megatron_legacy=self.megatron_legacy,
model_cfg=model.cfg,
)
splits.insert(self.intermediate_shared_embedding_location, split)
break
Expand Down Expand Up @@ -534,7 +582,7 @@ def merge_partition(model, partitions: Dict[int, List[List[torch.Tensor]]], writ
)

# Original TP rank change logic
concated = compute_tp_merge(idx, name, param, partitions_pp)
concated = compute_tp_merge(idx, name, param, partitions_pp, model.cfg)

# Update the model parameter with the merged tensor
param.data = concated
Expand Down Expand Up @@ -656,6 +704,7 @@ def split_tp_partition_only(model, partitions, tp_size, write_path=None, megatro
pp_rank=0,
pp_split_rank=0,
megatron_legacy=megatron_legacy,
model_cfg=model.cfg,
)
splits.append(split)
idx += 1
Expand Down

0 comments on commit 58ed9bb

Please sign in to comment.