Skip to content

Commit

Permalink
feat(config): configurable prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
efugier committed Nov 8, 2023
1 parent 5feb684 commit 752f1e7
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 48 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ edition = "2021"

[dependencies]
toml = "*"
clap = { version = "*", features = ["derive"] }
ureq = { version="*", features = ["json"] }
serde = { version = "*", features = ["derive"] }
59 changes: 46 additions & 13 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,56 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use toml::Value;

pub fn get_api_key() -> String {
let config_path = format!(
"{}/.config/pipelm/.api_configs.toml",
std::env::var("HOME").unwrap()
);
let content = fs::read_to_string(config_path).expect("Failed to read the TOML file");
#[derive(Debug, Deserialize, Serialize)]
pub struct Prompt {
#[serde(skip_serializing)] // internal use only
pub service: String,
pub model: String,
pub messages: Vec<Message>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: String,
pub content: String,
}

pub const PLACEHOLDER_TOKEN: &str = "#[<input>]";

const DEFAULT_CONFIG_PATH: &str = ".config/pipelm/";
const CUSTOM_CONFIG_ENV_VAR: &str = "PIPLE_CONFIG_PATH";
const API_KEYS_FILE: &str = ".api_keys.toml";
const PROMPT_FILE: &str = "prompts.toml";

fn resolve_config_path() -> PathBuf {
match std::env::var(CUSTOM_CONFIG_ENV_VAR) {
Ok(p) => PathBuf::new().join(p),
Err(_) => PathBuf::new().join(env!("HOME")).join(DEFAULT_CONFIG_PATH),
}
}

pub fn get_api_key(service: &str) -> String {
let api_keys_path = resolve_config_path().join(API_KEYS_FILE);
let content = fs::read_to_string(&api_keys_path)
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path, error));
let value: Value = content.parse().expect("Failed to parse TOML");

// Extract the API key from the TOML table.
let api_key = value
.get("openai")
.and_then(|table| table.get("API_KEY"))
.and_then(|api_key| api_key.as_str())
.unwrap_or_else(|| {
eprintln!("API_KEY not found in the TOML file.");
std::process::exit(1);
});
.get("API_KEYS")
.expect("API_KEYS section not found")
.get(service)
.unwrap_or_else(|| panic!("No api key found for service {}.", &service));

api_key.to_string()
}

pub fn get_prompts() -> HashMap<String, Prompt> {
let prompts_path = resolve_config_path().join(PROMPT_FILE);
let content = fs::read_to_string(&prompts_path)
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path, error));
toml::from_str(&content).unwrap()
}
20 changes: 11 additions & 9 deletions src/input_processing.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::config::{get_api_key, Prompt, PLACEHOLDER_TOKEN};
use crate::request::{make_authenticated_request, OpenAiResponse};
use std::io::{Read, Result, Write};

// [tmp] mostly template to write tests
pub fn chunk_process_input<R: Read, W: Write>(
input: &mut R,
output: &mut W,
Expand All @@ -9,7 +11,6 @@ pub fn chunk_process_input<R: Read, W: Write>(
) -> Result<()> {
let mut first_chunk = true;
let mut buffer = [0; 1024];

loop {
match input.read(&mut buffer) {
Ok(0) => break, // end of input
Expand All @@ -33,10 +34,9 @@ pub fn chunk_process_input<R: Read, W: Write>(
}

pub fn process_input_with_request<R: Read, W: Write>(
prompt: &mut Prompt,
input: &mut R,
output: &mut W,
prefix: &str,
suffix: &str,
) -> Result<()> {
let mut buffer = Vec::new();
input.read_to_end(&mut buffer)?;
Expand All @@ -48,15 +48,17 @@ pub fn process_input_with_request<R: Read, W: Write>(

let input = String::from_utf8(buffer).unwrap();

let mut result = String::from(prefix);
result.push_str(&input);
result.push_str(suffix);

let response: OpenAiResponse = make_authenticated_request(&result).unwrap().into_json()?;
for message in prompt.messages.iter_mut() {
message.content = message.content.replace(PLACEHOLDER_TOKEN, &input)
}
let api_key = get_api_key(&prompt.service);
let response: OpenAiResponse = make_authenticated_request(&api_key, prompt)
.unwrap()
.into_json()?;

println!("{}", response.choices.first().unwrap().message.content);

output.write_all(suffix.as_bytes())?;
output.write_all(input.as_bytes())?;

Ok(())
}
Expand Down
52 changes: 45 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
use clap::Parser;
use std::io;
mod config;
mod input_processing;
mod request;

#[allow(dead_code)]
mod config;

#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(default_value_t = String::from("default"))]
prompt: String,
#[arg(short, long, default_value_t = String::from("openai"))]
service: String,
}

fn main() {
let args = Cli::parse();

let mut output = io::stdout();
let mut input = io::stdin();

if let Err(e) = input_processing::chunk_process_input(
&mut input,
&mut output,
"Hello, World!\n```\n",
"\n```\n",
) {
let mut prompts = config::get_prompts();

// case for testing IO
if args.prompt == "test" {
if let Err(e) = input_processing::chunk_process_input(
&mut input,
&mut output,
"Hello, World!\n```\n",
"\n```\n",
) {
eprintln!("Error: {}", e);
std::process::exit(1);
} else {
std::process::exit(0);
}
}

let available_prompts: Vec<&String> = prompts.keys().collect();
let prompt_not_found_error = format!(
"Prompt {} not found, availables ones are: {:?}",
&args.prompt, &available_prompts
);

let prompt = prompts
.get_mut(&args.prompt)
.expect(&prompt_not_found_error);

println!("{:?}", prompt);

if let Err(e) = input_processing::process_input_with_request(prompt, &mut input, &mut output) {
eprintln!("Error: {}", e);
std::process::exit(1);
}
Expand Down
39 changes: 21 additions & 18 deletions src/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::config::get_api_key;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize)]
pub struct Message {
Expand Down Expand Up @@ -32,24 +31,28 @@ pub struct OpenAiResponse {
pub system_fingerprint: String,
}

pub fn make_authenticated_request(text: &str) -> Result<ureq::Response, ureq::Error> {
let api_key = get_api_key();
pub fn make_authenticated_request(
api_key: &str,
data: impl Serialize,
) -> Result<ureq::Response, ureq::Error> {
println!("Trying to reach openai with {}", &api_key);
ureq::post("https://api.openai.com/v1/chat/completions")
.set("Content-Type", "application/json")
.set("Authorization", &format!("Bearer {}", api_key))
.send_json(ureq::json!({
"model": "gpt-4-1106-preview",
"messages": [
{
"role": "system",
"content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair."
},
{
"role": "user",
"content": text
}
]
})
)
.send_json(data)
// .send_json(ureq::json!(
// {
// "model": "gpt-4-1106-preview",
// "messages": [
// {
// "role": "system",
// "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair."
// },
// {
// "role": "user",
// "content": data.messages.last().unwrap().content
// }
// ]
// })
// )
}
3 changes: 2 additions & 1 deletion tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::io::{Read, Write};
use std::process::{Command, Stdio};

#[test]
fn test_program_integration() {
fn test_io() {
let hardcoded_prefix = "Hello, World!\n```\n";
let hardcoded_suffix = "\n```\n";
let input_data = "Input data";

// launch the program and get the streams
let mut child = Command::new("cargo")
.arg("run")
.arg("test")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
Expand Down

0 comments on commit 752f1e7

Please sign in to comment.