Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the device bubble introduced by heavy loop synchronization in coalesced fetch/release(z3_leaf_module) #6694

Merged
merged 48 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
a2610d8
z3 coalesced fetch
inkcherry Oct 21, 2024
4e8be08
fix format
inkcherry Oct 21, 2024
7641994
fix default value
inkcherry Oct 21, 2024
805a820
fix default
inkcherry Oct 21, 2024
ce7dfb7
Merge branch 'master' into z3_coalesced_fetch
delock Oct 23, 2024
810353b
fix ut
inkcherry Oct 23, 2024
a8dd8fe
fix ut
inkcherry Oct 23, 2024
53584ca
Merge branch 'master' into z3_coalesced_fetch
loadams Oct 25, 2024
4d86198
Merge branch 'master' into z3_coalesced_fetch
tjruwase Oct 31, 2024
7b94377
add ut(usage)
inkcherry Nov 4, 2024
cd31a0d
use int type config
inkcherry Nov 4, 2024
ea50964
fix format
inkcherry Nov 4, 2024
b068118
Merge remote-tracking branch 'origin/z3_coalesced_fetch' into z3_coal…
inkcherry Nov 4, 2024
600d9c7
fix note
inkcherry Nov 4, 2024
4477077
Merge branch 'master' into z3_coalesced_fetch
tjruwase Nov 4, 2024
c2c434b
refine code
inkcherry Nov 5, 2024
e5f9430
remove debug code
inkcherry Nov 5, 2024
c2b020a
update
inkcherry Nov 5, 2024
511ace0
Merge remote-tracking branch 'origin/z3_coalesced_fetch' into z3_coal…
inkcherry Nov 5, 2024
3680109
don't set leaf for container module
inkcherry Nov 5, 2024
f2752f8
Merge branch 'master' into z3_coalesced_fetch
inkcherry Nov 5, 2024
22c0f81
update ut
inkcherry Nov 6, 2024
f773258
udpate
inkcherry Nov 6, 2024
c31ad02
change config name, refine doc
inkcherry Nov 6, 2024
40ceeac
fix rjust size
inkcherry Nov 6, 2024
73e5bd5
fix merge
inkcherry Nov 6, 2024
c31c8d2
format
inkcherry Nov 6, 2024
619cbe6
always print info if the config is enabled
inkcherry Nov 7, 2024
3c0a183
Merge branch 'master' into z3_coalesced_fetch
inkcherry Nov 7, 2024
a6e5a39
update
inkcherry Nov 7, 2024
e7e5cdf
Merge branch 'z3_coalesced_fetch' of https://github.com/inkcherry/Dee…
inkcherry Nov 7, 2024
00ac4eb
Merge remote-tracking branch 'upstream/master' into z3_coalesced_fetch
inkcherry Nov 11, 2024
25df962
use mark parametrize for test
inkcherry Nov 11, 2024
663e637
opt loop
inkcherry Nov 11, 2024
13aaf2f
update
inkcherry Nov 13, 2024
063e5b9
Use fast fetch only for the case of z3_leaf_module with fine-grained …
inkcherry Nov 13, 2024
444d586
Merge branch 'master' into reduce_coalesced_fetch_bubble
inkcherry Nov 14, 2024
5e9f2e8
update
inkcherry Nov 14, 2024
21b3823
Merge branch 'master' into reduce_coalesced_fetch_bubble
inkcherry Nov 20, 2024
26ffcf7
Merge branch 'master' into reduce_coalesced_fetch_bubble
inkcherry Dec 23, 2024
931c121
Move the condition outside the loop
inkcherry Dec 25, 2024
8d59565
update condition&comments
inkcherry Dec 25, 2024
59fef78
Swap the order of the conditions
inkcherry Dec 25, 2024
335db94
rename arg name&fix ut
inkcherry Dec 26, 2024
c7b4441
fix ut
inkcherry Dec 27, 2024
4f317b8
Merge branch 'master' into reduce_coalesced_fetch_bubble
inkcherry Dec 27, 2024
99de70e
Merge branch 'master' into reduce_coalesced_fetch_bubble
loadams Jan 2, 2025
1517d71
Merge branch 'master' into reduce_coalesced_fetch_bubble
loadams Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def __init__(
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry

self.fast_sharding_for_leaf_module = False

if zero_module_granularity_threshold > 0:
self.min_granularity_value = sys.maxsize
self.min_granularity_layer = None
self.granularity_info = set()
self.z3_leaf_layers = []
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
self.fast_sharding_for_leaf_module = True

self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
Expand All @@ -155,14 +165,7 @@ def __init__(
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

if zero_module_granularity_threshold > 0:
self.min_granularity_value = sys.maxsize
self.min_granularity_layer = None
self.granularity_info = set()
self.z3_leaf_layers = []
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)

self.forward_hooks = []
self.backward_hooks = []
Expand Down
39 changes: 24 additions & 15 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None:
non_blocking=True).view(param.ds_shape)
self.__param = param

