diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 80da7886dccf..76788c5e4224 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -350,9 +350,9 @@ def test_streaming_request(self) -> None: messages = [{"content": "What is the Transformers library known for?", "role": "user"}] - inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to( - model.device - )[0] + inputs = tokenizer.apply_chat_template( + messages, return_tensors="pt", add_generation_prompt=True, return_dict=False + ).to(model.device)[0] request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True) @@ -382,9 +382,9 @@ def test_non_streaming_request(self) -> None: messages = [{"content": "What is the Transformers library known for?", "role": "user"}] - inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to( - model.device - )[0] + inputs = tokenizer.apply_chat_template( + messages, return_tensors="pt", add_generation_prompt=True, return_dict=False + ).to(model.device)[0] request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False) @@ -409,9 +409,9 @@ def test_streaming_and_non_streaming_requests_can_alternate(self) -> None: messages = [{"content": "What is the Transformers library known for?", "role": "user"}] - inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to( - model.device - )[0] + inputs = tokenizer.apply_chat_template( + messages, return_tensors="pt", add_generation_prompt=True, return_dict=False + ).to(model.device)[0] # Non-streaming request request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)