Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor-parallel communication overlap with userbuffer backend #6780

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ model:
fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration
use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.
ub_tp_comm_overlap: False
# Use userbuffer backend to overlap tensor-parallel communications with computes.
# This feature is only available with Transformer Engine and squence parallelism enabled and, currently, supports only GPT models.
ub_tp_comm_overlap_cfg: null
# A yaml file with userbuffer communicator configurations. This file should provide `method`, `dtype`, `num_sm`, `num_splits`,
# `cga_size`, `num_splits`, `set_sm_margin`, and `aggregate` for the communicators to use custom settings.
# If the configuration file is not provided a default setting is used for all communicators.

data:
# Path to data must be specified by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -243,6 +244,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size'),
use_fp8=cfg.get('fp8', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)
Expand Down Expand Up @@ -540,6 +541,14 @@ def _validate_and_override_config(self):
'Make sure the number of model chunks is the same across all pipeline stages.'
)

if self.cfg.get('ub_tp_comm_overlap', False):
if not self.cfg.get('transformer_engine', False) or not self.cfg.get('sequence_parallel', False):
logging.info(
"Userbuffer tensor-parallel communication overlap is available with both Transformer Engine and sequence-parallelism."
)
with open_dict(self.cfg):
self.cfg.ub_tp_comm_overlap = False

def is_data_parallel_rank_zero(self):
if is_global_rank_zero():
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

try:
import transformer_engine
from transformer_engine.pytorch import module as te_module

HAVE_TE = True

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_end_step *= grad_accum_steps

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)

def get_gpt_module_list(self):
if isinstance(self.model, list):
Expand Down Expand Up @@ -254,6 +256,7 @@ def model_provider_func(self, pre_process, post_process):
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'),
reduce_amax=self.cfg.get('reduce_amax', True),
use_emha=self.cfg.get('use_emha', False),
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
)

return model
Expand Down Expand Up @@ -410,6 +413,31 @@ def training_step(self, dataloader_iter, batch_idx):
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators. Initialization is done only once at the
# beginning of the first training step.
if self.initialize_ub:
input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
print("Fail to read ub_tp_comm_overlap config file.")
else:
ub_cfgs = None
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False

# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_language_model(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -191,6 +192,7 @@ def get_language_model(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -497,6 +499,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -602,6 +605,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
self._encoder_key = 'encoder'

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def initialize_model_parallel_for_nemo(
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
init_mpi_proc_group=False,
seed=1234,
apex_transformer_log_level=30,
):
Expand All @@ -83,6 +84,7 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def __init__(
layer_type: str = "encoder",
drop_path_rate: float = 0,
use_emha: bool = False,
ub_tp_comm_overlap: bool = False,
autocast_dtype: Any = 16,
zero_centered_gamma: bool = False,
) -> None:
Expand Down Expand Up @@ -824,6 +825,7 @@ def __init__(
set_parallel_mode=tp_size > 1,
fuse_qkv_params=True,
zero_centered_gamma=zero_centered_gamma,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# use_emha=use_emha,

Expand Down Expand Up @@ -919,6 +921,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
normalize_attention_scores=True,
multi_query_attention=False,
num_moe_experts=1,
Expand Down Expand Up @@ -1058,6 +1061,7 @@ def build_layer(layer_number):
apply_residual_connection_post_layernorm=False,
autocast_dtype=precision,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
zero_centered_gamma=normalization == 'layernorm1p',
)
else:
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group()

# create MPI process group for UCX-based communication APIs
if app_state.init_mpi_proc_group:
torch.distributed.new_group(backend='mpi')

def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
Expand Down Expand Up @@ -405,7 +409,7 @@ class PEFTSaveRestoreConnector(NLPSaveRestoreConnector):
Args:
peft_model_nemo_path: Used to provide the .nemo file corresponding to a PEFT model (which will only contain a small set of params)
peft_model_ckpt_path: Used to provide the path to .ckpt files of a PEFt model. This is required when no .nemo is available (yet) such as during resumed training.
If both are provided the peft_model_ckpt_path takes precedence.
If both are provided the peft_model_ckpt_path takes precedence.
If neither are provided, PEFT params are initialized at random (not loaded from any external source).
"""

Expand Down
17 changes: 17 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self):
self._data_parallel_group = None
self._megatron_checkpoint_version = None
self._use_fp8 = False
self._init_mpi_proc_gruop = False

self._random_seed = None

Expand Down Expand Up @@ -363,6 +364,22 @@ def use_fp8(self, use_fp8):
"""
self._use_fp8 = use_fp8

@property
def init_mpi_proc_group(self):
""" Property sets the initialization of mpi process group.
Returns:
Initialize mpi process group.
"""
return self._init_mpi_proc_group

@init_mpi_proc_group.setter
def init_mpi_proc_group(self, init_mpi_proc_group):
""" Property sets the initialization of mpi process group.
Args:
init_mpi_proc_group: Initialize mpi process group.
"""
self._init_mpi_proc_group = init_mpi_proc_group

@property
def random_seed(self):
""" Property returns the random seed.
Expand Down
Loading