Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions crates/goose-server/src/commands/agent.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::sync::Arc;

use crate::configuration;
use crate::state;
use anyhow::Result;
use axum::middleware;
use etcetera::{choose_app_strategy, AppStrategy};
use goose::agents::Agent;
use goose::config::APP_STRATEGY;
use goose::scheduler_factory::SchedulerFactory;
use goose_server::auth::check_token;
Expand All @@ -32,10 +29,7 @@ pub async fn run() -> Result<()> {
let secret_key =
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());

let new_agent = Agent::new();
let agent_ref = Arc::new(new_agent);

let app_state = state::AppState::new(agent_ref.clone());
let app_state = state::AppState::new().await;

let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())?
.data_dir()
Expand All @@ -44,8 +38,8 @@ pub async fn run() -> Result<()> {
let scheduler_instance = SchedulerFactory::create(schedule_file_path).await?;
app_state.set_scheduler(scheduler_instance.clone()).await;

// NEW: Provide scheduler access to the agent
agent_ref.set_scheduler(scheduler_instance).await;
// TODO: Once we have per-session agents, each agent will need scheduler access
// For now, we'll handle this when agents are created in AgentManager

let cors = CorsLayer::new()
.allow_origin(Any)
Expand Down
117 changes: 96 additions & 21 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::state::AppState;
use axum::response::IntoResponse;

use axum::{
extract::{Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
Expand Down Expand Up @@ -101,22 +102,30 @@ pub struct ErrorResponse {
error: String,
}

#[derive(Serialize, utoipa::ToSchema)]
pub struct AgentStatsResponse {
agents_created: usize,
agents_cleaned: usize,
cache_hits: usize,
cache_misses: usize,
active_agents: usize,
}

#[utoipa::path(
post,
path = "/agent/start",
request_body = StartAgentRequest,
responses(
(status = 200, description = "Agent started successfully", body = StartAgentResponse),
(status = 400, description = "Bad request - invalid working directory"),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
(status = 401, description = "Unauthorized - invalid secret key")
)
)]
async fn start_agent(
State(state): State<Arc<AppState>>,
Json(payload): Json<StartAgentRequest>,
) -> Result<Json<StartAgentResponse>, StatusCode> {
state.reset().await;
// No longer reset the global agent - each session gets its own

let session_id = session::generate_session_id();
let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1;
Expand Down Expand Up @@ -196,14 +205,18 @@ async fn resume_agent(
responses(
(status = 200, description = "Added sub recipes to agent successfully", body = AddSubRecipesResponse),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 424, description = "Agent not initialized"),
),
(status = 424, description = "Agent not initialized")
)
)]
async fn add_sub_recipes(
State(state): State<Arc<AppState>>,
Json(payload): Json<AddSubRecipesRequest>,
) -> Result<Json<AddSubRecipesResponse>, StatusCode> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(payload.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
agent.add_sub_recipes(payload.sub_recipes.clone()).await;
Ok(Json(AddSubRecipesResponse { success: true }))
}
Expand All @@ -215,14 +228,18 @@ async fn add_sub_recipes(
responses(
(status = 200, description = "Extended system prompt successfully", body = ExtendPromptResponse),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 424, description = "Agent not initialized"),
),
(status = 424, description = "Agent not initialized")
)
)]
async fn extend_prompt(
State(state): State<Arc<AppState>>,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(payload.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
agent.extend_system_prompt(payload.extension.clone()).await;
Ok(Json(ExtendPromptResponse { success: true }))
}
Expand All @@ -236,9 +253,7 @@ async fn extend_prompt(
),
responses(
(status = 200, description = "Tools retrieved successfully", body = Vec<ToolInfo>),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 424, description = "Agent not initialized"),
(status = 500, description = "Internal server error")
(status = 401, description = "Unauthorized - invalid secret key")
)
)]
async fn get_tools(
Expand All @@ -247,7 +262,11 @@ async fn get_tools(
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(query.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let permission_manager = PermissionManager::default();

let mut tools: Vec<ToolInfo> = agent
Expand Down Expand Up @@ -290,16 +309,18 @@ async fn get_tools(
responses(
(status = 200, description = "Provider updated successfully"),
(status = 400, description = "Bad request - missing or invalid parameters"),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 424, description = "Agent not initialized"),
(status = 500, description = "Internal server error")
(status = 401, description = "Unauthorized - invalid secret key")
)
)]
async fn update_agent_provider(
State(state): State<Arc<AppState>>,
Json(payload): Json<UpdateProviderRequest>,
) -> Result<StatusCode, impl IntoResponse> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(payload.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, String::new())
})?;
let config = Config::global();
let model = match payload
.model
Expand Down Expand Up @@ -344,9 +365,15 @@ async fn update_agent_provider(
)]
async fn update_router_tool_selector(
State(state): State<Arc<AppState>>,
Json(_payload): Json<UpdateRouterToolSelectorRequest>,
Json(payload): Json<UpdateRouterToolSelectorRequest>,
) -> Result<Json<String>, Json<ErrorResponse>> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(payload.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
Json(ErrorResponse {
error: format!("Failed to get agent: {}", e),
})
})?;
agent
.update_router_tool_selector(None, Some(true))
.await
Expand Down Expand Up @@ -377,7 +404,13 @@ async fn update_session_config(
State(state): State<Arc<AppState>>,
Json(payload): Json<SessionConfigRequest>,
) -> Result<Json<String>, Json<ErrorResponse>> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(payload.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
Json(ErrorResponse {
error: format!("Failed to get agent: {}", e),
})
})?;
if let Some(response) = payload.response {
agent.add_final_output_tool(response).await;

Expand All @@ -390,6 +423,46 @@ async fn update_session_config(
}
}

