diff --git a/fairseq/search.py b/fairseq/search.py index ecb4764a82..2c21b66bbd 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -10,7 +10,11 @@ import torch.nn as nn from torch import Tensor -from fairseq.token_generation_constraints import ConstraintState, UnorderedConstraintState, OrderedConstraintState +from fairseq.token_generation_constraints import ( + ConstraintState, + UnorderedConstraintState, + OrderedConstraintState, +) class Search(nn.Module): @@ -22,8 +26,11 @@ def __init__(self, tgt_dict): self.vocab_size = len(tgt_dict) self.src_lengths = torch.tensor(-1) self.supports_constraints = False + self.stop_on_max_len = False - def step(self, step, lprobs, scores): + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): """Take a single search step. Args: @@ -32,6 +39,12 @@ def step(self, step, lprobs, scores): the model's log-probabilities over the vocabulary at the current step scores: (bsz x input_beam_size x step) the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices Return: A tuple of (scores, indices, beams) where: scores: (bsz x output_beam_size) @@ -94,7 +107,14 @@ def __init__(self, tgt_dict): self.constraint_states = None @torch.jit.export - def step(self, step: int, lprobs, scores: Optional[Tensor]): + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if step == 0: @@ -125,6 +145,69 @@ def step(self, step: int, lprobs, scores: Optional[Tensor]): return scores_buf, indices_buf, beams_buf +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + class LexicallyConstrainedBeamSearch(Search): """Implements lexically constrained beam search as described in @@ -143,6 +226,7 @@ class LexicallyConstrainedBeamSearch(Search): constraints have been generated and using this information to shape the beam for each input sentence. """ + def __init__(self, tgt_dict, representation): super().__init__(tgt_dict) self.representation = representation @@ -163,17 +247,28 @@ def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): @torch.jit.export def prune_sentences(self, batch_idxs: Tensor): - self.constraint_states = [self.constraint_states[i] for i in batch_idxs.tolist()] + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] @torch.jit.export def update_constraints(self, active_hypos: Tensor): if self.constraint_states: batch_size = active_hypos.size(0) for sentid in range(batch_size): - self.constraint_states[sentid] = [self.constraint_states[sentid][i] for i in active_hypos[sentid]] + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] @torch.jit.export - def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): """ A constrained step builds a large candidates list from the following: - the top 2 * {beam_size} items over the whole beam @@ -222,7 +317,9 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): not_finished_indices.append(index) not_finished_indices = torch.tensor(not_finished_indices) if not_finished_indices.numel() > 0: - lprobs.view(batch_size * beam_size, -1)[not_finished_indices, self.eos] = -math.inf + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf if step == 0: # at the first step all hypotheses are equally likely, so use @@ -265,13 +362,15 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() for sentno, states in enumerate(constraint_states): - scores, indices, beams, new_states = self.step_sentence(step, - sentno, - lprobs[sentno], - constraint_states[sentno], - beams_buf[sentno].clone(), - indices_buf[sentno].clone(), - scores_buf[sentno].clone()) + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) new_scores_buf[sentno] = scores new_indices_buf[sentno] = indices new_beams_buf[sentno] = beams @@ -280,14 +379,16 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): return new_scores_buf, new_indices_buf, new_beams_buf @torch.jit.export - def step_sentence(self, - step: int, - sentno: int, - lprobs: Tensor, - constraint_states: List[List[ConstraintState]], - beams_buf: Tensor, - indices_buf: Tensor, - scores_buf: Tensor): + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): """Does per-sentence processing. Adds all constraints for each hypothesis to the list of candidates; then removes duplicates, sorts, and dynamically stripes across the banks. All tensor inputs @@ -300,7 +401,11 @@ def step_sentence(self, next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() if next_tokens.numel() != 0: indices_buf = torch.cat((indices_buf, next_tokens)) - next_beams = torch.tensor(beamno, device=device).repeat(next_tokens.size(0)).long() + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) beams_buf = torch.cat((beams_buf, next_beams)) next_values = lprobs[beamno].take(next_tokens.view(-1)) scores_buf = torch.cat((scores_buf, next_values)) @@ -320,8 +425,10 @@ def step_sentence(self, # Compute the new states for all candidates cands_size = indices_buf.size(0) - constraint_states = [constraint_states[beams_buf[i]].advance(indices_buf[i]) - for i in range(cands_size)] + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] banks = torch.tensor([state.bank for state in constraint_states], device=device) @@ -357,7 +464,7 @@ def roll(t): # This is then shifted by 1. We can then easily identify # duplicates and create a mask that identifies unique # extensions. - uniques_mask = (beams_buf * (self.vocab_size + 1) + indices_buf) + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf uniques_mask = roll(uniques_mask) != uniques_mask # Use the mask to pare down the data structures @@ -410,9 +517,9 @@ def roll(t): constraint_states = [constraint_states[i] for i in sort_indices] # STEP 8: Truncate to the candidates size! - scores_buf = scores_buf[:self.num_cands] - indices_buf = indices_buf[:self.num_cands] - beams_buf = beams_buf[:self.num_cands] + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] return scores_buf, indices_buf, beams_buf, constraint_states @@ -427,7 +534,14 @@ def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): self.beam = BeamSearch(tgt_dict) self.needs_src_lengths = True - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): min_lens = self.min_len_a * self.src_lengths + self.min_len_b max_lens = self.max_len_a * self.src_lengths + self.max_len_b lprobs[step < min_lens, :, self.eos] = -math.inf @@ -452,7 +566,14 @@ def __init__(self, tgt_dict, num_groups, diversity_strength): self.beam = BeamSearch(tgt_dict) @torch.jit.export - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if beam_size % self.num_groups != 0: raise ValueError( @@ -553,7 +674,14 @@ def _sample_topp(self, lprobs): return trimed_probs, truncated_indices @torch.jit.export - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if step == 0: @@ -576,7 +704,9 @@ def step(self, step: int, lprobs, scores): # sample if step == 0: indices_buf = torch.multinomial( - probs.view(bsz, -1), beam_size, replacement=True, + probs.view(bsz, -1), + beam_size, + replacement=True, ).view(bsz, beam_size) else: indices_buf = torch.multinomial( @@ -590,9 +720,7 @@ def step(self, step: int, lprobs, scores): probs = probs.expand(bsz, beam_size, -1) # gather scores - scores_buf = torch.gather( - probs, dim=2, index=indices_buf.unsqueeze(-1) - ) + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) scores_buf = scores_buf.log_().view(bsz, -1) # remap indices if using top-k or top-P sampling @@ -635,7 +763,14 @@ def __init__(self, tgt_dict, diversity_rate): self.diversity_rate = diversity_rate self.beam = BeamSearch(tgt_dict) - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() k = min( # Take the best 2 x beam_size predictions. We'll choose the first diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 965594cd6e..9ef90b1490 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -269,6 +269,13 @@ def _generate( reorder_state: Optional[Tensor] = None batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams # print(f'step: {step}') @@ -281,6 +288,7 @@ def _generate( reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size ) + original_batch_idxs = original_batch_idxs[batch_idxs] self.model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state @@ -342,8 +350,10 @@ def _generate( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, ) - + # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] @@ -385,8 +395,10 @@ def _generate( assert num_remaining_sent >= 0 if num_remaining_sent == 0: break + if self.search.stop_on_max_len and step >= max_len: + break assert step < max_len - + # Remove finalized sentences (ones for which {beam_size} # finished hypotheses have been generated) from the batch. if len(finalized_sents) > 0: diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 2aa6c8ff28..1ce4ab1921 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -313,6 +313,7 @@ def build_generator( match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if ( sum( int(cond) @@ -356,6 +357,10 @@ def build_generator( search_strategy = search.LexicallyConstrainedBeamSearch( self.target_dictionary, args.constraints ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) else: search_strategy = search.BeamSearch(self.target_dictionary)