Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] use all_gather for 10X OSD consolidation speedup #595

Merged
merged 23 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Dict, Generator, List, Tuple
from typing import Any, Dict, Generator, List, Tuple

import torch

# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}

# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
Expand All @@ -16,6 +18,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
singleton_state: Dict = copy.deepcopy(new_state)
else:
new_state = {}
non_tensor_state = {}
Expand All @@ -24,19 +27,26 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
for global_id, buffers in sd["state"].items():
local_id = param_id_map[global_id]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if is_singleton_tensor(p):
singleton_state[local_id][buffer_name] = p
elif torch.is_tensor(p):
if buffer_name not in new_state[local_id]:
new_state[local_id][buffer_name] = []
new_state[local_id][buffer_name].append(p.reshape(-1))
elif isinstance(p, list):
singleton_state[local_id][buffer_name] = p
else:
non_tensor_state[buffer_name] = p

# Now combine all tensors in each buffer using torch.cat().
for local_id, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_state[local_id].update(singleton_state[local_id])
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in UNFLAT_RETURN_KEYS:
new_sd[k] = copy.deepcopy(sd[k])

# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
Expand Down Expand Up @@ -70,22 +80,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
return non_tensor_state


def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state


def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
combined_state: Dict[int, Dict],
instance_list: List[torch.nn.Module],
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
Expand All @@ -98,17 +97,17 @@ def _unflatten_optim_state(
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]

# local corresponds to flattened, global corresponds to unflattened
num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore
num_global_params = [len(m._param_numels) for m in instance_list] # type: ignore
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_unflat_params):
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
next_global_id += 1
if not combined_state:
return {}, global_to_local_id

# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
# copy non tensor state to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}

if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id
Expand All @@ -131,37 +130,44 @@ def _unflatten_optim_state(
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
unflat_state[global_id].update(singleton_state[local_id])

return unflat_state, global_to_local_id


def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
instance_list: List[torch.nn.Module],
world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
# Use uncollected_opt_state to update tensor_state, singleton_state
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states

assert local_id not in state
state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if not is_singleton_tensor(x)}
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(param_groups)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params))
return {
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS
return unflat_optim_state_dict


def is_singleton_tensor(x: Any) -> bool:
"""Is x a dimensionless tensor?"""
return torch.is_tensor(x) and x.dim() == 0
112 changes: 68 additions & 44 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,70 +1382,88 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.

Args:

optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state

Returns:
all_states (list[dict]) the optimizer state from each rank


.. warning: This needs to be called on all replicas"""
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank."""
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
for rank in range(self.world_size):
if rank == self.rank:
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
pad_info = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states

def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
pad_info = dummy_tensor # type: ignore
pad_info = broadcast_object(
pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
if self.rank == 0:
world_pad_info.append(pad_info) # type: ignore
return world_pad_info

def _gather_optim_state(
self, sd_state: Dict[int, Dict[str, Any]]
) -> Tuple[Dict[int, Dict[str, List]], Dict[int, Dict[str, List]]]:
"""For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor
for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
if ou.is_singleton_tensor(t):
if singleton_buffer is None:
singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
dist.all_gather(singleton_buffer, t, group=self.process_group)
if self.rank == 0:
singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
elif torch.is_tensor(t):
if buffer is None:
buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size))
dist.all_gather(buffer, t, group=self.process_group)
if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
elif self.rank == 0: # Add non tensor state
gathered_state[k][buffer_name] = [t]

return gathered_state, singleton_state

def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: Dict) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.

This should be called only on the root FSDP instance.
Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


Different world_size groups in nested FSDP instances is not supported.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
optim (Optimizer): an optimizer instance for this FSDP rank. Its state_dict is
used in the consolidation. However, its state is not modified.

Returns:
a dict with two entries

* A dict with four entries (On rank zero, other workers return ``None``)
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
* param_id_map - global (unflat) to local (flat) id mapping
* uncollected_local_ids - keys in the state dict that were not broadcast

"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:

self._lazy_init()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}'
assert len(sd["param_groups"]) == 1, "Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state, singleton_state = self._gather_optim_state(sd.pop("state"))
pad_info = self._broadcast_pad_info_to_r0()
if self.rank != 0:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
Expand Down Expand Up @@ -1499,14 +1517,20 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
for k, v in s.items():
if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
# if we are resuming on larger world size, take first entry
v_shard = v[0] if self.rank >= len(v) else v[self.rank]
assert ou.is_singleton_tensor(v_shard)
else:
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict

def _print_r0(self, msg: str) -> None:
def _print_r0(self, msg: str, restart: bool = False) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if restart:
self._tstart = time.time()
if self.rank == 0:
gb_denom = 1024 ** 3
print(
Expand Down
12 changes: 8 additions & 4 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,19 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 16
expert = nn.Linear(d_expert, 4)
d_expert = 23
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure we unpad expert params correctly.

d_shared = 12
d_input = 8
expert = nn.Linear(d_expert, d_shared)

self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True

# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, d_expert)

shared = nn.Linear(d_shared, d_expert)

if checkpoint_act:
expert = checkpoint_wrapper(expert)
Expand All @@ -648,7 +652,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

shared = FullyShardedDataParallel(shared, group, **wrapper_config)

self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))

def forward(self, x):
if self.delay_before_free_ms > 0:
Expand Down
8 changes: 8 additions & 0 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.optim import SGD, Adadelta, Adam # type: ignore

from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.optim.utils import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal

Expand Down Expand Up @@ -147,3 +148,10 @@ def test_named_params_ordering(self):
named_pars = [p for n, p in model.named_parameters()]
for i, p in enumerate(model.parameters()):
assert objects_are_equal(p, named_pars[i])

def test_is_singleton_tensor(self):
assert is_singleton_tensor(torch.tensor(4.0))
assert not is_singleton_tensor(torch.tensor([4.0]))
assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
assert not is_singleton_tensor([4.0])
assert not is_singleton_tensor(4.0)