Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hyperbolic inference API #238

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions rig-core/examples/agent_with_hyperbolic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::env;

use rig::{completion::Prompt, providers};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client = providers::hyperbolic::Client::new(
&env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set"),
);

// Create agent with a single context prompt
let comedian_agent = client
.agent(rig::providers::hyperbolic::DEEPSEEK_R1)
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}
311 changes: 311 additions & 0 deletions rig-core/src/providers/hyperbolic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
//! Hyperbolic Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::hyperbolic;
//!
//! let client = hyperbolic::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
//! ```

use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;

// ================================================================
// Main Hyperbolic Client
// ================================================================
const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Hyperbolic client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
}

/// Create a new OpenAI client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenAI reqwest client should build"),
}
}

/// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
Self::new(&api_key)
}

fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::hyperbolic::{Client, self};
///
/// // Initialize the Hyperbolic client
/// let hyperbolic = Client::new("your-hyperbolic-api-key");
///
/// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::hyperbolic::{Client, self};
///
/// // Initialize the Eternal client
/// let hyperbolic = Client::new("your-hyperbolic-api-key");
///
/// let agent = hyperbolic.agent(hyperbolic::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f64>,
pub index: usize,
}

#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}

impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

// ================================================================
// Hyperbolic Completion API
// ================================================================
/// Meta Llama 3.1b Instruct model with 8B parameters.
pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
/// Meta Llama 3.3b Instruct model with 70B parameters.
pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
/// Meta Llama 3.1b Instruct model with 70B parameters.
pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
/// Meta Llama 3 Instruct model with 70B parameters.
pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
/// Hermes 3 Instruct model with 70B parameters.
pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
/// Deepseek v2.5 model.
pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
/// Qwen 2.5 model with 72B parameters.
pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
/// Meta Llama 3.2b Instruct model with 3B parameters.
pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
/// Qwen 2.5 Coder Instruct model with 32B parameters.
pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
/// Preview (latest) version of Qwen model with 32B parameters.
pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
/// Deepseek R1 Zero model.
pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
/// Deepseek R1 model.
pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";

/// A Hyperbolic completion object.
///
/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}

impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;

fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message:
Message {
content: Some(content),
..
},
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
"Response did not contain a message".into(),
)),
}
}
}

#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
}

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
pub model: String,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;

async fn completion(
&self,
mut completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
} else {
vec![]
};

// Extend existing chat history
full_history.append(&mut completion_request.chat_history);

// Add context documents to chat history
let prompt_with_context = completion_request.prompt_with_context();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});

let request = json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
});

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"Hyperbolic completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);

response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod anthropic;
pub mod cohere;
pub mod eternalai;
pub mod gemini;
pub mod hyperbolic;
pub mod openai;
pub mod perplexity;
pub mod xai;
Loading