-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Pipeline Dataloader Sampler: shuffle=False
#5619
Comments
It's unclear to me why this is the default, and confusing that there is no documentation of this behavior. Shuffle should probably be a parameter of the deepspeed config file. |
After struggling with a bug all morning, I believe I now understand why the shuffle must be set to false here. Let's first take a look at how the engine loads data when pipeline parallelism is enabled: def _exec_load_micro_batch(self, buffer_id):
if self.wall_clock_breakdown():
self.timers(BATCH_INPUT_TIMER).start()
batch = self._next_batch()
if self.is_first_stage():
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
'use_reentrant']:
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], (tuple, list))
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
'use_reentrant']:
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)
self.pipe_buffers['inputs'][buffer_id] = loaded
if self.is_last_stage():
loaded = batch[1]
if torch.is_tensor(batch[1]):
loaded = batch[1].to(self.device)
# XXX: torch 1.6.0 DataLoader will auto convert tuple to list
elif isinstance(batch[1], (tuple, list)):
loaded = []
for x in batch[1]:
assert torch.is_tensor(x)
x = [x.to](http://x.to/)(self.device).detach()
loaded.append(x)
loaded = tuple(loaded)
self.pipe_buffers['labels'][buffer_id] = loaded From the above code snippet, we can observe that both the first and last stages independently fetch data from the DataLoader. The first stage retains the inputs (i.e., P.S. What's the correct behavior if we want to enable data shuffling in DeepSpeed. |
@xianshunw No, mismatching wil not appear. The |
@xianshunw @avicooper1 Setting |
As we pass a data parallel rank to the loader, I think shuffling should properly work. Feel free to submit a PR to set |
@tohtana @Coobiw In addition to Pipeline Dataloader Samler, deepspeed runtime engine dataloader also has similar problems, the code is in https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/engine.py#L1777, do you think we should submit a PR to set |
I'm not sure why it is also set to False but I agree that |
Apologies, it was my fault. The issue arose when I manually configured the dataloader. It turned out that the problem was due to incorrect initialization of the dataloader |
Making it configurable is better. |
@xianshunw Yeah, I've used the following sampler for my custom dataloader: sampler = torch.utils.data.distributed.DistributedSampler(
datasets['train'],
num_replicas=engine.dp_world_size,
rank=engine.mpu.get_data_parallel_rank(),
shuffle=True
) There is no problem. |
shuffle=False
shuffle=False
Thanks @xianshunw and @Coobiw - we will work on making it configurable, but at least with the current unit tests, the linked PR seems to hang with |
Describe the bug
When I read the source code of building the
dataloader
inPipelineEngine
. I findshuffle=False
in the sampler. Code:I want to know why you set
shuffle
to False, not True.The code is in
deepspeed/runtime/pipe/engine.py
, Pipeline Engine class,def _build_data_iter
.deepspeed version: 0.12.4
To Reproduce
Steps to reproduce the behavior:
Expected behavior
A clear and concise description of what you expected to happen.
ds_report output
Please run
ds_report
to give us details about your setup.Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Launcher context
Are you launching your experiment with the
deepspeed
launcher, MPI, or something else?Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: