Skip to content

Commit

Permalink
feat(apis): add anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
efugier committed Apr 3, 2024
1 parent a996728 commit bf643ad
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "smartcat"
version = "0.6.1"
version = "0.7.0"
authors = ["Emilien Fugier <[email protected]>"]
description = '''
Putting a brain behind `cat`.
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ select the generated impl block
:'<,'>!sc -e -i "can you make it more concise?"
```

put the cursor at the bottom of the file
put the cursor at the bottom of the file and give example usage as input

```
:'<,'>!sc -e -p "now write tests for it knowing it's used like this" -f src/main.rs
Expand Down Expand Up @@ -238,13 +238,18 @@ stores the latest chat if you need to continue it
```toml
[openai] # each supported api has their own config section with api and url
api_key = "<your_api_key>"
default_model = "gpt-4"
default_model = "gpt-4-turbo-preview"
url = "https://api.openai.com/v1/chat/completions"

[mistral]
api_key_command = "pass mistral/api_key" # you can use a command to grab the key
default_model = "mistral-medium"
url = "https://api.mistral.ai/v1/chat/completions"

[anthropic]
api_key = "<yet_another_api_key>"
url = "https://api.anthropic.com/v1/messages"
default_model = "claude-3-opus-20240229"
```

`prompts.toml`
Expand Down Expand Up @@ -310,9 +315,4 @@ Smartcat has reached an acceptable feature set. The focus is now on upgrading th
#### TODO

- [ ] make it available on homebrew
- [ ] refactor the prompt parameters into a struct

#### Ideas:

- interactive mode to have conversations and make the model iterate on the last answer (e.g. a flag `--start-conversation` to start and `--end-conversation` to end the current one, by default no conversation)
- fetch more context from the codebase
24 changes: 22 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const CONVERSATION_FILE: &str = "conversation.toml";
pub enum Api {
Openai,
Mistral,
Anthropic,
AnotherApiForTests,
}

Expand All @@ -30,6 +31,7 @@ impl FromStr for Api {
match s.to_lowercase().as_str() {
"openai" => Ok(Api::Openai),
"mistral" => Ok(Api::Mistral),
"anthropic" => Ok(Api::Anthropic),
_ => Err(()),
}
}
Expand All @@ -40,6 +42,7 @@ impl ToString for Api {
match self {
Api::Openai => "openai".to_string(),
Api::Mistral => "mistral".to_string(),
Api::Anthropic => "anthropic".to_string(),
v => panic!(
"{:?} is not implemented, use one among {:?}",
v,
Expand Down Expand Up @@ -103,6 +106,15 @@ impl ApiConfig {
}
}

fn anthropic() -> Self {
ApiConfig {
api_key_command: None,
api_key: None,
url: String::from("https://api.anthropic.com/v1/messages"),
default_model: Some(String::from("claude-3-opus-20240229")),
}
}

fn default_with_api_key(api_key: Option<String>) -> Self {
ApiConfig {
api_key_command: None,
Expand Down Expand Up @@ -135,7 +147,7 @@ impl Default for Prompt {
Sometimes you will be asked to implement or extend some input code. Same thing goes here, write only what was asked because what you write will \
be directly added to the user's editor. \
Never ever write ``` around the code. \
Now let's make something great together! \
Make sure to keep the indentation and formatting. \
")
];
Prompt {
Expand Down Expand Up @@ -301,11 +313,14 @@ pub fn ensure_config_files(interactive: bool) -> std::io::Result<()> {

pub fn generate_api_keys_file(api_key: Option<String>) -> std::io::Result<()> {
let mut api_config = HashMap::new();
api_config.insert(Api::Openai.to_string(), ApiConfig::openai());
api_config.insert(Api::Mistral.to_string(), ApiConfig::mistral());
api_config.insert(Api::Anthropic.to_string(), ApiConfig::anthropic());
// Default, should override one of the above
api_config.insert(
Prompt::default().api.to_string(),
ApiConfig::default_with_api_key(api_key),
);
api_config.insert(Api::Mistral.to_string(), ApiConfig::mistral());

std::fs::create_dir_all(api_keys_path().parent().unwrap())?;

Expand Down Expand Up @@ -506,6 +521,11 @@ mod tests {
Some(&ApiConfig::mistral())
);

assert_eq!(
api_config.get(&Api::Anthropic.to_string()),
Some(&ApiConfig::anthropic())
);

let default_prompt = Prompt::default();
assert_eq!(prompt_config.get("default"), Some(&default_prompt));

Expand Down
28 changes: 7 additions & 21 deletions src/input_processing.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use log::debug;
use std::io::{self, Read, Result, Write};
use std::io::{Read, Result, Write};

use crate::config::{get_api_config, Message, Prompt, PLACEHOLDER_TOKEN};
use crate::request::{make_authenticated_request, OpenAiResponse};
use crate::config::{get_api_config, Prompt, PLACEHOLDER_TOKEN};
use crate::request::make_api_request;

// [tmp] mostly template to write tests
pub fn chunk_process_input<R: Read, W: Write>(
Expand Down Expand Up @@ -65,32 +65,18 @@ pub fn process_input_with_request<R: Read, W: Write>(
let api_config = get_api_config(&prompt.api.to_string());

// make the request
let response: OpenAiResponse = make_authenticated_request(api_config, &prompt)
.map_err(|e| match e {
ureq::Error::Status(status, response) => {
let body = match response.into_string() {
Ok(body) => body,
Err(_) => "(non-UTF-8 response)".to_owned(),
};
io::Error::other(format!(
"API call failed with status code {status} and body: {body}"
))
}
ureq::Error::Transport(transport) => io::Error::other(transport),
})?
.into_json()?;
let response_message = make_api_request(api_config, &prompt)?;

let response_text = response.choices.first().unwrap().message.content.as_str();
debug!("{}", &response_text);
debug!("{}", &response_message.content);

prompt.messages.push(Message::assistant(response_text));
prompt.messages.push(response_message.clone());

if repeat_input {
input.push('\n');
output.write_all(input.as_bytes())?;
}

output.write_all(response_text.as_bytes())?;
output.write_all(response_message.content.as_bytes())?;

Ok(prompt)
}
Expand Down
124 changes: 106 additions & 18 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ use crate::config::{Api, ApiConfig, Message, Prompt};
use log::debug;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::io;

#[derive(Debug, Deserialize)]
pub struct AnthropicMessage {
pub text: String,
#[serde(rename(serialize = "type", deserialize = "type"))]
pub _type: String,
}

#[derive(Debug, Deserialize)]
pub struct MessageWrapper {
Expand All @@ -13,6 +21,23 @@ pub struct OpenAiResponse {
pub choices: Vec<MessageWrapper>,
}

#[derive(Debug, Deserialize)]
pub struct AnthropicResponse {
pub content: Vec<AnthropicMessage>,
}

impl From<AnthropicResponse> for String {
fn from(value: AnthropicResponse) -> Self {
value.content.first().unwrap().text.to_owned()
}
}

impl From<OpenAiResponse> for String {
fn from(value: OpenAiResponse) -> Self {
value.choices.first().unwrap().message.content.to_owned()
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct OpenAiPrompt {
pub model: String,
Expand All @@ -21,6 +46,15 @@ pub struct OpenAiPrompt {
pub temperature: Option<f32>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct AnthropicPrompt {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
pub max_tokens: i32,
}

impl From<Prompt> for OpenAiPrompt {
fn from(prompt: Prompt) -> OpenAiPrompt {
OpenAiPrompt {
Expand All @@ -33,10 +67,48 @@ impl From<Prompt> for OpenAiPrompt {
}
}

pub fn make_authenticated_request(
api_config: ApiConfig,
prompt: &Prompt,
) -> Result<ureq::Response, ureq::Error> {
impl From<Prompt> for AnthropicPrompt {
fn from(prompt: Prompt) -> Self {
let merged_messages =
prompt
.messages
.into_iter()
.fold(Vec::new(), |mut acc: Vec<Message>, message| {
match acc.last_mut() {
Some(last_message) if last_message.role == message.role => {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
}
_ => acc.push(message),
}
acc
});

AnthropicPrompt {
model: prompt.model.expect("model must be specified"),
messages: merged_messages,
temperature: prompt.temperature,
max_tokens: 4096,
}
}
}

fn parse_response(response: Result<ureq::Response, ureq::Error>) -> io::Result<ureq::Response> {
response.map_err(|e| match e {
ureq::Error::Status(status, response) => {
let body = match response.into_string() {
Ok(body) => body,
Err(_) => "(non-UTF-8 response)".to_owned(),
};
io::Error::other(format!(
"API call failed with status code {status} and body: {body}"
))
}
ureq::Error::Transport(transport) => io::Error::other(transport),
})
}

pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> io::Result<Message> {
debug!(
"Trying to reach {:?} with {:?}",
api_config.url, api_config.api_key
Expand All @@ -48,19 +120,35 @@ pub fn make_authenticated_request(
prompt.model = api_config.default_model.clone()
}

let request = ureq::post(&api_config.url)
.set("Content-Type", "application/json")
.set(
"Authorization",
&format!("Bearer {}", &api_config.get_api_key()),
);
match prompt.api {
Api::Openai => request.send_json(OpenAiPrompt::from(prompt)),
Api::Mistral => request.send_json(OpenAiPrompt::from(prompt)),
v => panic!(
"{:?} is not implemented, use on among {:?}",
v,
vec![Api::Openai]
let request = ureq::post(&api_config.url);
let response_text = match prompt.api {
Api::Openai | Api::Mistral => {
let request = request.set("Content-Type", "application/json").set(
"Authorization",
&format!("Bearer {}", &api_config.get_api_key()),
);
let response: OpenAiResponse =
parse_response(request.send_json(OpenAiPrompt::from(prompt)))?.into_json()?;
response.into()
}
Api::Anthropic => {
let request = request
.set("Content-Type", "application/json")
.set("x-api-key", &api_config.get_api_key())
.set("anthropic-version", "2023-06-01");
let response: AnthropicResponse =
parse_response(request.send_json(AnthropicPrompt::from(prompt)))?.into_json()?;
response.into()
}
unknown_api => panic!(
"{:?} is not implemented, use one among {:?}",
unknown_api,
vec![Api::Openai, Api::Mistral, Api::Anthropic]
),
}
};

Ok(Message {
content: response_text,
role: "assistant".to_string(),
})
}

0 comments on commit bf643ad

Please sign in to comment.