diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 67b6719325c8..4f5f7f6b5b55 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -266,7 +266,13 @@ class TemperatureLogitsWarper(LogitsWarper): def __init__(self, temperature: float): if not isinstance(temperature, float) or not (temperature > 0): - raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + except_msg = ( + f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token " + "scores will be invalid." + ) + if isinstance(temperature, float) and temperature == 0.0: + except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`." + raise ValueError(except_msg) self.temperature = temperature