Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load ub_cfg from hydra config #7003

Merged
merged 12 commits into from
Aug 12, 2023
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- optional [email protected]_tp_comm_overlap_cfg:

name: megatron_gpt
restore_from_path: null # used when starting from a .nemo file

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS1/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS2/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 8
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8

# Bulk overlap with AllGather / ReduceScatter
qkv_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 1

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 20
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP8/MBS2/SeqLen2K/FP8

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 1

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 16
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
Expand Up @@ -511,20 +511,16 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
return loss_mean

def initialize_ub_func(self):
ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfgs is None:
warnings.warn(
"Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config."
)

input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
ub_cfgs = None
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.")
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
Expand Down