Skip to content

Commit

Permalink
Merge pull request #116 from pepperoni21/borrowed-strings
Browse files Browse the repository at this point in the history
Support for borrowed strings in generation requests
  • Loading branch information
pepperoni21 authored Jan 21, 2025
2 parents 1e3f7d3 + 5f38fc3 commit 747b19d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 25 deletions.
2 changes: 1 addition & 1 deletion ollama-rs/examples/images_to_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async fn download_image(url: &str) -> Result<Vec<u8>, reqwest::Error> {

// Function to send the request to the model
async fn send_request(
request: GenerationRequest,
request: GenerationRequest<'_>,
) -> Result<GenerationResponse, Box<dyn std::error::Error>> {
let ollama = Ollama::default();
let response = ollama.generate(request).await?;
Expand Down
4 changes: 2 additions & 2 deletions ollama-rs/src/generation/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl Ollama {
/// Returns a stream of `GenerationResponse` objects
pub async fn generate_stream(
&self,
request: GenerationRequest,
request: GenerationRequest<'_>,
) -> crate::error::Result<GenerationResponseStream> {
use tokio_stream::StreamExt;

Expand Down Expand Up @@ -68,7 +68,7 @@ impl Ollama {
/// Returns a single `GenerationResponse` object
pub async fn generate(
&self,
request: GenerationRequest,
request: GenerationRequest<'_>,
) -> crate::error::Result<GenerationResponse> {
let mut request = request;
request.stream = false;
Expand Down
30 changes: 16 additions & 14 deletions ollama-rs/src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use serde::Serialize;

use crate::generation::{
Expand All @@ -10,27 +12,27 @@ use super::GenerationContext;

/// A generation request to Ollama.
#[derive(Debug, Clone, Serialize)]
pub struct GenerationRequest {
pub struct GenerationRequest<'a> {
#[serde(rename = "model")]
pub model_name: String,
pub prompt: String,
pub suffix: Option<String>,
pub prompt: Cow<'a, str>,
pub suffix: Option<Cow<'a, str>>,
pub images: Vec<Image>,
pub options: Option<GenerationOptions>,
pub system: Option<String>,
pub template: Option<String>,
pub system: Option<Cow<'a, str>>,
pub template: Option<Cow<'a, str>>,
pub context: Option<GenerationContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<FormatType>,
pub keep_alive: Option<KeepAlive>,
pub(crate) stream: bool,
}

impl GenerationRequest {
pub fn new(model_name: String, prompt: String) -> Self {
impl<'a> GenerationRequest<'a> {
pub fn new(model_name: String, prompt: impl Into<Cow<'a, str>>) -> Self {
Self {
model_name,
prompt,
prompt: prompt.into(),
suffix: None,
images: Vec::new(),
options: None,
Expand All @@ -51,8 +53,8 @@ impl GenerationRequest {
}

/// Adds a text after the model response
pub fn suffix(mut self, suffix: String) -> Self {
self.suffix = Some(suffix);
pub fn suffix(mut self, suffix: impl Into<Cow<'a, str>>) -> Self {
self.suffix = Some(suffix.into());
self
}

Expand All @@ -75,14 +77,14 @@ impl GenerationRequest {
}

/// System prompt to (overrides what is defined in the Modelfile)
pub fn system(mut self, system: String) -> Self {
self.system = Some(system);
pub fn system(mut self, system: impl Into<Cow<'a, str>>) -> Self {
self.system = Some(system.into());
self
}

/// The full prompt or prompt template (overrides what is defined in the Modelfile)
pub fn template(mut self, template: String) -> Self {
self.template = Some(template);
pub fn template(mut self, template: impl Into<Cow<'a, str>>) -> Self {
self.template = Some(template.into());
self
}

Expand Down
10 changes: 2 additions & 8 deletions ollama-rs/tests/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ async fn test_generation_stream() {
let ollama = Ollama::default();

let mut res: GenerationResponseStream = ollama
.generate_stream(GenerationRequest::new(
"llama2:latest".to_string(),
PROMPT.into(),
))
.generate_stream(GenerationRequest::new("llama2:latest".to_string(), PROMPT))
.await
.unwrap();

Expand All @@ -45,10 +42,7 @@ async fn test_generation() {
let ollama = Ollama::default();

let res = ollama
.generate(GenerationRequest::new(
"llama2:latest".to_string(),
PROMPT.into(),
))
.generate(GenerationRequest::new("llama2:latest".to_string(), PROMPT))
.await
.unwrap();
dbg!(res);
Expand Down

0 comments on commit 747b19d

Please sign in to comment.