Skip to content

Commit

Permalink
final alibi fix
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jul 4, 2024
1 parent 3e6d50b commit c80b4d3
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 115 deletions.
64 changes: 32 additions & 32 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from dataclasses import dataclass, field

import torch
import torch.nn.functional as F

from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
Expand All @@ -33,6 +31,26 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss


@register_criterion(
"label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
)
Expand All @@ -53,6 +71,7 @@ def __init__(

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
Expand All @@ -76,43 +95,24 @@ def forward(self, model, sample, reduce=True):
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output

def get_lprobs_and_target(self, model, net_output, sample, log_probs=True):
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
# lprobs: B x T x C
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)

def get_logits_and_target(self, model, net_output, sample):
logits = net_output[0]
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
# logits: B x T x C
logits = logits[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return logits.view(-1, logits.size(-1)), target.view(-1)

def compute_loss(self, model, net_output, sample, reduce=True):
logits, target = self.get_logits_and_target(model, net_output, sample)

loss = F.cross_entropy(
logits,
target,
reduction="sum" if reduce else "none",
ignore_index=self.padding_idx,
label_smoothing=self.eps,
)

nll_loss = F.cross_entropy(
logits,
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
reduction="sum" if reduce else "none",
self.eps,
ignore_index=self.padding_idx,
label_smoothing=0.0
reduce=reduce,
)

return loss, nll_loss

def compute_accuracy(self, model, net_output, sample):
Expand All @@ -127,16 +127,16 @@ def compute_accuracy(self, model, net_output, sample):
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss = sum(log.get("nll_loss", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

metrics.log_scalar(
"loss", loss / sample_size / math.log(2), sample_size, round=3
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss / ntokens / math.log(2), ntokens, round=3
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
Expand Down
28 changes: 10 additions & 18 deletions fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from dataclasses import dataclass, field

import torch
import torch.nn.functional as F

from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
label_smoothed_nll_loss,
)


Expand Down Expand Up @@ -69,7 +69,6 @@ def forward(self, model, sample, reduce=True, net_output=None):
) == sample["target"].size(0):
sample = duplicate_input(sample)
net_output = model(**sample["net_input"])

loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, net_output, sample, reduce=reduce
)
Expand Down Expand Up @@ -104,20 +103,13 @@ def get_lprobs_and_target(self, model, net_output, sample):
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)

def compute_loss(self, model, net_output, sample, reduce=True):
logits, target = self.get_logits_and_target(model, net_output, sample)
loss = F.cross_entropy(
logits,
target,
reduction="sum" if reduce else "none",
ignore_index=self.padding_idx,
label_smoothing=self.eps,
)

nll_loss = F.cross_entropy(
logits,
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
reduction="sum" if reduce else "none",
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)

if self.rdrop_alpha > 0:
Expand Down Expand Up @@ -162,16 +154,16 @@ def duplicate_input(sample):

def compute_kl_loss(model, net_output, pad_mask=None, reduce=True):
net_prob = model.get_normalized_probs(net_output, log_probs=True)
net_prob_tec = model.get_normalized_probs(net_output, log_probs=True)
net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)

net_prob = net_prob.view(-1, net_prob.size(-1))
net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1))

p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)

p_loss = F.kl_div(p, q_tec, reduction="none", log_target=True)
q_loss = F.kl_div(q, p_tec, reduction="none", log_target=True)
p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none")
q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none")

if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.0)
Expand All @@ -182,4 +174,4 @@ def compute_kl_loss(model, net_output, pad_mask=None, reduce=True):
q_loss = q_loss.sum()

