Generate: basic token streaming#22449
Conversation
sgugger
left a comment
There was a problem hiding this comment.
I am not a fan of this API at all. Why is there a need for the TextStreamer to spawn a new process? The put method could directly call the print statement.
|
The documentation is not available anymore as the PR was closed or merged. |
|
@sgugger revised with the simpler implementation (no context manager nor multiprocessing) 🤗 |
sgugger
left a comment
There was a problem hiding this comment.
Better this way, thanks!
|
Just a FYI: I have been doing this using class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return FalseThe callback is then used to create an iterator with the Iteratorize class here: https://github.com/oobabooga/text-generation-webui/blob/main/modules/callbacks.py#L42 Usage becomes: def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
with torch.no_grad():
shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
with generate_with_streaming(**generate_params) as generator:
for output in generator: |
|
@oobabooga 🧠 That's a smart (and unexpected!) use of the stopping criteria. I'm going to work on a standardized Gradio solution today, and a Queue+iterator was indeed my plan. If you don't mind, I will take inspiration in your code 💛 A question regarding your implementation -- you use a separate thread in the |
Feel free to copy anything you want.
Honestly, I have no specific reason to give. I just spent several days trying to get the text generation to run in the background independently of where the |
* haha tokens go brrrr
* haha tokens go brrrr
What does this PR do?
Adds token streaming to
.generate()🎉Why now?
I want to showcase and communicate how much faster assisted generation can be... and for that, I need token streaming :D Non-image/video results have a much lower impact.
What's being added
This PR adds a
streamerinput to generate. If it is non-None, generate will callstreamer.put(new_tokens)as they are being generated.streamercan, therefore, be a wide array of things. This PR adds the simplest case: print tokens as they are generated.At first, I thought of adding a simpler
stream=Trueoption. However, the tokenizer would have to be passed into.generate(), which we have been avoiding, and it wouldn't be nearly as flexible. I've made the call to make streaming+.generate()flexible, and to keep it simple at apipelinelevel.If this PR gets accepted
The plan is to:
stream=Trueflag to startHow does it look
Here's an example. Note that it is running on CPU, so we can actually see the streaming effect (3090 is too fast 😅 ). On GPU it also streams, but much faster 🔥
Screen.Recording.2023-03-29.at.16.39.55.mov