Skip to content

Commit

Permalink
Parameterize generate method with top_k directly instead of `thre…
Browse files Browse the repository at this point in the history
…shold`.

- closes #38
  • Loading branch information
krasserm committed Feb 24, 2023
1 parent f807b2e commit 49e1d92
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
1 change: 0 additions & 1 deletion perceiver/model/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class PerceiverARConfig:
cross_attention_widening_factor: int = 4
cross_attention_dropout: float = 0.5
post_attention_dropout: float = 0.0
init_scale: float = 0.02
activation_checkpointing: bool = False
activation_offloading: bool = False

Expand Down
8 changes: 1 addition & 7 deletions perceiver/model/core/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,10 @@ def __init__(
cross_attention_widening_factor: int = 4,
cross_attention_dropout: float = 0.5,
post_attention_dropout: float = 0.0,
init_scale: float = 0.02,
activation_checkpointing: bool = False,
activation_offloading: bool = False,
):
"""Experimental implementation of Perceiver AR (https://arxiv.org/abs/2202.07765).
"""Implementation of Perceiver AR (https://arxiv.org/abs/2202.07765).
:param input_adapter: Transforms an input sequence to generic Perceiver AR input. An input adapter may choose
to add (absolute) position information to transformed inputs while `PerceiverAR` additionally computes a
Expand All @@ -614,7 +613,6 @@ def __init__(
:param cross_attention_dropout: Probability of dropping positions in the prefix sequence.
:param post_attention_dropout: Probability of dropping cross- and self-attention scores (same as `dropout`
in Perceiver IO encoder and decoder).
:param init_scale: Standard deviation for random normal initialization of parameters.
:param activation_checkpointing: If True, implements an activation checkpoint for each self-attention
layer and cross-attention layer.
:param activation_offloading: If True, offloads checkpointed activations to CPU.
Expand Down Expand Up @@ -659,10 +657,6 @@ def self_attn():
self.cross_attention = cross_attn()
self.self_attention = self_attn()

def _init_parameters(self, init_scale: float):
with torch.no_grad():
init_parameters(self, init_scale)

def forward(self, x, prefix_len, pad_mask=None):
b, n = x.shape

Expand Down
20 changes: 13 additions & 7 deletions perceiver/model/text/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

from perceiver.model.core import OutputAdapter, PerceiverAR, PerceiverARConfig, RotarySupport
from perceiver.model.core.utils import init_parameters
from perceiver.model.text import common


Expand All @@ -21,6 +22,7 @@ class CausalLanguageModelConfig(PerceiverARConfig):
max_latents: int = 512
num_channels: int = 512
output_norm: bool = False
init_scale: float = 0.02

@classmethod
def create(cls, **kwargs):
Expand Down Expand Up @@ -62,6 +64,10 @@ def __init__(self, config: CausalLanguageModelConfig):
self.output_adapter = common.TiedTextOutputAdapter(vocab_size=config.vocab_size)
self._init_parameters(config.init_scale)

def _init_parameters(self, init_scale: float):
with torch.no_grad():
init_parameters(self, init_scale)

@property
def max_seq_len(self):
return self.input_adapter.max_seq_len
Expand Down Expand Up @@ -92,17 +98,19 @@ def generate(
pad_mask: Optional[torch.Tensor] = None,
num_tokens: int = 512,
num_latents: int = 1,
threshold: float = 0.9,
top_k: int = 5,
temperature: float = 1.0,
pbar: bool = True,
):
"""Generate sequence from `prompt` via top-k sampling (with k determined by `threshold`) at given
`temperature`.
"""Generate sequence from `prompt` via `top-k` sampling at given `temperature`.
:param prompt: Prompt of shape (B, N). If sequences have different length they must be left-padded.
:param pad_mask: Prompt pad mask of shape (B, N). Must be supplied if prompt contains pad tokens.
:param num_tokens: Number of tokens to generate.
:param num_latents: Initial number of latent positions.
:param top_k: Number of most likely tokens to sample from.
:param temperature: "Controls the entropy of next token probabilities."
:param pbar: If `True`, uses a progress bar during generation.
"""

n_init = prompt.shape[1]
Expand All @@ -127,7 +135,7 @@ def generate(
prefix_len += 1

logits = self(result[:, -self.max_seq_len :], prefix_len=prefix_len, pad_mask=result_pad_mask)[:, -1]
logits = self.top_f(logits, fraction=1 - threshold)
logits = self.top_k(logits, top_k)

probs = F.softmax(logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
Expand All @@ -141,9 +149,7 @@ def generate(
return result[:, n_init:]

@staticmethod
def top_f(logits: torch.Tensor, fraction: float = 0.1):
"""Keep the highest `fraction` of `logits` and set others to `-inf`."""
k = int(fraction * logits.shape[-1])
def top_k(logits: torch.Tensor, k: int):
val, idx = torch.topk(logits, k)
logits_top = torch.full_like(logits, float("-inf"))
logits_top.scatter_(1, idx, val)
Expand Down

0 comments on commit 49e1d92

Please sign in to comment.