-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Closed
Description
I am running the Llama-2-7b-chat-hf model on Huggingface.
When I set temperature=0.0 or temperature=0, I get
ValueError: temperature has to be a strictly positive float, but is 0.0.
Until a week ago, It was working with the same code and environment.
My code and error message;
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_name="meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate(text):
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Summarize following sentence in three lines.
### Input:
{text}
### Response:"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_ids.to(device)
with torch.no_grad():
outputs = model.generate(inputs=input_ids,
temperature=0.0,
max_new_tokens=500)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
text = """FC Barcelona's Spanish defender Jordi Alba and Turkish midfielder Arda Turan have returned to full training, according to the Spanish newspaper Marca on March 28. J. Alba returned to full training after suffering an injury in the Copa del Rey match against Athletic Bilbao on March 17. Arda, who missed the match against Atletico Madrid on March 27 due to a high fever, has also returned to the squad and is now in good shape for the match against Atletico Madrid."""
generate(text)
>>
ValueError Traceback (most recent call last)
Cell In[12], line 5
2 input_ids.to(device)
3 with torch.no_grad():
----> 5 outputs = model.generate(inputs=input_ids,
6 temperature=0.0,
7 max_new_tokens=500)
8 print(tokenizer.decode(outputs[0], skip_special_tokens=True))
File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:1604, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1586 return self.contrastive_search(
1587 input_ids,
1588 top_k=generation_config.top_k,
(...)
1599 **model_kwargs,
1600 )
1602 elif is_sample_gen_mode:
1603 # 11. prepare logits warper
-> 1604 logits_warper = self._get_logits_warper(generation_config)
1606 # 12. expand input_ids with `num_return_sequences` additional sequences per batch
1607 input_ids, model_kwargs = self._expand_inputs_for_generation(
1608 input_ids=input_ids,
1609 expand_size=generation_config.num_return_sequences,
1610 is_encoder_decoder=self.config.is_encoder_decoder,
1611 **model_kwargs,
1612 )
File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:809, in GenerationMixin._get_logits_warper(self, generation_config)
806 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
807 # all samplers can be found in `generation_utils_samplers.py`
808 if generation_config.temperature is not None and generation_config.temperature != 1.0:
--> 809 warpers.append(TemperatureLogitsWarper(generation_config.temperature))
810 min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
811 if generation_config.top_k is not None and generation_config.top_k != 0:
File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/logits_process.py:231, in TemperatureLogitsWarper.__init__(self, temperature)
229 def __init__(self, temperature: float):
230 if not isinstance(temperature, float) or not (temperature > 0):
--> 231 raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
233 self.temperature = temperature
ValueError: `temperature` has to be a strictly positive float, but is 0.0rcanand, SleepingSkipper, yudhiesh and lfmatosm
Metadata
Metadata
Assignees
Labels
No labels