Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Rebuild code for pipeline #36396

Merged
merged 2 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,22 @@ def __init__(self, layers, hcg, strategy):
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.')

assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.')

if self.is_first_stage or self.is_last_stage:
assert data is not None, (
"For the first and the last stage, the data must be set.")
else:
data = None
def forward_backward_pipeline(self, data, scaler=None):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py

self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.scaler = scaler
self.data = data
self._compute_loss = True

self._layers.train()
# store data for train
self.data = data

# store total loss of entire batch
self.total_loss = None

# store data id for micro_batch
self.micro_batch_id = 0

# Next, use the 1f1b scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py

startup_steps = (self.num_stages - self.stage_id - 1)
startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps
Expand Down Expand Up @@ -161,11 +146,35 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):

self._layers.allreduce_shared_weight_gradients()

self.train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss()

return train_loss

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.')

assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.')

if self.is_first_stage or self.is_last_stage:
assert data is not None, (
"For the first and the last stage, the data must be set.")
else:
data = None

self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

self._layers.train()

# 1f1b for pipeline
train_loss = self.forward_backward_pipeline(data, scaler)

# optimizer
self._optimizer_step()
return self.train_loss

return train_loss

def eval_batch(self, data, compute_loss=False):
self._layers.eval()
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,15 @@ def sync_params_buffers(model,
if not isinstance(param, core.VarBase):
raise TypeError("The data type of '%s' must be Varbase" %
param.name)

# is_distributed param not need to sync when in mp mode
if is_model_parallel and isinstance(param, ParamBase):
if param.is_distributed:
if isinstance(param, ParamBase):
if is_model_parallel and param.is_distributed:
continue

# NOTE(shenliang03): Support situations that do not require synchronization parameters,
# such as moe's expert parameters
if getattr(param, "no_sync", False):
continue

model_vars.append(param.detach())
Expand Down