Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support key padding masks for Perceiver AR #25

Merged
merged 1 commit into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/training-examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ UTF-8 bytes of the input.
python examples/training/clm/train.py
```

For better generalization to shorter sequences I found random sequence truncation helpful which can be enabled with
`--model.random_truncation=true`. The minimum sequence length can be configured with `--model.random_min_seq_lem=m`.
Random sequence truncation randomly truncates sequences in a batch to length `randint(m, n+1)` where `m < n` and `n`
is the configured `max_seq_len`.
For better generalization to shorter sequences I found random sequence truncation at training time helpful. This can be
enabled with `--data.random_train_truncation=true`. The minimum sequence length can be configured with `--data.random_min_seq_lem=m`.
Random sequence truncation randomly truncates sequences in a batch to length `randint(m, n+1)` where `m < n` and `n` is
the configured `max_seq_len`. Sequences are truncated from the right and padded to the left (`--data.padding_side=left`).

With option `--model.validation_sample_record=-1` a sequence is randomly picked from the validation set and used as
prompt for sequence generation during validation. The prompt and the generated sequence is logged to Tensorboard. You
Expand Down
1 change: 1 addition & 0 deletions examples/training/clm/prep.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
python -m perceiver.scripts.text.preproc wikitext \
--tokenizer=deepmind/language-perceiver \
--random_train_shift=true \
--add_special_tokens=false \
--max_seq_len=4096 \
--task=clm
2 changes: 2 additions & 0 deletions examples/training/clm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def configure_optimizers(self):
max_seq_len=4096,
batch_size=24,
task=Task.clm,
padding_side="left",
random_train_shift=True,
)

config = CausalLanguageModelConfig(
Expand Down
4 changes: 2 additions & 2 deletions examples/training/clm/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ python -m perceiver.scripts.text.clm fit \
--model.num_latents=512 \
--model.cross_attention_dropout=0.5 \
--model.post_attention_dropout=0.0 \
--model.random_truncation=false \
--model.random_min_seq_len=16 \
--data=WikiTextDataModule \
--data.tokenizer=deepmind/language-perceiver \
--data.add_special_tokens=false \
--data.max_seq_len=4096 \
--data.task=clm \
--data.batch_size=24 \
--data.padding_side=left \
--data.random_train_shift=true \
--optimizer=Adam \
--optimizer.lr=2e-4 \
--lr_scheduler.warmup_steps=200 \
Expand Down
24 changes: 16 additions & 8 deletions perceiver/data/text/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __call__(self, examples):


class DefaultCollator(Collator):
label_keys = ["label", "labels", "label_ids"]
label_keys = ["label", "labels"]

def __init__(self, tokenizer: PreTrainedTokenizerFast, max_seq_len: Optional[int] = None):
self.collator = DefaultDataCollator()
Expand All @@ -42,20 +42,28 @@ def _prepare(self, example, max_length):
# must be preserved though. Setting add_special_tokens=true doesn't
# work either because this would duplicate (some) special tokens
# already contained in the input sequence.
prepared = self.tokenizer.prepare_for_model(
example["input_ids"],
prepared = self._prepare_sequence(example["input_ids"], max_length)

if "label_ids" in example:
prepared_label_ids = self._prepare_sequence(example["label_ids"], max_length)
prepared["label_ids"] = prepared_label_ids["input_ids"]

for label_key in self.label_keys:
if label_key in example:
prepared[label_key] = example[label_key]

return prepared

def _prepare_sequence(self, sequence, max_length):
return self.tokenizer.prepare_for_model(
sequence,
add_special_tokens=False,
return_token_type_ids=False,
padding=PaddingStrategy.MAX_LENGTH,
max_length=max_length,
truncation=True,
)

for label_key in self.label_keys:
if label_key in example:
prepared[label_key] = example[label_key]
return prepared


class WordMaskingCollator(Collator):
def __init__(self, tokenizer: PreTrainedTokenizerFast, mask_prob: float = 0.15):
Expand Down
43 changes: 43 additions & 0 deletions perceiver/data/text/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ def __init__(
mask_words: bool = True,
static_masking: bool = False,
add_special_tokens: bool = False,
padding_side: Optional[str] = None,
random_train_shift: bool = False,
random_valid_shift: bool = False,
random_train_truncation: bool = False,
random_valid_truncation: bool = False,
random_min_seq_len: int = 16,
preproc_batch_size: int = 1000,
preproc_workers: Optional[int] = None,
batch_size: int = 64,
Expand All @@ -78,6 +82,14 @@ def __init__(
:param static_masking: Whether to mask at preprocessing time (static) or at data loading time (dynamic).
Ignored if task is not `Task.mlm`.
:param add_special_tokens: Whether to add special tokens to tokenized text.
:param padding_side: If `None`, uses the pre-configured `padding_side` of the tokenizer. Can be overridden by
setting to "left" or "right".
:param random_train_truncation: Randomly truncates sequences in the training set to length
`randint(random_min_seq_len, max_seq_len + 1)`.
:param random_valid_truncation: Randomly truncates sequences in the validation set to length
`randint(random_min_seq_len, max_seq_len + 1)`.
:param random_min_seq_len: Minimum sequence length when using `random_train_truncation` or
`random_valid_truncation`.
:param preproc_batch_size: Preprocessing batch size.
:param preproc_workers: Number of preprocessing processes. If not defined, defaults to `num_workers`.
:param batch_size: Batch size of loaded data.
Expand All @@ -91,6 +103,10 @@ def __init__(
raise ValueError("static_masking=true is only supported for mask_words=true")

self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.tokenizer, verbose=False)

if self.hparams.padding_side is not None:
self.tokenizer.padding_side = self.hparams.padding_side

# PerceiverTokenizer needs special support for generating word_ids as it is not a fast tokenizer
self.perceiver_tokenizer_configured = self.hparams.tokenizer == "deepmind/language-perceiver"
if self.perceiver_tokenizer_configured:
Expand Down Expand Up @@ -162,6 +178,11 @@ def setup(self, stage=None):
if self.hparams.random_valid_shift:
self.ds_valid = RandomShiftDataset(self.ds_valid)

if self.hparams.random_train_truncation:
self.ds_train = RandomTruncationDataset(self.ds_train, self.hparams.random_min_seq_len)
if self.hparams.random_valid_truncation:
self.ds_valid = RandomTruncationDataset(self.ds_valid, self.hparams.random_min_seq_len)

if self.hparams.task == Task.clm:
self.ds_train = CLMDataset(self.ds_train)
self.ds_valid = CLMDataset(self.ds_valid)
Expand Down Expand Up @@ -333,6 +354,28 @@ def __len__(self):
return len(self.dataset) - 1


class RandomTruncationDataset(torch.utils.data.Dataset):
def __init__(self, dataset, random_min_seq_len: int):
self.dataset = dataset
self.random_min_seq_len = random_min_seq_len

def __getitem__(self, idx):
example = self.dataset[idx]
example_seq_len = len(example["input_ids"])

drop_max = example_seq_len - self.random_min_seq_len
if drop_max > 0:
drop = torch.randint(drop_max + 1, size=(1,))
if drop > 0:
for key in example.keys():
example[key] = example[key][:-drop]

return example

def __len__(self):
return len(self.dataset)


class CLMDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
Expand Down
20 changes: 10 additions & 10 deletions perceiver/model/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch.nn as nn
from einops import rearrange

from perceiver.model.core.position import FrequencyPositionEncoding
from perceiver.model.core.position import FrequencyPositionEncoding, positions


class InputAdapter(nn.Module):
def __init__(self, num_input_channels: int):
def __init__(self, num_input_channels: int, *args, **kwargs):
"""Transforms and position-encodes task-specific input to generic encoder input.

:param num_input_channels: Number of channels of the generic encoder input produced by this adapter.
Expand All @@ -20,16 +20,16 @@ def num_input_channels(self):


class RotarySupport(InputAdapter):
def __init__(self, encoded_channels_per_head: int, *args, **kwargs):
"""An input adapter mixin that additionally generates constructor arguments for
`RotaryPositionEmbedding`."""
def __init__(self, rotated_channels_per_head: int, *args, **kwargs):
"""An input adapter mixin that additionally generates a frequency position encoding for input sequence
`x`."""
super().__init__(*args, **kwargs)
self.frq_pos_encoding = FrequencyPositionEncoding(dim=encoded_channels_per_head)
self.frq_pos_encoding = FrequencyPositionEncoding(dim=rotated_channels_per_head)

def forward(self, x):
"""Transforms and position-encodes sequence `x` and additionally returns a frequency position encoding of
`x` required to create a `RotaryPositionEmbedding` instance."""
return super().forward(x), self.frq_pos_encoding(x.shape[1])
def forward(self, x, abs_pos=None):
if abs_pos is None:
abs_pos = positions(*x.shape, device=x.device)
return super().forward(x, abs_pos), self.frq_pos_encoding(abs_pos)


class OutputAdapter(nn.Module):
Expand Down
48 changes: 33 additions & 15 deletions perceiver/model/core/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops import rearrange
from fairscale.nn import checkpoint_wrapper

from perceiver.model.core.adapter import (
Expand All @@ -12,7 +12,7 @@
RotarySupport,
TrainableQueryProvider,
)
from perceiver.model.core.position import RotaryPositionEmbedding
from perceiver.model.core.position import positions, RotaryPositionEmbedding
from perceiver.model.core.utils import init_parameters, Residual, Sequential


Expand Down Expand Up @@ -636,14 +636,25 @@ def _init_parameters(self, init_scale: float):
with torch.no_grad():
init_parameters(self, init_scale)

def forward(self, x):
x, frq_pos_enc = self.input_adapter(x)
def forward(self, x, pad_mask=None):
if pad_mask is None:
shift = None
else:
# caller must ensure that x is left-padded
shift = pad_mask.sum(dim=1, keepdim=True)

frq_pos_enc_q = frq_pos_enc
frq_pos_enc_k = frq_pos_enc
# freq_pos_enc shape is (b, n, f)
x, frq_pos_enc = self.input_adapter(x, abs_pos=positions(*x.shape, shift=shift, device=x.device))

x_latent = x[:, -self.num_latents :]
x_prefix = x[:, : -self.num_latents]

frq_pos_enc_latent = frq_pos_enc[:, -self.num_latents :]
frq_pos_enc_prefix = frq_pos_enc[:, : -self.num_latents]

pad_mask_latent = None if pad_mask is None else pad_mask[:, -self.num_latents :]
pad_mask_prefix = None if pad_mask is None else pad_mask[:, : -self.num_latents]

n_prefix = x_prefix.shape[1]

b, n, _ = x.shape
Expand All @@ -657,25 +668,32 @@ def forward(self, x):
keep_indices = rand.topk(keep, dim=-1).indices
# mask of positions in prefix sequence to keep
keep_mask = torch.zeros_like(rand, dtype=torch.bool).scatter_(dim=1, index=keep_indices, value=1)

# drop positions in prefix sequence according to prefix_dropout
x_prefix = rearrange(x_prefix[keep_mask], "(b n) c -> b n c", b=b)
# drop positions in prefix frequency position encoding
frq_pos_enc_prefix = rearrange(frq_pos_enc_prefix[keep_mask], "(b n) f -> b n f", b=b)
# drop positions in prefix pad mask
pad_mask_prefix = None if pad_mask is None else rearrange(pad_mask_prefix[keep_mask], "(b n) -> b n", b=b)

frq_pos_enc_k = repeat(frq_pos_enc_k, "... -> b ...", b=b)
frq_pos_enc_k_latent = frq_pos_enc_k[:, n_prefix:]
frq_pos_enc_prefix = frq_pos_enc_k[:, :n_prefix]
frq_pos_enc_prefix = rearrange(frq_pos_enc_prefix[keep_mask], "(b n) c -> b n c", b=b)
frq_pos_enc_q = frq_pos_enc_latent
frq_pos_enc_k = torch.cat([frq_pos_enc_prefix, frq_pos_enc_latent], dim=1)

frq_pos_enc_k = torch.cat((frq_pos_enc_prefix, frq_pos_enc_k_latent), dim=1)
frq_pos_enc_k = rearrange(frq_pos_enc_k, "b n c -> b 1 n c")
pad_mask_cross_attend = None if pad_mask is None else torch.cat([pad_mask_prefix, pad_mask_latent], dim=1)
pad_mask_self_attend = None if pad_mask is None else pad_mask_latent

x_latent = self.cross_attention(
x_latent,
x_kv_prefix=x_prefix,
pad_mask=pad_mask_cross_attend,
rot_pos_emb_q=RotaryPositionEmbedding(frq_pos_enc_q, right_align=True),
rot_pos_emb_k=RotaryPositionEmbedding(frq_pos_enc_k, right_align=True),
)

x_latent = self.self_attention(x_latent, rot_pos_emb=RotaryPositionEmbedding(frq_pos_enc, right_align=True))
x_logits = self.output_adapter(x_latent)
x_latent = self.self_attention(
x_latent,
pad_mask=pad_mask_self_attend,
rot_pos_emb=RotaryPositionEmbedding(frq_pos_enc_latent, right_align=True),
)

return x_logits
return self.output_adapter(x_latent)
29 changes: 18 additions & 11 deletions perceiver/model/core/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@
from einops import rearrange, repeat


def positions(b, n, shift: Optional[torch.Tensor] = None, device: Optional[torch.device] = None):
pos = repeat(torch.arange(n, device=device), "n -> b n", b=b)

if shift is not None:
if shift.shape != (b, 1):
raise ValueError(f"shift must have shape {(1, b)} but has shape {shift.shape}")
pos = pos - shift

return torch.clamp(pos, min=0)


class RotaryPositionEmbedding:
# Specified in https://arxiv.org/abs/2104.09864
# Modified from https://github.com/lucidrains/rotary-embedding-torch
def __init__(self, frq_pos_enc: torch.Tensor, right_align: bool = False):
# frq_pos_enc shape is either (n, c) or (b, 1, n, c).
# frq_pos_enc shape is (b, n, c).
# frq_pos_enc is broadcast to (b, h, n, c).
self.frq_pos_enc = frq_pos_enc
self.frq_pos_enc = rearrange(frq_pos_enc, "b n c -> b 1 n c")
self.rotate_dim = frq_pos_enc.shape[-1]
self.right_align = right_align

Expand Down Expand Up @@ -53,15 +64,11 @@ def __init__(self, dim):
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)

def forward(self, seq_len):
# positions [0, 1, ..., seq_len -1]
pos = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
# outer product of positions and inverse frequencies
pos_enc = torch.einsum("p, f -> p f", pos, self.inv_freq)
# for a single position p: [pf_1, pf_2, ..., pf_dim/2] -> [pf_1, pf1, pf_2, pf_2..., pf_dim/2, pf_dim/2]
pos_enc = repeat(pos_enc, "... pf -> ... (pf r)", r=2)
# pos_enc.shape == (seq_len, dim)
return pos_enc
def forward(self, abs_pos):
# outer product of absolute positions and inverse frequencies
pos_enc = torch.einsum("b n, f -> b n f", abs_pos, self.inv_freq)
# frequency position encodings (per example in batch b)
return repeat(pos_enc, "... pf -> ... (pf r)", r=2)


class FourierPositionEncoding(nn.Module):
Expand Down
Loading