loss = (p_loss + q_loss) / 2
return loss
return loss
16 changes: 5 additions & 11 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,15 @@ class TransformerConfig(FairseqDataclass):
"help": "LoRA arguments (rank, alpha, dropout, target_modules, rank_scaled)"
},
)
rope_args: Optional[str] = field(
default=None,
metadata={
"help": "RoPE arguments (max_position_embeddings, base, type ['vanilla', 'linear', 'dynamic'], scale)"
},
)
yarn_args: Optional[str] = field(
default=None,
use_rope: Optional[bool] = field(
default=False,
metadata={
"help": "YaRN arguments (max_position_embeddings, base, type ['vanilla', 'dynamic'], scale, original_max_position_embeddings, extrapolation_factor, attn_factor, beta_fast, beta_slow, finetuned)"
"help": "use RoPE based attention."
},
)
alibi_args: Optional[str] = field(
use_alibi: Optional[str] = field(
default=None,
metadata={"help": "ALiBi arguments (alibi_asymmetrical)"},
metadata={"help": "use ALiBi positional encoding (symmetrical/asymmetrical)"},
)

# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
Expand Down
50 changes: 35 additions & 15 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,21 @@ def __init__(
else:
self.project_out_activation_fn = None

self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(cfg, dictionary, embed_tokens)

if getattr(cfg, "alibi_args", False) and self.embed_positions is None:
if cfg.use_alibi is not None:
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.decoder.attention_heads, self.max_target_positions
)
else:
self.alibi = None

self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(cfg, dictionary, embed_tokens)

def normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

Expand Down Expand Up @@ -363,11 +366,14 @@ def extract_features_scriptable(

x = self.dropout_module(x)

# We move the mask construction here because its slightly more efficient.
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
self_attn_mask = (
self.buffered_future_mask(x)
if incremental_state is None and not full_context_alignment
else None
)

if self_attn_mask is None and incremental_state is not None:
self_attn_mask = self._bias_attn_mask(x, incremental_state)

# B x T x C -> T x B x C
x = x.transpose(0, 1)
Expand Down Expand Up @@ -428,6 +434,19 @@ def max_positions(self):
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)

def _bias_attn_mask(self, x, incremental_state):
if incremental_state is None:
return None

saved_state = self.layers[0].self_attn._get_input_buffer(incremental_state)

if saved_state is None or ("prev_key" not in saved_state):
return None

src_len = saved_state["prev_key"].shape[2]

return self.alibi[:, src_len, : src_len + 1].unsqueeze(1).to(x.device)

def buffered_future_mask(self, tensor):
B, T, _ = tensor.size()
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
Expand All @@ -445,10 +464,11 @@ def buffered_future_mask(self, tensor):
if self.alibi is not None:
self._future_mask = self._future_mask.unsqueeze(0) + self.alibi
self._future_mask = self._future_mask.to(tensor)
if self.alibi is not None:
return self._future_mask[:, :T, :T].repeat(B, 1, 1)
else:
return self._future_mask[:T, :T]
return (
self._future_mask[:, :T, :T].repeat(B, 1, 1)
if self.alibi is not None
else self._future_mask[:T, :T]
)

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
Expand Down
20 changes: 11 additions & 9 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
else None
)

if getattr(cfg, "alibi_args", False) and self.embed_positions is None:
alibi_args = json.loads(cfg.alibi_args)
if cfg.use_alibi is not None:
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.encoder.attention_heads,
self.max_source_positions,
alibi_args.get("alibi_asymmetrical", False),
asymmetrical=cfg.use_alibi,
)
else:
self.alibi = None
Expand Down Expand Up @@ -247,13 +249,11 @@ def forward_scriptable(
)

if self.alibi is not None:
shape = x.size()
B, T, _ = x.size()
self.alibi = self.alibi.to(x)
self_attn_mask = self.alibi[:, : shape[1], : shape[1]].repeat(
shape[0], 1, 1
)
self_attn_mask = self.alibi[:, :T, :T].repeat(B, 1, 1)
else:
pass
self_attn_mask = None

# B x T x C -> T x B x C
x = x.transpose(0, 1)
Expand All @@ -267,7 +267,9 @@ def forward_scriptable(
# encoder layers
for idx, layer in enumerate(self.layers):
lr = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
x,
encoder_padding_mask=encoder_padding_mask if has_pads else None,
self_attn_mask=self_attn_mask,
)

if isinstance(lr, tuple) and len(lr) == 2:
Expand Down
Loading

0 comments on commit c80b4d3

Please sign in to comment.