Skip to content

Commit

Permalink
fix sparse gradient clipping for torch>=2.0 (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang authored May 28, 2024
1 parent fbdde4f commit ca1a91b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 16 deletions.
62 changes: 62 additions & 0 deletions pecos/utils/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import torch
from typing import Union, Iterable

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,3 +73,64 @@ def apply_mask(hidden_states, masks):
hidden_dim = hidden_states.shape[-1]
hidden_states.view(-1, hidden_dim)[~masks.view(-1).type(torch.ByteTensor), :] = 0
return hidden_states


def clip_grad_norm_(
parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> torch.Tensor:
r"""
Implementation of torch.nn.utils.clip_grad_norm_ in torch==1.13
This is to support sparse gradient with gradient clipping.
REF: https://pytorch.org/docs/1.13/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
device = grads[0].device
if norm_type == "inf":
norms = [g.detach().abs().max().to(device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm
8 changes: 5 additions & 3 deletions pecos/xmc/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,11 @@ def predict(
Ye = self.predict(
X[i : i + max_pred_chunk, :],
pred_params=pred_params,
selected_outputs_csr=selected_outputs_csr[i : i + max_pred_chunk, :]
if selected_outputs_csr is not None
else None,
selected_outputs_csr=(
selected_outputs_csr[i : i + max_pred_chunk, :]
if selected_outputs_csr is not None
else None
),
**new_kwargs,
)
Ys.append(Ye)
Expand Down
34 changes: 22 additions & 12 deletions pecos/xmc/xtransformer/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,18 +784,20 @@ def _predict(
if not only_embeddings:
text_model_W_seq, text_model_b_seq = self.text_model(
output_indices=inputs["label_indices"],
num_device=len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1,
num_device=(
len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1
),
)

outputs = self.text_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
label_embedding=None
if only_embeddings
else (text_model_W_seq, text_model_b_seq),
label_embedding=(
None if only_embeddings else (text_model_W_seq, text_model_b_seq)
),
)

if not only_embeddings:
Expand Down Expand Up @@ -1088,9 +1090,11 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
}
text_model_W_seq, text_model_b_seq = self.text_model(
output_indices=inputs["label_indices"],
num_device=len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1,
num_device=(
len(self.text_encoder.device_ids)
if hasattr(self.text_encoder, "device_ids")
else 1
),
)
outputs = self.text_encoder(
input_ids=inputs["input_ids"],
Expand Down Expand Up @@ -1119,9 +1123,15 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
scheduler.step() # update learning rate schedule
optimizer.zero_grad() # clear gradient accumulation

torch.nn.utils.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)
if self.text_model.is_sparse:
torch_util.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)
else:
torch.nn.utils.clip_grad_norm_(
self.text_model.parameters(), train_params.max_grad_norm
)

emb_optimizer.step() # perform gradient update
emb_scheduler.step() # update learning rate schedule
emb_optimizer.zero_grad() # clear gradient accumulation
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def get_blas_lib_dir(cls):
install_requires = numpy_requires + [
'scipy>=1.4.1',
'scikit-learn>=0.24.1',
'torch>=1.8.0,<2.0.0',
'torch==1.13; python_version<"3.8"',
'torch>=2.0; python_version>="3.8"',
'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers
'transformers>=4.1.1; python_version<"3.9"',
'transformers>=4.4.2; python_version>="3.9"'
Expand Down

0 comments on commit ca1a91b

Please sign in to comment.