Skip to content

Commit

Permalink
Add ub communicator initialization to validation step (#6807)
Browse files Browse the repository at this point in the history
  • Loading branch information
erhoo82 authored and web-flow committed Jun 5, 2023
1 parent 76fc488 commit 16ff138
Showing 1 changed file with 31 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,39 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):

return loss_mean

def initialize_ub_func(self):
input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
print("Fail to read ub_tp_comm_overlap config file.")
else:
ub_cfgs = None
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False

def training_step(self, dataloader_iter, batch_idx):
"""
We pass the dataloader iterator function to the micro-batch scheduler.
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators.
if self.initialize_ub:
self.initialize_ub_func()

# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()
Expand Down Expand Up @@ -829,6 +856,10 @@ def validation_step(self, dataloader_iter, batch_idx):
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
# Initialize userbuffer communicators.
if self.initialize_ub:
self.initialize_ub_func()

if isinstance(self.model, list):
for model_module in self.model:
model_module.eval()
Expand Down

0 comments on commit 16ff138

Please sign in to comment.