Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/goose-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ axum = { version = "0.8.1", features = ["ws", "macros"] }
tokio = { version = "1.43", features = ["full"] }
chrono = "0.4"
tokio-cron-scheduler = "0.14.0"
tower-http = { version = "0.5", features = ["cors"] }
tower-http = { version = "0.5", features = ["cors", "compression-gzip", "compression-br"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
futures = "0.3"
Expand Down
8 changes: 7 additions & 1 deletion crates/goose-server/src/commands/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use etcetera::{choose_app_strategy, AppStrategy};
use goose::agents::Agent;
use goose::config::APP_STRATEGY;
use goose::scheduler_factory::SchedulerFactory;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;

Expand Down Expand Up @@ -50,7 +51,12 @@ pub async fn run() -> Result<()> {
.allow_methods(Any)
.allow_headers(Any);

let app = crate::routes::configure(app_state).layer(cors);
// Add compression middleware for gzip and brotli
let compression = CompressionLayer::new().gzip(true).br(true);

let app = crate::routes::configure(app_state)
.layer(cors)
.layer(compression);

let listener = tokio::net::TcpListener::bind(settings.socket_addr()).await?;
info!("listening on {}", listener.local_addr()?);
Expand Down
57 changes: 55 additions & 2 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,18 @@ pub struct PricingResponse {
pub source: String,
}

#[derive(Deserialize, ToSchema)]
pub struct ModelRequest {
pub provider: String,
pub model: String,
}

#[derive(Deserialize, ToSchema)]
pub struct PricingQuery {
/// If true, only return pricing for configured providers. If false, return all.
pub configured_only: Option<bool>,
/// Specific models to fetch pricing for. If provided, only these models will be returned.
pub models: Option<Vec<ModelRequest>>,
}

#[utoipa::path(
Expand All @@ -355,6 +363,7 @@ pub async fn get_pricing(
verify_secret_key(&headers, &state)?;

let configured_only = query.configured_only.unwrap_or(true);
let has_specific_models = query.models.is_some();

// If refresh requested (configured_only = false), refresh the cache
if !configured_only {
Expand All @@ -365,7 +374,49 @@ pub async fn get_pricing(

let mut pricing_data = Vec::new();

if !configured_only {
// If specific models are requested, fetch only those
if let Some(requested_models) = query.models {
for model_req in requested_models {
// Try to get pricing from cache
if let Some(pricing) = get_model_pricing(&model_req.provider, &model_req.model).await {
pricing_data.push(PricingData {
provider: model_req.provider,
model: model_req.model,
input_token_cost: pricing.input_cost,
output_token_cost: pricing.output_cost,
currency: "$".to_string(),
context_length: pricing.context_length,
});
}
// Check if the model has embedded pricing data from provider metadata
else if let Some(metadata) = get_providers()
.iter()
.find(|p| p.name == model_req.provider)
{
if let Some(model_info) = metadata
.known_models
.iter()
.find(|m| m.name == model_req.model)
{
if let (Some(input_cost), Some(output_cost)) =
(model_info.input_token_cost, model_info.output_token_cost)
{
pricing_data.push(PricingData {
provider: model_req.provider,
model: model_req.model,
input_token_cost: input_cost,
output_token_cost: output_cost,
currency: model_info
.currency
.clone()
.unwrap_or_else(|| "$".to_string()),
context_length: Some(model_info.context_limit as u32),
});
}
}
}
}
} else if !configured_only {
// Get ALL pricing data from the cache
let all_pricing = get_all_pricing().await;

Expand Down Expand Up @@ -425,7 +476,9 @@ pub async fn get_pricing(
tracing::debug!(
"Returning pricing for {} models{}",
pricing_data.len(),
if configured_only {
if has_specific_models {
" (specific models requested)"
} else if configured_only {
" (configured providers only)"
} else {
" (all cached models)"
Expand Down
4 changes: 4 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ path = "examples/async_token_counter_demo.rs"
[[bench]]
name = "tokenization_benchmark"
harness = false

[[bench]]
name = "connection_pooling"
harness = false
102 changes: 102 additions & 0 deletions crates/goose/benches/connection_pooling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use goose::providers::provider_common::{create_provider_client, get_shared_client};
use reqwest::Client;
use std::sync::Arc;
use tokio::runtime::Runtime;

fn create_new_clients(c: &mut Criterion) {
let rt = Runtime::new().unwrap();

c.bench_function("create_new_client", |b| {
b.iter(|| {
rt.block_on(async {
let _client = black_box(create_provider_client(Some(600)).unwrap());
})
})
});
}

fn reuse_shared_client(c: &mut Criterion) {
let rt = Runtime::new().unwrap();

c.bench_function("get_shared_client", |b| {
b.iter(|| {
rt.block_on(async {
let _client = black_box(get_shared_client());
})
})
});
}

fn concurrent_requests_new_clients(c: &mut Criterion) {
let rt = Runtime::new().unwrap();

let mut group = c.benchmark_group("concurrent_requests_new");
for num_requests in [10, 50, 100].iter() {
group.bench_with_input(
BenchmarkId::from_parameter(num_requests),
num_requests,
|b, &num_requests| {
b.iter(|| {
rt.block_on(async {
let tasks: Vec<_> = (0..num_requests)
.map(|_| {
tokio::spawn(async move {
let client = create_provider_client(Some(600)).unwrap();
// Simulate a request (without actually making one)
black_box(&client);
})
})
.collect();

for task in tasks {
task.await.unwrap();
}
})
})
},
);
}
group.finish();
}

fn concurrent_requests_shared_client(c: &mut Criterion) {
let rt = Runtime::new().unwrap();

let mut group = c.benchmark_group("concurrent_requests_shared");
for num_requests in [10, 50, 100].iter() {
group.bench_with_input(
BenchmarkId::from_parameter(num_requests),
num_requests,
|b, &num_requests| {
b.iter(|| {
rt.block_on(async {
let tasks: Vec<_> = (0..num_requests)
.map(|_| {
tokio::spawn(async move {
let client = get_shared_client();
// Simulate a request (without actually making one)
black_box(&client);
})
})
.collect();

for task in tasks {
task.await.unwrap();
}
})
})
},
);
}
group.finish();
}

criterion_group!(
benches,
create_new_clients,
reuse_shared_client,
concurrent_requests_new_clients,
concurrent_requests_shared_client
);
criterion_main!(benches);
Loading
Loading