Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] use all_gather for 10X OSD consolidation speedup #595

Merged
merged 23 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 14 additions & 29 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
# TODO(SS): if there are other keys, like loss_scale, don't delete them

# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
Expand Down Expand Up @@ -70,28 +71,14 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
return non_tensor_state


def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state


def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}
local_ids = [id for id in sorted(combined_state.keys())]
local_ids = list(sorted(combined_state.keys()))

# non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step".
# we check that these are identical across workers and then take the first
Expand All @@ -108,6 +95,7 @@ def _unflatten_optim_state(
return {}, global_to_local_id

# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
# deepcopy is OK because there are no tensors.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}

if non_tensor_state[0].keys() == combined_state[0].keys():
Expand Down Expand Up @@ -136,28 +124,25 @@ def _unflatten_optim_state(


def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
instance_list: List[torch.nn.Module],
world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
# Add uncollected state to tensor_state
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states

assert local_id not in state
state[local_id] = {buffer_name: [x] for buffer_name, x in v.items()}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(param_groups)
param_groups[0]["params"] = list(range(num_params))
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
Expand Down
96 changes: 56 additions & 40 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,45 +1382,48 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.

Args:

optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state

Returns:
all_states (list[dict]) the optimizer state from each rank


.. warning: This needs to be called on all replicas"""
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
for rank in range(self.world_size):
if rank == self.rank:
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
pad_info = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states

def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
pad_info = dummy_tensor # type: ignore
pad_info = broadcast_object(
pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
if self.rank == 0:
world_pad_info.append(pad_info) # type: ignore
return world_pad_info

def _gather_optim_state(self, sd_state: Dict[int, Dict[str, Any]]) -> Dict[int, Dict[str, List]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new _all_gather logic

"""for each buffer in state[i], if the buffer is a tensor, collect it from the world. Else use rank 0's entry."""
self._print_r0(f"start: n state: {len(sd_state)}")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
for k, v in sd_state.items():
gathered_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
buffer = None
for buffer_name, t in v.items():
if torch.is_tensor(t):
if buffer is None:
buffer = t.new_zeros(*desired_buffer_size, dtype=t.dtype)
chunks = list(buffer.chunk(self.world_size))
dist.all_gather(chunks, t, group=self.process_group)

if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in chunks]
elif self.rank == 0:
# Add non tensor state
if buffer_name in gathered_state[k]:
gathered_state[k][buffer_name].append(t)
else:
gathered_state[k][buffer_name] = [t]
self._print_r0(f"gathered tensor state for key {k} from all ranks")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
return gathered_state

def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: Dict) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.

Expand All @@ -1430,24 +1433,37 @@ def gather_full_optim_state_dict(
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.

Returns:
a dict with two entries
a dict with four entries
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
* param_id_map - global (unflat) to local (flat) id mapping
* uncollected_local_ids - keys in the state dict that were not broadcast

"""

self._tstart = time.time()
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:

self._lazy_init()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
assert set(sd.keys()) == {"param_groups", "state"}
assert len(sd["param_groups"]) == 1, "Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state = self._gather_optim_state(sd.pop("state"))
pad_info = self._broadcast_pad_info_to_r0()
if self.rank != 0:
return None
assert set(sd.keys()) == {"param_groups"}
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
self._fsdp_instances, pad_info, state, self.uncollected_opt_state, sd["param_groups"]
)

self.uncollected_opt_state = {}
self._print_r0("FSDP: done unflat")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
assert "uncollected_local_ids" in new_state_dict
return new_state_dict

Expand Down
11 changes: 7 additions & 4 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,18 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 16
expert = nn.Linear(d_expert, 4)
d_expert = 23
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure we unpad expert params correctly.

d_shared = 12
expert = nn.Linear(d_expert, d_shared)

self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True

# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, d_expert)

shared = nn.Linear(d_shared, d_expert)

if checkpoint_act:
expert = checkpoint_wrapper(expert)
Expand All @@ -648,7 +651,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

shared = FullyShardedDataParallel(shared, group, **wrapper_config)

self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
self.module = nn.Sequential(nn.Linear(8, d_shared), shared, expert, nn.Linear(d_shared, 8))

def forward(self, x):
if self.delay_before_free_ms > 0:
Expand Down