-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prototype] Add param-to-lr interface to distributed Shampoo #22
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,10 @@ class DistributedShampoo(torch.optim.Optimizer): | |
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. | ||
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. | ||
(Default: False) | ||
experimental_param_to_lr (Optional[Callable[[Tensor], float]]): Optional mapping between Param and learning rate. | ||
If set, this map needs to cover all parameters in param_groups. | ||
This setting supersedes learning rate of each parameter group. | ||
(Default: None) | ||
|
||
""" | ||
|
||
|
@@ -326,6 +330,7 @@ def __init__( | |
precision_config: Optional[PrecisionConfig] = None, | ||
use_protected_eigh: bool = True, | ||
track_root_inv_residuals: bool = False, | ||
experimental_param_to_lr: Optional[Callable[[torch.Tensor], float]] = None, | ||
) -> None: | ||
# Hyperparameter checks. | ||
if not lr >= 0.0: | ||
|
@@ -474,6 +479,7 @@ def __init__( | |
self._shampoo_pt2_compile_config: Optional[ShampooPT2CompileConfig] = ( | ||
shampoo_pt2_compile_config | ||
) | ||
self._experimental_param_to_lr = experimental_param_to_lr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to handle support for this properly, we will need to create a function that constructs this mapping automatically from each parameter (within each parameter group) to its learning rate and modify the parameter groups defined typically by This may need to be moved prior to the |
||
|
||
# Initialize dictionary containing lists of . | ||
self._per_group_state_lists: List[Dict[str, Any]] = [ | ||
|
@@ -1142,6 +1148,114 @@ def _per_group_step_impl( | |
masked_blocked_search_directions=masked_blocked_search_directions | ||
) | ||
|
||
@torch.no_grad() | ||
def _per_group_step_experimental_lrs( | ||
self, | ||
state_lists: Dict[str, Any], | ||
step: torch.Tensor, | ||
neg_lrs: List[torch.Tensor], | ||
beta1: float, | ||
beta3: float, | ||
weight_decay: float, | ||
momentum_param: float, | ||
dampening: float, | ||
grafting_config_not_none: bool, | ||
compute_root_inverse: bool, | ||
use_decoupled_weight_decay: bool, | ||
use_bias_correction: bool, | ||
use_grafting_method: bool, | ||
use_nesterov: bool, | ||
) -> None: | ||
# Incorporate L2-regularization or (coupled) weight decay if enabled. | ||
# G <- G + lr * weight_decay * W | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reminder to self that we have a typo here; shouldn't have cc: @tsunghsienlee |
||
self._add_l2_regularization( | ||
state_lists, | ||
weight_decay, | ||
use_decoupled_weight_decay, | ||
) | ||
|
||
with DequantizePreconditionersContext( | ||
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] | ||
), ( | ||
DequantizePreconditionersContext( | ||
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] | ||
) | ||
if grafting_config_not_none | ||
else contextlib.nullcontext() | ||
): | ||
# Update Shampoo and grafting preconditioners / factor matrices. | ||
# Example for AdaGrad accumulation: | ||
# L <- L + G * G^T | ||
# R <- R + G^T * G | ||
# V <- V + G^2 (element-wise) | ||
# (and similar) | ||
self._update_preconditioners( | ||
state_lists, | ||
step, | ||
grafting_config_not_none, | ||
) | ||
|
||
# Compute matrix root inverse. | ||
# L_inv <- L ** (-1/4) | ||
# R_inv <- R ** (-1/4) | ||
# (and similar) | ||
self._compute_root_inverse(state_lists, compute_root_inverse) | ||
|
||
# Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. | ||
# Note that we use two beta factors here akin to Lion. | ||
# G_bar <- beta3 * G_tilde + (1 - beta3) * G | ||
# G_tilde <- beta1 * G_tilde + (1 - beta1) * G | ||
masked_filtered_grad_list = self._compute_filtered_grad_list( | ||
state_lists, | ||
step, | ||
beta1, | ||
beta3, | ||
use_bias_correction, | ||
) | ||
|
||
# Precondition and graft filtered gradients. | ||
# PT2 compile is currently disabled for preconditioning and grafting. | ||
# TODO: Resolve preconditioning and grafting PT2 NEX issue and enable them. | ||
# | ||
# P_shampoo <- L_inv * G_bar * R_inv (and similar) | ||
# P_grafting <- G_bar / (sqrt(V) + epsilon) | ||
# P <- P_grafting if step < start_preconditioning_step | ||
# P <- ||P_grafting|| / ||P_shampoo|| * P_shampoo otherwise | ||
masked_blocked_search_directions = self._precondition_and_grafting( | ||
state_lists, | ||
masked_filtered_grad_list, | ||
use_grafting_method, | ||
grafting_config_not_none, | ||
) | ||
|
||
# Incorporate decoupled weight decay into search direction if enabled. | ||
# P <- P + weight_decay * W | ||
self._apply_decoupled_weight_decay( | ||
state_lists, | ||
masked_blocked_search_directions, | ||
weight_decay, | ||
use_decoupled_weight_decay, | ||
) | ||
|
||
# Update momentum optimizer state and use momentum / Nesterov if enabled. | ||
# M <- momentum_param * M + (1 - dampening) * P | ||
# P <- (1 - dampening) * P + momentum_param * M if use_nesterov | ||
# P <- M otherwise. | ||
self._update_momentum( | ||
state_lists, | ||
masked_blocked_search_directions, | ||
momentum_param, | ||
dampening, | ||
use_nesterov, | ||
) | ||
|
||
# Updates parameters in distributed fashion. | ||
# If DDP, executes AllGather communication to ensure all parameters are updated after local updates. | ||
torch._foreach_mul_(masked_blocked_search_directions, neg_lrs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this is exactly the same as the code for |
||
state_lists[DISTRIBUTOR].update_params( | ||
masked_blocked_search_directions=masked_blocked_search_directions | ||
) | ||
|
||
@torch.no_grad() | ||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: | ||
"""Performs a single optimization step. | ||
|
@@ -1173,12 +1287,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] | |
|
||
# Iterate group step counter and define Python scalar step. | ||
step = state_lists[STEP].add_(1) | ||
# NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; | ||
# Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu | ||
# tensor impl once PT2 supports cpu 0D tensor properly. | ||
lr = torch.tensor(group[LR], dtype=torch.float).to( | ||
self._device, non_blocking=True | ||
) | ||
beta1 = group[BETAS][0] | ||
beta3 = group[BETA3] | ||
weight_decay = group[WEIGHT_DECAY] | ||
|
@@ -1200,22 +1308,58 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] | |
) | ||
use_nesterov = group[USE_NESTEROV] | ||
|
||
self._per_group_step( | ||
state_lists, | ||
step, | ||
lr, | ||
beta1, | ||
beta3, | ||
weight_decay, | ||
momentum_param, | ||
dampening, | ||
grafting_config_not_none, | ||
compute_root_inverse, | ||
use_decoupled_weight_decay, | ||
use_bias_correction, | ||
use_grafting_method, | ||
use_nesterov, | ||
) | ||
if self._experimental_param_to_lr is None: | ||
# NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; | ||
# Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu | ||
# tensor impl once PT2 supports cpu 0D tensor properly. | ||
lr = torch.tensor(group[LR], dtype=torch.float).to( | ||
self._device, non_blocking=True | ||
) | ||
self._per_group_step( | ||
state_lists, | ||
step, | ||
lr, | ||
beta1, | ||
beta3, | ||
weight_decay, | ||
momentum_param, | ||
dampening, | ||
grafting_config_not_none, | ||
compute_root_inverse, | ||
use_decoupled_weight_decay, | ||
use_bias_correction, | ||
use_grafting_method, | ||
use_nesterov, | ||
) | ||
else: | ||
local_block_info_list = compress_list( | ||
state_lists[DISTRIBUTOR].global_block_info_list, | ||
state_lists[DISTRIBUTOR].distributor_selector, | ||
) | ||
neg_lr_tersors = [] | ||
for local_block_info in local_block_info_list: | ||
lr_scalar = self._experimental_param_to_lr(local_block_info.param) | ||
lr = torch.tensor(-lr_scalar, dtype=torch.float).to( | ||
self._device, non_blocking=True | ||
) | ||
neg_lr_tersors.append(lr) | ||
|
||
self._per_group_step_experimental_lrs( | ||
state_lists, | ||
step, | ||
neg_lr_tersors, | ||
beta1, | ||
beta3, | ||
weight_decay, | ||
momentum_param, | ||
dampening, | ||
grafting_config_not_none, | ||
compute_root_inverse, | ||
use_decoupled_weight_decay, | ||
use_bias_correction, | ||
use_grafting_method, | ||
use_nesterov, | ||
) | ||
|
||
return loss | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will not want to expose this directly to the user, but create a flag that merges parameter groups that have different
lr
,betas
,beta3
,epsilon
,momentum
,dampening
, orweight_decay
, but share the same fields everywhere else.