From 2d8b8cebb35ba5b194672211473116a82caf0bb3 Mon Sep 17 00:00:00 2001 From: John McBride Date: Wed, 18 Oct 2023 13:56:48 -0600 Subject: [PATCH] feat: Upgrade to openai-api-rs 2.0.0 Signed-off-by: John McBride --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- src/conversation/mod.rs | 3 ++- src/conversation/prompts.rs | 25 ++++++------------------- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 870b505..8fa1540 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1711,9 +1711,9 @@ dependencies = [ [[package]] name = "openai-api-rs" -version = "1.0.4" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeab072070ad2fdba1b1f1a7de48d8a499871e5777bc855fcf81e9f8b2ef7a40" +checksum = "3696225585180dd710afc91eb91048464e8ea7b4154f30f33fdea5393a9289ea" dependencies = [ "minreq", "serde", diff --git a/Cargo.toml b/Cargo.toml index adef7b0..557573b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ rayon = "1" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] } serde = "1" tokenizers = "0.14" -openai-api-rs = "1.0" +openai-api-rs = "2.0" zip = "0.6" rust-fuzzy-search = "0.1" text-splitter = "0.4" diff --git a/src/conversation/mod.rs b/src/conversation/mod.rs index d6b25a9..c014064 100644 --- a/src/conversation/mod.rs +++ b/src/conversation/mod.rs @@ -93,7 +93,8 @@ impl Conversation { #[allow(unused_labels)] 'conversation: loop { //Generate a request with the message history and functions - let request = generate_completion_request(self.messages.clone(), FunctionCallType::Auto); + let request = + generate_completion_request(self.messages.clone(), FunctionCallType::Auto); match self.send_request(request) { Ok(response) => { diff --git a/src/conversation/prompts.rs b/src/conversation/prompts.rs index b3e4091..745bc1f 100644 --- a/src/conversation/prompts.rs +++ b/src/conversation/prompts.rs @@ -1,6 +1,6 @@ use openai_api_rs::v1::chat_completion::{ - ChatCompletionMessage, ChatCompletionRequest, Function as F, FunctionParameters, - JSONSchemaDefine, JSONSchemaType, FunctionCallType, + ChatCompletionMessage, ChatCompletionRequest, Function as F, FunctionCallType, + FunctionParameters, JSONSchemaDefine, JSONSchemaType, }; use std::collections::HashMap; @@ -17,23 +17,10 @@ pub fn generate_completion_request( messages: Vec, function_call: FunctionCallType, ) -> ChatCompletionRequest { - - ChatCompletionRequest { - model: CHAT_COMPLETION_MODEL.into(), - messages, - functions: Some(functions()), - function_call: Some(function_call), - temperature: Some(CHAT_COMPLETION_TEMPERATURE), - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - } + ChatCompletionRequest::new(CHAT_COMPLETION_MODEL.to_string(), messages) + .functions(functions()) + .function_call(function_call) + .temperature(CHAT_COMPLETION_TEMPERATURE) } pub fn functions() -> Vec {