Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def train_step(forward_step_func, data_iterator,
assert isinstance(model[0], deepspeed.PipelineEngine), model
loss = model[0].train_batch(data_iter=data_iterator)
skipped_iter = 0
grad_norm = 0.
grad_norm = model[0].get_global_grad_norm()
num_zeros_in_grad = 0
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad

Expand Down
12 changes: 9 additions & 3 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import os
import subprocess


def model_provider(pre_process=True, post_process=True):
"""Build the model."""

Expand All @@ -41,9 +42,10 @@ def model_provider(pre_process=True, post_process=True):

args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device=='none' else args.remote_device,
config=args.deepspeed_config,
enabled=args.zero_stage==3):
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
Expand Down Expand Up @@ -112,6 +114,7 @@ def get_batch(data_iterator):

return tokens, labels, loss_mask, attention_mask, position_ids


def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
Expand Down Expand Up @@ -139,6 +142,7 @@ def get_batch_pipe(data):

return (tokens, position_ids, attention_mask), (labels, loss_mask)


def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
Expand Down Expand Up @@ -185,10 +189,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):

return train_ds, valid_ds, test_ds


def command_exists(cmd):
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 0


def git_ds_info():
from deepspeed.env_report import main as ds_report
ds_report()
Expand Down