def wait(self) -> None:
def wait(self, **kwargs) -> None:
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
Expand All @@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None:
non_blocking=True).view(param.ds_shape)

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, **kwargs) -> None:
if self.__complete:
return

Expand Down Expand Up @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None:
self.__param = param
self.__quantization = quantization

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
Expand All @@ -650,6 +650,8 @@ def wait(self) -> None:

class AllGatherCoalescedHandle:

data_buffer = []

def __init__(
self,
allgather_handle,
Expand All @@ -672,7 +674,7 @@ def __init__(
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
if self.complete:
return

Expand Down Expand Up @@ -704,24 +706,30 @@ def wait(self) -> None:
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE

for part_to_copy in partitions:
if not get_accelerator().is_synchronized_device():
if not get_accelerator().is_synchronized_device() and handle_dependency:
for part_to_copy in partitions:
part_to_copy.record_stream(get_accelerator().current_stream())

param_offset += ds_tensor_numel

self.complete = True
if not get_accelerator().is_synchronized_device() and not handle_dependency:
# if the device needs to handle dependencies and opts for explicit processing outside the function.
AllGatherCoalescedHandle.data_buffer.append(partitions)

@staticmethod
def free_buffer():
AllGatherCoalescedHandle.data_buffer = []


class MultipleAllGatherHandles:

def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles

def wait(self) -> None:
def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
handle.wait()
handle.wait(handle_dependency)


class AllReduceCoalescedHandle:
Expand Down Expand Up @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter],
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
self._partition(param_list, has_been_updated=has_been_updated, free_data=True)

def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
Expand Down Expand Up @@ -1527,20 +1535,20 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):

return handles

def _partition(self, param_list, force=False, has_been_updated=False):
def _partition(self, param_list, force=False, has_been_updated=False, free_data=True):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
self._partition_param(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated, free_data=True)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
def _partition_param(self, param, buffer=None, has_been_updated=False):
def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
global reuse_buffers
print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
Expand All @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
# param.data does not store anything meaningful in partitioned state
free_param(param)
if free_data:
free_param(param)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)

if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
Expand Down
50 changes: 33 additions & 17 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,17 @@ class __ParamInTrace:
param: Parameter
step_id_last_used_at: int

def __init__(
self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
) -> None:
def __init__(self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
Expand Down Expand Up @@ -130,6 +129,10 @@ def __init__(
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)

# whether to enable fast fetch for the z3 leaf module.
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module

"""Tracing and Tracking
TODO. consider performing trace before initializing PartitionedParameterCoordinator
and passing trace results into constructor. This way all the code in here can
Expand Down Expand Up @@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
wait_numel = 0
wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
self.__profiler.start_event(wait_event_name)
fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
Expand All @@ -321,16 +325,18 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()

self.__inflight_param_registry.pop(param).wait()
self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch)

if not get_accelerator().handles_memory_backpressure():
if not get_accelerator().handles_memory_backpressure() and not fast_fetch:
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
if fast_fetch:
AllGatherCoalescedHandle.free_buffer()
self.__profiler.stop_event(wait_event_name, wait_numel)

# kick off parameter prefetches for upcoming modules
Expand Down Expand Up @@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None:
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))

free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
if not free_data:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, device=get_accelerator().current_device())

for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param)
self.__release_param(param, free_data)
if not free_data:
if param.ds_id in params_to_release and not param.is_external_param:
# empty buffer ensures that all computations are complete
param.data = empty_buffer

@instrument_w_nvtx
@torch.no_grad()
Expand Down Expand Up @@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:

@compiler.disable
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
def __release_param(self, param: Parameter, free_data: bool = True) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel

@instrument_w_nvtx
Expand Down
Loading