diff --git a/crates/llm-chain-llama/examples/stream.rs b/crates/llm-chain-llama/examples/stream.rs index 8ff95148..0274e730 100644 --- a/crates/llm-chain-llama/examples/stream.rs +++ b/crates/llm-chain-llama/examples/stream.rs @@ -1,6 +1,7 @@ +use llm_chain::output::StreamExt; use llm_chain::{executor, parameters, prompt}; -/// This example demonstrates how to use the llm-chain-llama crate to generate text using a +/// This example demonstrates how to use the llm-chain-llama crate to generate streaming text using a /// LLaMA model. /// /// Usage: cargo run --example simple path/to/llama-or-alpaca-model @@ -14,6 +15,10 @@ async fn main() -> Result<(), Box> { let res = prompt!("The Colors of the Rainbow are (in order): ") .run(¶meters!(), &exec) .await?; - println!("{}", res.to_immediate().await?); + let mut stream = res.as_stream().await?; + while let Some(v) = stream.next().await { + print!("{}", v); + } + Ok(()) } diff --git a/crates/llm-chain-openai/examples/simple_sequential_generation_stream.rs b/crates/llm-chain-openai/examples/simple_sequential_generation_stream.rs index 6b70522b..7a6ea660 100644 --- a/crates/llm-chain-openai/examples/simple_sequential_generation_stream.rs +++ b/crates/llm-chain-openai/examples/simple_sequential_generation_stream.rs @@ -1,4 +1,4 @@ -use llm_chain::{chains::sequential::Chain, executor, prompt, step::Step}; +use llm_chain::{chains::sequential::Chain, executor, output::StreamExt, prompt, step::Step}; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { @@ -31,8 +31,10 @@ async fn main() -> Result<(), Box> { // Print the result to the console // Call `res.primary_textual_output()` explictly to get the streamed response. - println!("{:?}", res.to_immediate().await?.as_content()); - + let mut stream = res.as_stream().await?; + while let Some(v) = stream.next().await { + print!("{}", v); + } // Call `res.as_stream()` to access the Stream without polling. Ok(()) } diff --git a/crates/llm-chain/src/output/mod.rs b/crates/llm-chain/src/output/mod.rs index adc2717b..fa6698c2 100644 --- a/crates/llm-chain/src/output/mod.rs +++ b/crates/llm-chain/src/output/mod.rs @@ -3,10 +3,11 @@ mod stream; use core::fmt; use crate::{prompt::Data, traits::ExecutorError}; -use futures::Stream; +use thiserror; use tokio::sync::mpsc; pub use stream::{OutputStream, StreamSegment}; +pub use tokio_stream::{Stream, StreamExt}; /// The `Output` enum provides a general interface for outputs of different types. /// The `Immediate` variant represents data that is immediately available, while the `Stream` variant @@ -19,6 +20,10 @@ pub enum Output { Stream(OutputStream), } +#[derive(Debug, thiserror::Error)] +#[error("Trying to return a stream on an Immediate output")] +pub struct NotAStreamError; + impl Output { /// Converts the `Output` to its `Immediate` form. /// If the output is `Stream`, it will be consumed and turned into an `Immediate` output. @@ -30,6 +35,15 @@ impl Output { } } + /// Given that the Output is a stream, return a OutputStream + /// If the output is `Immediate` NotAStreamError will be raised + pub async fn as_stream(self) -> Result { + match self { + Output::Immediate(_) => Err(NotAStreamError), + Output::Stream(x) => Ok(x), + } + } + /// Creates a new `Stream` output along with a sender to produce data. pub fn new_stream() -> (mpsc::Sender, Self) { let (sender, stream) = OutputStream::new(); diff --git a/crates/llm-chain/src/output/stream.rs b/crates/llm-chain/src/output/stream.rs index 75177a01..fa778805 100644 --- a/crates/llm-chain/src/output/stream.rs +++ b/crates/llm-chain/src/output/stream.rs @@ -1,18 +1,30 @@ use crate::prompt::{ChatRole, Data}; use crate::traits::ExecutorError; use futures::StreamExt; +use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::sync::mpsc::{self, Receiver}; use tokio_stream::Stream; use crate::prompt::{ChatMessage, ChatMessageCollection}; +#[derive(Debug)] pub enum StreamSegment { Role(ChatRole), Content(String), Err(ExecutorError), } +impl fmt::Display for StreamSegment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StreamSegment::Role(chat_role) => write!(f, "{}", chat_role), + StreamSegment::Content(content) => write!(f, "{}", content), + StreamSegment::Err(executor_error) => write!(f, "{}", executor_error), + } + } +} + pub struct OutputStream { receiver: Receiver, } @@ -85,6 +97,6 @@ impl Stream for OutputStream { type Item = StreamSegment; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.receiver).poll_recv(cx) + self.receiver.poll_recv(cx) } }