Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 172 additions & 37 deletions fairseq/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch.nn as nn
from torch import Tensor

from fairseq.token_generation_constraints import ConstraintState, UnorderedConstraintState, OrderedConstraintState
from fairseq.token_generation_constraints import (
ConstraintState,
UnorderedConstraintState,
OrderedConstraintState,
)


class Search(nn.Module):
Expand All @@ -22,8 +26,11 @@ def __init__(self, tgt_dict):
self.vocab_size = len(tgt_dict)
self.src_lengths = torch.tensor(-1)
self.supports_constraints = False
self.stop_on_max_len = False

def step(self, step, lprobs, scores):
def step(
self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None
):
"""Take a single search step.

Args:
Expand All @@ -32,6 +39,12 @@ def step(self, step, lprobs, scores):
the model's log-probabilities over the vocabulary at the current step
scores: (bsz x input_beam_size x step)
the historical model scores of each hypothesis up to this point
prev_output_tokens: (bsz x step)
the previously generated oputput tokens
original_batch_idxs: (bsz)
the tensor with the batch indices, in the range [0, bsz)
this is useful in case there has been applied a re-ordering
and we need to know the orignal indices

Return: A tuple of (scores, indices, beams) where:
scores: (bsz x output_beam_size)
Expand Down Expand Up @@ -94,7 +107,14 @@ def __init__(self, tgt_dict):
self.constraint_states = None

@torch.jit.export
def step(self, step: int, lprobs, scores: Optional[Tensor]):
def step(
self,
step: int,
lprobs,
scores: Optional[Tensor],
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
bsz, beam_size, vocab_size = lprobs.size()

if step == 0:
Expand Down Expand Up @@ -125,6 +145,69 @@ def step(self, step: int, lprobs, scores: Optional[Tensor]):
return scores_buf, indices_buf, beams_buf


class PrefixConstrainedBeamSearch(Search):
def __init__(self, tgt_dict, prefix_allowed_tokens_fn):
super().__init__(tgt_dict)
self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self.stop_on_max_len = True

@torch.jit.export
def apply_mask(self, x, prev_output_tokens, original_batch_idxs):
beam_size = x.shape[0] // original_batch_idxs.shape[0]
original_batch_idxs = (
original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist()
)

mask = torch.full_like(x, -math.inf)
for sent_i, (sent, batch_i) in enumerate(
zip(prev_output_tokens, original_batch_idxs)
):
mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0

return mask

@torch.jit.export
def step(
self,
step: int,
lprobs: Tensor,
scores: Tensor,
prev_output_tokens: Tensor,
original_batch_idxs: Tensor,
):
bsz, beam_size, vocab_size = lprobs.size()

lprobs += self.apply_mask(
lprobs.view(bsz * beam_size, 1, vocab_size),
prev_output_tokens,
original_batch_idxs,
).view(bsz, beam_size, vocab_size)

if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
assert scores is not None
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)

top_prediction = torch.topk(
lprobs.view(bsz, -1),
k=min(
# Take the best beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
)
scores_buf = top_prediction[0]
indices_buf = top_prediction[1]
beams_buf = indices_buf // vocab_size
indices_buf = indices_buf.fmod(vocab_size)
return scores_buf, indices_buf, beams_buf


class LexicallyConstrainedBeamSearch(Search):
"""Implements lexically constrained beam search as described in

Expand All @@ -143,6 +226,7 @@ class LexicallyConstrainedBeamSearch(Search):
constraints have been generated and using this information to
shape the beam for each input sentence.
"""

def __init__(self, tgt_dict, representation):
super().__init__(tgt_dict)
self.representation = representation
Expand All @@ -163,17 +247,28 @@ def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):

@torch.jit.export
def prune_sentences(self, batch_idxs: Tensor):
self.constraint_states = [self.constraint_states[i] for i in batch_idxs.tolist()]
self.constraint_states = [
self.constraint_states[i] for i in batch_idxs.tolist()
]

@torch.jit.export
def update_constraints(self, active_hypos: Tensor):
if self.constraint_states:
batch_size = active_hypos.size(0)
for sentid in range(batch_size):
self.constraint_states[sentid] = [self.constraint_states[sentid][i] for i in active_hypos[sentid]]
self.constraint_states[sentid] = [
self.constraint_states[sentid][i] for i in active_hypos[sentid]
]

@torch.jit.export
def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]):
def step(
self,
step: int,
lprobs: Tensor,
scores: Optional[Tensor],
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
"""
A constrained step builds a large candidates list from the following:
- the top 2 * {beam_size} items over the whole beam
Expand Down Expand Up @@ -222,7 +317,9 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]):
not_finished_indices.append(index)
not_finished_indices = torch.tensor(not_finished_indices)
if not_finished_indices.numel() > 0:
lprobs.view(batch_size * beam_size, -1)[not_finished_indices, self.eos] = -math.inf
lprobs.view(batch_size * beam_size, -1)[
not_finished_indices, self.eos
] = -math.inf

if step == 0:
# at the first step all hypotheses are equally likely, so use
Expand Down Expand Up @@ -265,13 +362,15 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]):
new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
for sentno, states in enumerate(constraint_states):
scores, indices, beams, new_states = self.step_sentence(step,
sentno,
lprobs[sentno],
constraint_states[sentno],
beams_buf[sentno].clone(),
indices_buf[sentno].clone(),
scores_buf[sentno].clone())
scores, indices, beams, new_states = self.step_sentence(
step,
sentno,
lprobs[sentno],
constraint_states[sentno],
beams_buf[sentno].clone(),
indices_buf[sentno].clone(),
scores_buf[sentno].clone(),
)
new_scores_buf[sentno] = scores
new_indices_buf[sentno] = indices
new_beams_buf[sentno] = beams
Expand All @@ -280,14 +379,16 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]):
return new_scores_buf, new_indices_buf, new_beams_buf

