Skip to content

Commit ca1a91b

Browse files
authored
fix sparse gradient clipping for torch>=2.0 (amzn#288)
1 parent fbdde4f commit ca1a91b

File tree

4 files changed

+91
-16
lines changed

4 files changed

+91
-16
lines changed

pecos/utils/torch_util.py

+62
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15+
from typing import Union, Iterable
1516

1617
LOGGER = logging.getLogger(__name__)
1718

@@ -72,3 +73,64 @@ def apply_mask(hidden_states, masks):
7273
hidden_dim = hidden_states.shape[-1]
7374
hidden_states.view(-1, hidden_dim)[~masks.view(-1).type(torch.ByteTensor), :] = 0
7475
return hidden_states
76+
77+
78+
def clip_grad_norm_(
79+
parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
80+
max_norm: float,
81+
norm_type: float = 2.0,
82+
error_if_nonfinite: bool = False,
83+
) -> torch.Tensor:
84+
r"""
85+
Implementation of torch.nn.utils.clip_grad_norm_ in torch==1.13
86+
This is to support sparse gradient with gradient clipping.
87+
REF: https://pytorch.org/docs/1.13/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
88+
89+
Clips gradient norm of an iterable of parameters.
90+
91+
The norm is computed over all gradients together, as if they were
92+
concatenated into a single vector. Gradients are modified in-place.
93+
94+
Args:
95+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
96+
single Tensor that will have gradients normalized
97+
max_norm (float or int): max norm of the gradients
98+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
99+
infinity norm.
100+
error_if_nonfinite (bool): if True, an error is thrown if the total
101+
norm of the gradients from :attr:`parameters` is ``nan``,
102+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
103+
104+
Returns:
105+
Total norm of the parameter gradients (viewed as a single vector).
106+
"""
107+
if isinstance(parameters, torch.Tensor):
108+
parameters = [parameters]
109+
grads = [p.grad for p in parameters if p.grad is not None]
110+
max_norm = float(max_norm)
111+
norm_type = float(norm_type)
112+
if len(grads) == 0:
113+
return torch.tensor(0.0)
114+
device = grads[0].device
115+
if norm_type == "inf":
116+
norms = [g.detach().abs().max().to(device) for g in grads]
117+
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
118+
else:
119+
total_norm = torch.norm(
120+
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type
121+
)
122+
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
123+
raise RuntimeError(
124+
f"The total norm of order {norm_type} for gradients from "
125+
"`parameters` is non-finite, so it cannot be clipped. To disable "
126+
"this error and scale the gradients by the non-finite norm anyway, "
127+
"set `error_if_nonfinite=False`"
128+
)
129+
clip_coef = max_norm / (total_norm + 1e-6)
130+
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
131+
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
132+
# when the gradients do not reside in CPU memory.
133+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
134+
for g in grads:
135+
g.detach().mul_(clip_coef_clamped.to(g.device))
136+
return total_norm

pecos/xmc/xlinear/model.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,11 @@ def predict(
537537
Ye = self.predict(
538538
X[i : i + max_pred_chunk, :],
539539
pred_params=pred_params,
540-
selected_outputs_csr=selected_outputs_csr[i : i + max_pred_chunk, :]
541-
if selected_outputs_csr is not None
542-
else None,
540+
selected_outputs_csr=(
541+
selected_outputs_csr[i : i + max_pred_chunk, :]
542+
if selected_outputs_csr is not None
543+
else None
544+
),
543545
**new_kwargs,
544546
)
545547
Ys.append(Ye)

pecos/xmc/xtransformer/matcher.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -784,18 +784,20 @@ def _predict(
784784
if not only_embeddings:
785785
text_model_W_seq, text_model_b_seq = self.text_model(
786786
output_indices=inputs["label_indices"],
787-
num_device=len(self.text_encoder.device_ids)
788-
if hasattr(self.text_encoder, "device_ids")
789-
else 1,
787+
num_device=(
788+
len(self.text_encoder.device_ids)
789+
if hasattr(self.text_encoder, "device_ids")
790+
else 1
791+
),
790792
)
791793

792794
outputs = self.text_encoder(
793795
input_ids=inputs["input_ids"],
794796
attention_mask=inputs["attention_mask"],
795797
token_type_ids=inputs["token_type_ids"],
796-
label_embedding=None
797-
if only_embeddings
798-
else (text_model_W_seq, text_model_b_seq),
798+
label_embedding=(
799+
None if only_embeddings else (text_model_W_seq, text_model_b_seq)
800+
),
799801
)
800802

801803
if not only_embeddings:
@@ -1088,9 +1090,11 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
10881090
}
10891091
text_model_W_seq, text_model_b_seq = self.text_model(
10901092
output_indices=inputs["label_indices"],
1091-
num_device=len(self.text_encoder.device_ids)
1092-
if hasattr(self.text_encoder, "device_ids")
1093-
else 1,
1093+
num_device=(
1094+
len(self.text_encoder.device_ids)
1095+
if hasattr(self.text_encoder, "device_ids")
1096+
else 1
1097+
),
10941098
)
10951099
outputs = self.text_encoder(
10961100
input_ids=inputs["input_ids"],
@@ -1119,9 +1123,15 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
11191123
scheduler.step() # update learning rate schedule
11201124
optimizer.zero_grad() # clear gradient accumulation
11211125

1122-
torch.nn.utils.clip_grad_norm_(
1123-
self.text_model.parameters(), train_params.max_grad_norm
1124-
)
1126+
if self.text_model.is_sparse:
1127+
torch_util.clip_grad_norm_(
1128+
self.text_model.parameters(), train_params.max_grad_norm
1129+
)
1130+
else:
1131+
torch.nn.utils.clip_grad_norm_(
1132+
self.text_model.parameters(), train_params.max_grad_norm
1133+
)
1134+
11251135
emb_optimizer.step() # perform gradient update
11261136
emb_scheduler.step() # update learning rate schedule
11271137
emb_optimizer.zero_grad() # clear gradient accumulation

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def get_blas_lib_dir(cls):
115115
install_requires = numpy_requires + [
116116
'scipy>=1.4.1',
117117
'scikit-learn>=0.24.1',
118-
'torch>=1.8.0,<2.0.0',
118+
'torch==1.13; python_version<"3.8"',
119+
'torch>=2.0; python_version>="3.8"',
119120
'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers
120121
'transformers>=4.1.1; python_version<"3.9"',
121122
'transformers>=4.4.2; python_version>="3.9"'

0 commit comments

Comments
 (0)