From 6f063bf87ffefaee23e9f461d1b0b747b3c6d77a Mon Sep 17 00:00:00 2001 From: efugier Date: Thu, 9 Nov 2023 10:50:53 +0100 Subject: [PATCH] feat(config): more configurable services --- src/config.rs | 28 +++++++++++++++++----------- src/input_processing.rs | 6 +++--- src/request.rs | 13 +++++++++---- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/config.rs b/src/config.rs index 1467c3c..90e38f7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,13 @@ use std::fs; use std::path::PathBuf; use toml::Value; +#[derive(Debug, Deserialize)] +pub struct ServiceConfig { + #[serde(skip_serializing)] // internal use only + pub api_key: String, + pub url: String, +} + #[derive(Debug, Deserialize, Serialize)] pub struct Prompt { #[serde(skip_serializing)] // internal use only @@ -32,20 +39,19 @@ fn resolve_config_path() -> PathBuf { } } -pub fn get_api_key(service: &str) -> String { +pub fn get_service_config(service: &str) -> ServiceConfig { let api_keys_path = resolve_config_path().join(API_KEYS_FILE); let content = fs::read_to_string(&api_keys_path) .unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path, error)); - let value: Value = content.parse().expect("Failed to parse TOML"); - - // Extract the API key from the TOML table. - let api_key = value - .get("API_KEYS") - .expect("API_KEYS section not found") - .get(service) - .unwrap_or_else(|| panic!("No api key found for service {}.", &service)); - - api_key.to_string() + let mut service_configs: HashMap = toml::from_str(&content).unwrap(); + + service_configs.remove(service).unwrap_or_else(|| { + panic!( + "Prompt {} not found, availables ones are: {:?}", + service, + service_configs.keys().collect::>() + ) + }) } pub fn get_prompts() -> HashMap { diff --git a/src/input_processing.rs b/src/input_processing.rs index af0eff6..78add66 100644 --- a/src/input_processing.rs +++ b/src/input_processing.rs @@ -1,4 +1,4 @@ -use crate::config::{get_api_key, Prompt, PLACEHOLDER_TOKEN}; +use crate::config::{get_service_config, Prompt, PLACEHOLDER_TOKEN}; use crate::request::{make_authenticated_request, OpenAiResponse}; use std::io::{Read, Result, Write}; @@ -51,8 +51,8 @@ pub fn process_input_with_request( for message in prompt.messages.iter_mut() { message.content = message.content.replace(PLACEHOLDER_TOKEN, &input) } - let api_key = get_api_key(&prompt.service); - let response: OpenAiResponse = make_authenticated_request(&api_key, prompt) + let service_config = get_service_config(&prompt.service); + let response: OpenAiResponse = make_authenticated_request(service_config, prompt) .unwrap() .into_json()?; diff --git a/src/request.rs b/src/request.rs index b53a1f3..0e50d13 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::config::ServiceConfig; + #[derive(Debug, Deserialize)] pub struct Message { pub role: String, @@ -32,13 +34,16 @@ pub struct OpenAiResponse { } pub fn make_authenticated_request( - api_key: &str, + service_config: ServiceConfig, data: impl Serialize, ) -> Result { - println!("Trying to reach openai with {}", &api_key); - ureq::post("https://api.openai.com/v1/chat/completions") + println!("Trying to reach openai with {}", service_config.api_key); + ureq::post(&service_config.url) .set("Content-Type", "application/json") - .set("Authorization", &format!("Bearer {}", api_key)) + .set( + "Authorization", + &format!("Bearer {}", service_config.api_key), + ) .send_json(data) // .send_json(ureq::json!( // {