Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random_seed argument to generate #8162

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from lightning_fabric.utilities.seed import seed_everything

from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
Expand Down Expand Up @@ -272,7 +273,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'), started
This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313
conversational-ai-with-transfer-learning-2d818ac26313

@param logits: logits tensor
@param top_k: keep only top k tokens with highest probability
Expand Down Expand Up @@ -315,7 +316,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'), started


def repetition_penalty(logits, repetition_penalty, used_tokens):
""" Implement the repetition penalty, check paper
""" Implement the repetition penalty, check paper
https://arxiv.org/pdf/1909.05858.pdf
"""
if used_tokens is not None and repetition_penalty != 1.0:
Expand Down Expand Up @@ -547,6 +548,7 @@ def generate(
end_strings=['<|endoftext|>'],
image_list=None,
min_tokens_to_generate=0,
random_seed=None,
**strategy_args,
) -> OutputType:
"""
Expand All @@ -562,6 +564,8 @@ def generate(
greedy (bool): Whether or not to use sampling ; use greedy decoding otherwise
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty
min_tokens_to_generate (int): The minimum length of the tokens to be generated
random_seed (int): can set to fix random seed for reproducibility. If None, we do not set random seed, so
the behavior of generation will depend on whether the seed was set earlier or not.
strategy_args, the extra arguments are treated as inference strategy arguments
end_strings, a list of strings to stop generation when they are encountered in the output.
Returns:
Expand All @@ -573,6 +577,9 @@ def generate(
token_ids: List[Tensor], output sentence token ids
offsets: List[List[int]] # list of tokens start positions in text
"""
if random_seed is not None:
seed_everything(random_seed)

if 'strategy' in strategy_args:
inference_strategy = strategy_args['strategy']
else:
Expand Down Expand Up @@ -1058,7 +1065,7 @@ def sample_token_greedy(logits):

Args:
logits: [batch_size, vocab_size] - unnormalized log probabilities of the next token

Returns:
log_probs: [batch_size] - log probabilities of the sampled tokens
token_ids: [batch_size] - sampled token ids
Expand All @@ -1078,7 +1085,7 @@ def sample_token_topk(logits, top_k=0, top_p=0.0, temperature=1.0, filter_value=
top_p: float - if > 0.0: only sample from a subset of candidates, where the cumulative probability
temperature: float - temperature for sampling
filter_value: float - value to set filtered tokens to

Returns:
log_probs: [batch_size] - log probabilities of the sampled tokens
token_ids: [batch_size] - sampled token ids
Expand Down
Loading