Skip to content

Llama-2-7b-chat-hf will not allow temperature to be 0.0 #687

@SleepingSkipper

Description

@SleepingSkipper

I am running the Llama-2-7b-chat-hf model on Huggingface.
When I set temperature=0.0 or temperature=0, I get
ValueError: temperature has to be a strictly positive float, but is 0.0.
Until a week ago, It was working with the same code and environment.

My code and error message;

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name="meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config, 
    trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit 
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate(text):
    prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Summarize following sentence in three lines.
### Input:
{text}
### Response:"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    input_ids.to(device)
    with torch.no_grad():
        outputs = model.generate(inputs=input_ids,
                                temperature=0.0,
                                max_new_tokens=500)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    
text = """FC Barcelona's Spanish defender Jordi Alba and Turkish midfielder Arda Turan have returned to full training, according to the Spanish newspaper Marca on March 28. J. Alba returned to full training after suffering an injury in the Copa del Rey match against Athletic Bilbao on March 17. Arda, who missed the match against Atletico Madrid on March 27 due to a high fever, has also returned to the squad and is now in good shape for the match against Atletico Madrid."""

generate(text)

>> 
ValueError                                Traceback (most recent call last)
Cell In[12], line 5
      2 input_ids.to(device)
      3 with torch.no_grad():
----> 5     outputs = model.generate(inputs=input_ids,
      6                               temperature=0.0,
      7                                 max_new_tokens=500)
      8 print(tokenizer.decode(outputs[0], skip_special_tokens=True))

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:1604, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1586     return self.contrastive_search(
   1587         input_ids,
   1588         top_k=generation_config.top_k,
   (...)
   1599         **model_kwargs,
   1600     )
   1602 elif is_sample_gen_mode:
   1603     # 11. prepare logits warper
-> 1604     logits_warper = self._get_logits_warper(generation_config)
   1606     # 12. expand input_ids with `num_return_sequences` additional sequences per batch
   1607     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1608         input_ids=input_ids,
   1609         expand_size=generation_config.num_return_sequences,
   1610         is_encoder_decoder=self.config.is_encoder_decoder,
   1611         **model_kwargs,
   1612     )

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:809, in GenerationMixin._get_logits_warper(self, generation_config)
    806 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
    807 # all samplers can be found in `generation_utils_samplers.py`
    808 if generation_config.temperature is not None and generation_config.temperature != 1.0:
--> 809     warpers.append(TemperatureLogitsWarper(generation_config.temperature))
    810 min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
    811 if generation_config.top_k is not None and generation_config.top_k != 0:

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/logits_process.py:231, in TemperatureLogitsWarper.__init__(self, temperature)
    229 def __init__(self, temperature: float):
    230     if not isinstance(temperature, float) or not (temperature > 0):
--> 231         raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
    233     self.temperature = temperature

ValueError: `temperature` has to be a strictly positive float, but is 0.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions