From ad56bfc1c763568e918b2f7f3aacf0d425d65517 Mon Sep 17 00:00:00 2001 From: pepperoni21 <pariselias00@gmail.com> Date: Fri, 17 Jan 2025 13:10:40 +0100 Subject: [PATCH 1/2] Support for borrowed strings in generation requests --- ollama-rs/examples/images_to_ollama.rs | 2 +- ollama-rs/src/generation/completion/mod.rs | 8 ++--- .../src/generation/completion/request.rs | 30 ++++++++++--------- ollama-rs/tests/generation.rs | 10 ++----- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/ollama-rs/examples/images_to_ollama.rs b/ollama-rs/examples/images_to_ollama.rs index 2b9203f..c4958e7 100644 --- a/ollama-rs/examples/images_to_ollama.rs +++ b/ollama-rs/examples/images_to_ollama.rs @@ -55,7 +55,7 @@ async fn download_image(url: &str) -> Result<Vec<u8>, reqwest::Error> { // Function to send the request to the model async fn send_request( - request: GenerationRequest, + request: GenerationRequest<'_>, ) -> Result<GenerationResponse, Box<dyn std::error::Error>> { let ollama = Ollama::default(); let response = ollama.generate(request).await?; diff --git a/ollama-rs/src/generation/completion/mod.rs b/ollama-rs/src/generation/completion/mod.rs index 4202c85..d01d2d1 100644 --- a/ollama-rs/src/generation/completion/mod.rs +++ b/ollama-rs/src/generation/completion/mod.rs @@ -21,9 +21,9 @@ impl Ollama { #[cfg(feature = "stream")] /// Completion generation with streaming. /// Returns a stream of `GenerationResponse` objects - pub async fn generate_stream( + pub async fn generate_stream<'a>( &self, - request: GenerationRequest, + request: GenerationRequest<'a>, ) -> crate::error::Result<GenerationResponseStream> { use tokio_stream::StreamExt; @@ -66,9 +66,9 @@ impl Ollama { /// Completion generation with a single response. /// Returns a single `GenerationResponse` object - pub async fn generate( + pub async fn generate<'a>( &self, - request: GenerationRequest, + request: GenerationRequest<'a>, ) -> crate::error::Result<GenerationResponse> { let mut request = request; request.stream = false; diff --git a/ollama-rs/src/generation/completion/request.rs b/ollama-rs/src/generation/completion/request.rs index d7605b3..67f0c13 100644 --- a/ollama-rs/src/generation/completion/request.rs +++ b/ollama-rs/src/generation/completion/request.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use serde::Serialize; use crate::generation::{ @@ -10,15 +12,15 @@ use super::GenerationContext; /// A generation request to Ollama. #[derive(Debug, Clone, Serialize)] -pub struct GenerationRequest { +pub struct GenerationRequest<'a> { #[serde(rename = "model")] pub model_name: String, - pub prompt: String, - pub suffix: Option<String>, + pub prompt: Cow<'a, str>, + pub suffix: Option<Cow<'a, str>>, pub images: Vec<Image>, pub options: Option<GenerationOptions>, - pub system: Option<String>, - pub template: Option<String>, + pub system: Option<Cow<'a, str>>, + pub template: Option<Cow<'a, str>>, pub context: Option<GenerationContext>, #[serde(skip_serializing_if = "Option::is_none")] pub format: Option<FormatType>, @@ -26,11 +28,11 @@ pub struct GenerationRequest { pub(crate) stream: bool, } -impl GenerationRequest { - pub fn new(model_name: String, prompt: String) -> Self { +impl<'a> GenerationRequest<'a> { + pub fn new(model_name: String, prompt: impl Into<Cow<'a, str>>) -> Self { Self { model_name, - prompt, + prompt: prompt.into(), suffix: None, images: Vec::new(), options: None, @@ -51,8 +53,8 @@ impl GenerationRequest { } /// Adds a text after the model response - pub fn suffix(mut self, suffix: String) -> Self { - self.suffix = Some(suffix); + pub fn suffix(mut self, suffix: impl Into<Cow<'a, str>>) -> Self { + self.suffix = Some(suffix.into()); self } @@ -75,14 +77,14 @@ impl GenerationRequest { } /// System prompt to (overrides what is defined in the Modelfile) - pub fn system(mut self, system: String) -> Self { - self.system = Some(system); + pub fn system(mut self, system: impl Into<Cow<'a, str>>) -> Self { + self.system = Some(system.into()); self } /// The full prompt or prompt template (overrides what is defined in the Modelfile) - pub fn template(mut self, template: String) -> Self { - self.template = Some(template); + pub fn template(mut self, template: impl Into<Cow<'a, str>>) -> Self { + self.template = Some(template.into()); self } diff --git a/ollama-rs/tests/generation.rs b/ollama-rs/tests/generation.rs index d35e94d..312a549 100644 --- a/ollama-rs/tests/generation.rs +++ b/ollama-rs/tests/generation.rs @@ -18,10 +18,7 @@ async fn test_generation_stream() { let ollama = Ollama::default(); let mut res: GenerationResponseStream = ollama - .generate_stream(GenerationRequest::new( - "llama2:latest".to_string(), - PROMPT.into(), - )) + .generate_stream(GenerationRequest::new("llama2:latest".to_string(), PROMPT)) .await .unwrap(); @@ -45,10 +42,7 @@ async fn test_generation() { let ollama = Ollama::default(); let res = ollama - .generate(GenerationRequest::new( - "llama2:latest".to_string(), - PROMPT.into(), - )) + .generate(GenerationRequest::new("llama2:latest".to_string(), PROMPT)) .await .unwrap(); dbg!(res); From 5f38fc3c27f9e3cf9d971542e6a9c911bc803cb8 Mon Sep 17 00:00:00 2001 From: pepperoni21 <pariselias00@gmail.com> Date: Fri, 17 Jan 2025 13:16:03 +0100 Subject: [PATCH 2/2] Fixed formatting --- ollama-rs/src/generation/completion/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ollama-rs/src/generation/completion/mod.rs b/ollama-rs/src/generation/completion/mod.rs index d01d2d1..e4039ef 100644 --- a/ollama-rs/src/generation/completion/mod.rs +++ b/ollama-rs/src/generation/completion/mod.rs @@ -21,9 +21,9 @@ impl Ollama { #[cfg(feature = "stream")] /// Completion generation with streaming. /// Returns a stream of `GenerationResponse` objects - pub async fn generate_stream<'a>( + pub async fn generate_stream( &self, - request: GenerationRequest<'a>, + request: GenerationRequest<'_>, ) -> crate::error::Result<GenerationResponseStream> { use tokio_stream::StreamExt; @@ -66,9 +66,9 @@ impl Ollama { /// Completion generation with a single response. /// Returns a single `GenerationResponse` object - pub async fn generate<'a>( + pub async fn generate( &self, - request: GenerationRequest<'a>, + request: GenerationRequest<'_>, ) -> crate::error::Result<GenerationResponse> { let mut request = request; request.stream = false;