#[utoipa::path(
get,
path = "/agent/stats",
responses(
(status = 200, description = "Agent statistics retrieved successfully", body = AgentStatsResponse),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
)
)]
async fn get_agent_stats(
State(state): State<Arc<AppState>>,
) -> Result<Json<AgentStatsResponse>, StatusCode> {
let metrics = state.get_agent_metrics().await;

Ok(Json(AgentStatsResponse {
agents_created: metrics.agents_created,
agents_cleaned: metrics.agents_cleaned,
cache_hits: metrics.cache_hits,
cache_misses: metrics.cache_misses,
active_agents: metrics.active_agents,
}))
}

#[utoipa::path(
post,
path = "/agent/cleanup",
responses(
(status = 200, description = "Agent cleanup completed successfully", body = String),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
)
)]
async fn cleanup_agents(State(state): State<Arc<AppState>>) -> Result<Json<String>, StatusCode> {
let cleaned = state.cleanup_idle_agents().await.map_err(|e| {
tracing::error!("Failed to cleanup agents: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(format!("Cleaned up {} idle agents", cleaned)))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/agent/start", post(start_agent))
Expand All @@ -403,5 +476,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
)
.route("/agent/session_config", post(update_session_config))
.route("/agent/add_sub_recipes", post(add_sub_recipes))
.route("/agent/stats", get(get_agent_stats))
.route("/agent/cleanup", post(cleanup_agents))
.with_state(state)
}
8 changes: 4 additions & 4 deletions crates/goose-server/src/routes/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_requires_auth() {
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let state = AppState::new().await;
let app = routes(state);

// Test without auth header
Expand All @@ -418,7 +418,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_validates_size() {
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let state = AppState::new().await;
let app = routes(state);

// Create a large base64 string (simulating > 25MB audio)
Expand All @@ -444,7 +444,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_validates_mime_type() {
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let state = AppState::new().await;
let app = routes(state);

let request = Request::builder()
Expand All @@ -470,7 +470,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_handles_invalid_base64() {
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let state = AppState::new().await;
let app = routes(state);

let request = Request::builder()
Expand Down
3 changes: 1 addition & 2 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,8 @@ pub fn routes(state: Arc<AppState>) -> Router {

#[cfg(test)]
mod tests {
use http::HeaderMap;

use super::*;
use http::HeaderMap;

#[tokio::test]
async fn test_read_model_limits() {
Expand Down
9 changes: 8 additions & 1 deletion crates/goose-server/src/routes/context.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::state::AppState;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use goose::conversation::{message::Message, Conversation};
use goose::session;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use utoipa::ToSchema;
Expand All @@ -13,6 +14,8 @@ pub struct ContextManageRequest {
pub messages: Vec<Message>,
/// Operation to perform: "truncation" or "summarize"
pub manage_action: String,
/// Session ID for the context management
pub session_id: String,
}

/// Response from context management operations
Expand Down Expand Up @@ -44,7 +47,11 @@ async fn manage_context(
State(state): State<Arc<AppState>>,
Json(request): Json<ContextManageRequest>,
) -> Result<Json<ContextManageResponse>, StatusCode> {
let agent = state.get_agent().await;
let session_id = session::Identifier::Name(request.session_id.clone());
let agent = state.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent for session: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;

let mut processed_messages = Conversation::new_unvalidated(vec![]);
let mut token_counts: Vec<usize> = vec![];
Expand Down
Loading
Loading