Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] use all_gather for 10X OSD consolidation speedup #595

Merged
merged 23 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def _unflatten_optim_state(


def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
instance_list: List[torch.nn.Module],
world_optim_states: List[Dict],
tensor_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
Expand All @@ -147,16 +150,15 @@ def build_unflat_state_dict(
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
assert local_id not in tensor_state
tensor_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
tensor_state[local_id][buffer_name] = [tensor]
del world_optim_states

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
unflat_state, global_to_local_id = _unflatten_optim_state(tensor_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params))
return {
Expand Down
83 changes: 70 additions & 13 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum, auto
import functools
from math import inf
import time
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

self.numel_padded_per_param: List[int] = []
self._tstart = time.time()

if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
Expand Down Expand Up @@ -1382,7 +1384,7 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
) -> Tuple[List[Dict], Dict[int, Dict[str, List[Any]]]]:
"""Update the consolidated state_dict list, one per rank.

Args:
Expand All @@ -1393,7 +1395,8 @@ def _consolidate_optim_state_dict(
None is a special value, which means that all ranks should have the state

Returns:
all_states (list[dict]) the optimizer state from each rank
non_tensor_state (list[dict]) the non-tensor optimizer state from each rank
tensor_state (dict[list]) tensor state (combined from all ranks using all_gather)


.. warning: This needs to be called on all replicas"""
Expand All @@ -1402,20 +1405,58 @@ def _consolidate_optim_state_dict(
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
non_tensor_state: List[Dict[str, Any]] = [] # this will contain values from the whole world.
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
tensor_state = self._gather_optim_state(sd.pop("state"))
for rank in range(self.world_size):
# TODO(SS): there is no need to communicate values we don't know how to consolidate, like param_groups
if rank == self.rank:
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
send_data = sd
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
send_data = dummy_tensor # type: ignore
send_data = broadcast_object(
send_data, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
assert isinstance(send_data, dict), f"{self.rank} received {type(send_data)} from {rank}, expected dict"
non_tensor_state.append(
recursive_copy_to_device(send_data, non_blocking=False, device=torch.device("cpu"))
)

return all_states
self._print_r0(f"gathered non tensor state from rank {rank}")

return non_tensor_state, tensor_state

def _gather_optim_state(self, sd_state: Dict[int, Dict[str, Any]]) -> Dict[int, Dict[str, List]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new _all_gather logic

"""for each buffer in state[i], if the buffer is a tensor, collect it from the world. Else use rank 0's entry."""
self._print_r0(f"start: n state: {len(sd_state)}")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
for k, v in sd_state.items():
gathered_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
buffer = None

for buffer_name, t in v.items():
if torch.is_tensor(t):
if buffer is None:
buffer = self._fsdp_instances[k].flat_param.new_zeros(*desired_buffer_size, dtype=t.dtype) # type: ignore
chunks = list(buffer.chunk(self.world_size))
dist.all_gather(chunks, t, group=self.process_group)
# unpad each chunk here
# This is required to make chunks save different data foreach buffer name
if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in chunks]
elif self.rank == 0:
# Add non tensor state
if buffer_name in gathered_state[k]:
gathered_state[k][buffer_name].append(t)
else:
gathered_state[k][buffer_name] = [t]
self._print_r0(f"gathered tensor state for key {k} from all ranks")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
return gathered_state

def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
Expand All @@ -1437,16 +1478,20 @@ def gather_full_optim_state_dict(
* param_groups - a dict containing the 1 parameter group

"""
self._tstart = time.time()
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
non_tensor_state, tensor_state = self._consolidate_optim_state_dict(optim, recipient_rank)

if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
self._fsdp_instances, non_tensor_state, tensor_state, self.uncollected_opt_state
)

self.uncollected_opt_state = {}
self._print_r0("FSDP: done unflat")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
assert "uncollected_local_ids" in new_state_dict
return new_state_dict

Expand All @@ -1459,8 +1504,12 @@ def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self.uncollected_opt_state = {
k: recursive_copy_to_device(v, non_blocking=False, device=torch.device("cpu"))
for k, v in osd["state"].items()
if k in uncollected_ids
}

pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
Expand Down Expand Up @@ -1500,6 +1549,14 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])

return full_optim_state_dict

def _print_r0(self, msg: str) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if self.rank == 0:
gb_denom = 1024 ** 3
print(
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
)


def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
Expand Down
8 changes: 5 additions & 3 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,14 +627,16 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
d_expert = 10000
d_shared = 512
expert = nn.Linear(d_expert, d_shared)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True

# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, 16)
shared = nn.Linear(d_shared, d_expert)

if checkpoint_act:
expert = checkpoint_wrapper(expert)
Expand All @@ -647,7 +649,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

shared = FullyShardedDataParallel(shared, group, **wrapper_config)

self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
self.module = nn.Sequential(nn.Linear(8, d_shared), shared, expert, nn.Linear(d_shared, 8))

def forward(self, x):
if self.delay_before_free_ms > 0:
Expand Down
16 changes: 15 additions & 1 deletion tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,30 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state

torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
mem_usg_gb = cuda_gb_after - cuda_gb_before
assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

if fsdp.rank > 0:
assert sd is None
return

# assert whole state dict on CPU
for k, v in sd["state"].items():
for buffer_name, t in v.items():
if torch.is_tensor(t):
msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
assert t.device == torch.device("cpu"), msg

unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
Expand Down