@torch.jit.export
def step_sentence(self,
step: int,
sentno: int,
lprobs: Tensor,
constraint_states: List[List[ConstraintState]],
beams_buf: Tensor,
indices_buf: Tensor,
scores_buf: Tensor):
def step_sentence(
self,
step: int,
sentno: int,
lprobs: Tensor,
constraint_states: List[List[ConstraintState]],
beams_buf: Tensor,
indices_buf: Tensor,
scores_buf: Tensor,
):
"""Does per-sentence processing. Adds all constraints for each
hypothesis to the list of candidates; then removes duplicates,
sorts, and dynamically stripes across the banks. All tensor inputs
Expand All @@ -300,7 +401,11 @@ def step_sentence(self,
next_tokens = torch.tensor(list(state.next_tokens()), device=device).long()
if next_tokens.numel() != 0:
indices_buf = torch.cat((indices_buf, next_tokens))
next_beams = torch.tensor(beamno, device=device).repeat(next_tokens.size(0)).long()
next_beams = (
torch.tensor(beamno, device=device)
.repeat(next_tokens.size(0))
.long()
)
beams_buf = torch.cat((beams_buf, next_beams))
next_values = lprobs[beamno].take(next_tokens.view(-1))
scores_buf = torch.cat((scores_buf, next_values))
Expand All @@ -320,8 +425,10 @@ def step_sentence(self,

# Compute the new states for all candidates
cands_size = indices_buf.size(0)
constraint_states = [constraint_states[beams_buf[i]].advance(indices_buf[i])
for i in range(cands_size)]
constraint_states = [
constraint_states[beams_buf[i]].advance(indices_buf[i])
for i in range(cands_size)
]

banks = torch.tensor([state.bank for state in constraint_states], device=device)

Expand Down Expand Up @@ -357,7 +464,7 @@ def roll(t):
# This is then shifted by 1. We can then easily identify
# duplicates and create a mask that identifies unique
# extensions.
uniques_mask = (beams_buf * (self.vocab_size + 1) + indices_buf)
uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf
uniques_mask = roll(uniques_mask) != uniques_mask

# Use the mask to pare down the data structures
Expand Down Expand Up @@ -410,9 +517,9 @@ def roll(t):
constraint_states = [constraint_states[i] for i in sort_indices]

# STEP 8: Truncate to the candidates size!
scores_buf = scores_buf[:self.num_cands]
indices_buf = indices_buf[:self.num_cands]
beams_buf = beams_buf[:self.num_cands]
scores_buf = scores_buf[: self.num_cands]
indices_buf = indices_buf[: self.num_cands]
beams_buf = beams_buf[: self.num_cands]

return scores_buf, indices_buf, beams_buf, constraint_states

Expand All @@ -427,7 +534,14 @@ def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
self.beam = BeamSearch(tgt_dict)
self.needs_src_lengths = True

def step(self, step: int, lprobs, scores):
def step(
self,
step: int,
lprobs,
scores,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
min_lens = self.min_len_a * self.src_lengths + self.min_len_b
max_lens = self.max_len_a * self.src_lengths + self.max_len_b
lprobs[step < min_lens, :, self.eos] = -math.inf
Expand All @@ -452,7 +566,14 @@ def __init__(self, tgt_dict, num_groups, diversity_strength):
self.beam = BeamSearch(tgt_dict)

@torch.jit.export
def step(self, step: int, lprobs, scores):
def step(
self,
step: int,
lprobs,
scores,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
bsz, beam_size, vocab_size = lprobs.size()
if beam_size % self.num_groups != 0:
raise ValueError(
Expand Down Expand Up @@ -553,7 +674,14 @@ def _sample_topp(self, lprobs):
return trimed_probs, truncated_indices

@torch.jit.export
def step(self, step: int, lprobs, scores):
def step(
self,
step: int,
lprobs,
scores,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
bsz, beam_size, vocab_size = lprobs.size()

if step == 0:
Expand All @@ -576,7 +704,9 @@ def step(self, step: int, lprobs, scores):
# sample
if step == 0:
indices_buf = torch.multinomial(
probs.view(bsz, -1), beam_size, replacement=True,
probs.view(bsz, -1),
beam_size,
replacement=True,
).view(bsz, beam_size)
else:
indices_buf = torch.multinomial(
Expand All @@ -590,9 +720,7 @@ def step(self, step: int, lprobs, scores):
probs = probs.expand(bsz, beam_size, -1)

# gather scores
scores_buf = torch.gather(
probs, dim=2, index=indices_buf.unsqueeze(-1)
)
scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1))
scores_buf = scores_buf.log_().view(bsz, -1)

# remap indices if using top-k or top-P sampling
Expand Down Expand Up @@ -635,7 +763,14 @@ def __init__(self, tgt_dict, diversity_rate):
self.diversity_rate = diversity_rate
self.beam = BeamSearch(tgt_dict)

def step(self, step: int, lprobs, scores):
def step(
self,
step: int,
lprobs,
scores,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
bsz, beam_size, vocab_size = lprobs.size()
k = min(
# Take the best 2 x beam_size predictions. We'll choose the first
Expand Down
Loading