Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions docs/source/en/perplexity.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
nll_sum = 0.0
n_tokens = 0
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
Expand All @@ -124,13 +125,19 @@ for begin_loc in tqdm(range(0, seq_len, stride)):
# to the left by 1.
neg_log_likelihood = outputs.loss

nlls.append(neg_log_likelihood)
# Accumulate the total negative log-likelihood and the total number of tokens
num_valid_tokens = (target_ids != -100).sum().item() # number of valid tokens in target_ids
batch_size = target_ids.size(0)
num_loss_tokens = num_valid_tokens - batch_size # subtract batch_size due to internal label shift
nll_sum += neg_log_likelihood * num_loss_tokens
n_tokens += num_loss_tokens

prev_end_loc = end_loc
if end_loc == seq_len:
break

ppl = torch.exp(torch.stack(nlls).mean())
avg_nll = nll_sum / n_tokens # average negative log-likelihood per token
ppl = torch.exp(avg_nll)
```

Running this with the stride length equal to the max input length is equivalent to the suboptimal, non-sliding-window
Expand All @@ -139,5 +146,5 @@ and the better the reported perplexity will typically be.

When we run the above with `stride = 1024`, i.e. no overlap, the resulting PPL is `19.44`, which is about the same
as the `19.93` reported in the GPT-2 paper. By using `stride = 512` and thereby employing our striding window
strategy, this jumps down to `16.45`. This is not only a more favorable score, but is calculated in a way that is
strategy, this jumps down to `16.44`. This is not only a more favorable score, but is calculated in a way that is
closer to the true autoregressive decomposition of a sequence likelihood.