-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add session to agents #4216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add session to agents #4216
Changes from all commits
147c17b
6943f23
8fb862c
7e0b2cd
4f7064c
9a8d5e3
9beeba0
1f6a286
a6600d3
60c04fe
bc5fc11
c31b33e
e530f60
4739454
818c8df
fa6fe8a
ba694b2
efcdc17
2bfd409
0b03735
155a483
2778f3e
c94e16b
a772000
ef9a0d8
8a97540
640eb7e
d23ca9a
58c5c2f
2e29c3f
e2339c6
f8dbb44
2e68030
52bdc41
9a0a779
93ad89b
e10d469
6b0b59d
451bd56
8cef0a6
f90dce0
af18adc
64a1a2f
1caf7e0
6deb9d2
e5f4c0f
2d6721f
2877bfc
2cfb363
7f3f004
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,26 +1,36 @@ | ||
| use super::utils::verify_secret_key; | ||
| use crate::state::AppState; | ||
| use axum::response::IntoResponse; | ||
| use axum::{ | ||
| extract::{Query, State}, | ||
| http::{HeaderMap, StatusCode}, | ||
| routing::{get, post}, | ||
| Json, Router, | ||
| }; | ||
| use goose::config::PermissionManager; | ||
| use goose::conversation::message::Message; | ||
| use goose::conversation::Conversation; | ||
| use goose::model::ModelConfig; | ||
| use goose::providers::create; | ||
| use goose::recipe::Response; | ||
| use goose::recipe::{Recipe, Response}; | ||
| use goose::session; | ||
| use goose::session::SessionMetadata; | ||
| use goose::{ | ||
| agents::{extension::ToolInfo, extension_manager::get_parameter_names}, | ||
| config::permission::PermissionLevel, | ||
| }; | ||
| use goose::{config::Config, recipe::SubRecipe}; | ||
| use serde::{Deserialize, Serialize}; | ||
| use std::path::PathBuf; | ||
| use std::sync::atomic::Ordering; | ||
| use std::sync::Arc; | ||
| use tracing::error; | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct ExtendPromptRequest { | ||
| extension: String, | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Serialize, utoipa::ToSchema)] | ||
|
|
@@ -31,6 +41,8 @@ pub struct ExtendPromptResponse { | |
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct AddSubRecipesRequest { | ||
| sub_recipes: Vec<SubRecipe>, | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Serialize, utoipa::ToSchema)] | ||
|
|
@@ -42,23 +54,149 @@ pub struct AddSubRecipesResponse { | |
| pub struct UpdateProviderRequest { | ||
| provider: String, | ||
| model: Option<String>, | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct SessionConfigRequest { | ||
| response: Option<Response>, | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct GetToolsQuery { | ||
| extension_name: Option<String>, | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct UpdateRouterToolSelectorRequest { | ||
| #[allow(dead_code)] | ||
| session_id: String, | ||
| } | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct StartAgentRequest { | ||
| working_dir: String, | ||
| recipe: Option<Recipe>, | ||
| } | ||
|
|
||
| #[derive(Deserialize, utoipa::ToSchema)] | ||
| pub struct ResumeAgentRequest { | ||
| session_id: String, | ||
| } | ||
|
|
||
| // This is the same as SessionHistoryResponse | ||
| #[derive(Serialize, utoipa::ToSchema)] | ||
| pub struct StartAgentResponse { | ||
| session_id: String, | ||
| metadata: SessionMetadata, | ||
| messages: Vec<Message>, | ||
| } | ||
|
|
||
| #[derive(Serialize, utoipa::ToSchema)] | ||
| pub struct ErrorResponse { | ||
| error: String, | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| post, | ||
| path = "/agent/start", | ||
| request_body = StartAgentRequest, | ||
| responses( | ||
| (status = 200, description = "Agent started successfully", body = StartAgentResponse), | ||
| (status = 400, description = "Bad request - invalid working directory"), | ||
| (status = 401, description = "Unauthorized - invalid secret key"), | ||
| (status = 500, description = "Internal server error") | ||
| ) | ||
| )] | ||
| async fn start_agent( | ||
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| Json(payload): Json<StartAgentRequest>, | ||
| ) -> Result<Json<StartAgentResponse>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| state.reset().await; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would expect to append the newly started agent to the state rather then resetting/overriding it? |
||
|
|
||
| let session_id = session::generate_session_id(); | ||
| let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1; | ||
|
|
||
| let metadata = SessionMetadata { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does SessionMetadata::new() not do what we want?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, I added a description with a counter to this so we can better see which session is new etc. but you are right, I should just move that to the new method |
||
| working_dir: PathBuf::from(&payload.working_dir), | ||
| description: format!("New session {}", counter), | ||
| schedule_id: None, | ||
| message_count: 0, | ||
| total_tokens: Some(0), | ||
| input_tokens: Some(0), | ||
| output_tokens: Some(0), | ||
| accumulated_total_tokens: Some(0), | ||
| accumulated_input_tokens: Some(0), | ||
| accumulated_output_tokens: Some(0), | ||
| extension_data: Default::default(), | ||
| recipe: payload.recipe, | ||
| }; | ||
|
|
||
| let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { | ||
| Ok(path) => path, | ||
| Err(_) => return Err(StatusCode::BAD_REQUEST), | ||
| }; | ||
|
|
||
| let conversation = Conversation::empty(); | ||
| session::storage::save_messages_with_metadata(&session_path, &metadata, &conversation) | ||
| .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
|
||
| Ok(Json(StartAgentResponse { | ||
| session_id, | ||
| metadata, | ||
| messages: conversation.messages().clone(), | ||
| })) | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| post, | ||
| path = "/agent/resume", | ||
| request_body = ResumeAgentRequest, | ||
| responses( | ||
| (status = 200, description = "Agent started successfully", body = StartAgentResponse), | ||
| (status = 400, description = "Bad request - invalid working directory"), | ||
| (status = 401, description = "Unauthorized - invalid secret key"), | ||
| (status = 500, description = "Internal server error") | ||
| ) | ||
| )] | ||
| async fn resume_agent( | ||
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| Json(payload): Json<ResumeAgentRequest>, | ||
| ) -> Result<Json<StartAgentResponse>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let session_path = | ||
| match session::get_path(session::Identifier::Name(payload.session_id.clone())) { | ||
| Ok(path) => path, | ||
| Err(_) => return Err(StatusCode::BAD_REQUEST), | ||
| }; | ||
|
|
||
| let metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; | ||
|
|
||
| let conversation = match session::read_messages(&session_path) { | ||
| Ok(messages) => messages, | ||
| Err(e) => { | ||
| error!("Failed to read session messages: {:?}", e); | ||
| return Err(StatusCode::NOT_FOUND); | ||
| } | ||
| }; | ||
|
|
||
| Ok(Json(StartAgentResponse { | ||
| session_id: payload.session_id.clone(), | ||
| metadata, | ||
| messages: conversation.messages().clone(), | ||
| })) | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| post, | ||
| path = "/agent/add_sub_recipes", | ||
|
|
@@ -76,10 +214,7 @@ async fn add_sub_recipes( | |
| ) -> Result<Json<AddSubRecipesResponse>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let agent = state | ||
| .get_agent() | ||
| .await | ||
| .map_err(|_| StatusCode::PRECONDITION_FAILED)?; | ||
| let agent = state.get_agent().await; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need that PRECONDITION_FAILED? anymore? (not sure what we would do with it anyway) |
||
| agent.add_sub_recipes(payload.sub_recipes.clone()).await; | ||
| Ok(Json(AddSubRecipesResponse { success: true })) | ||
| } | ||
|
|
@@ -101,10 +236,7 @@ async fn extend_prompt( | |
| ) -> Result<Json<ExtendPromptResponse>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let agent = state | ||
| .get_agent() | ||
| .await | ||
| .map_err(|_| StatusCode::PRECONDITION_FAILED)?; | ||
| let agent = state.get_agent().await; | ||
| agent.extend_system_prompt(payload.extension.clone()).await; | ||
| Ok(Json(ExtendPromptResponse { success: true })) | ||
| } | ||
|
|
@@ -113,7 +245,8 @@ async fn extend_prompt( | |
| get, | ||
| path = "/agent/tools", | ||
| params( | ||
| ("extension_name" = Option<String>, Query, description = "Optional extension name to filter tools") | ||
| ("extension_name" = Option<String>, Query, description = "Optional extension name to filter tools"), | ||
| ("session_id" = String, Query, description = "Required session ID to scope tools to a specific session") | ||
| ), | ||
| responses( | ||
| (status = 200, description = "Tools retrieved successfully", body = Vec<ToolInfo>), | ||
|
|
@@ -131,10 +264,7 @@ async fn get_tools( | |
|
|
||
| let config = Config::global(); | ||
| let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string()); | ||
| let agent = state | ||
| .get_agent() | ||
| .await | ||
| .map_err(|_| StatusCode::PRECONDITION_FAILED)?; | ||
| let agent = state.get_agent().await; | ||
| let permission_manager = PermissionManager::default(); | ||
|
|
||
| let mut tools: Vec<ToolInfo> = agent | ||
|
|
@@ -186,38 +316,45 @@ async fn update_agent_provider( | |
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| Json(payload): Json<UpdateProviderRequest>, | ||
| ) -> Result<StatusCode, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let agent = state | ||
| .get_agent() | ||
| .await | ||
| .map_err(|_e| StatusCode::PRECONDITION_FAILED)?; | ||
| ) -> Result<StatusCode, impl IntoResponse> { | ||
| verify_secret_key(&headers, &state).map_err(|e| (e, String::new()))?; | ||
|
|
||
| let agent = state.get_agent().await; | ||
| let config = Config::global(); | ||
| let model = match payload | ||
| .model | ||
| .or_else(|| config.get_param("GOOSE_MODEL").ok()) | ||
| { | ||
| Some(m) => m, | ||
| None => return Err(StatusCode::BAD_REQUEST), | ||
| None => return Err((StatusCode::BAD_REQUEST, "No model specified".to_string())), | ||
| }; | ||
|
|
||
| let model_config = ModelConfig::new(&model).map_err(|_| StatusCode::BAD_REQUEST)?; | ||
| let model_config = ModelConfig::new(&model).map_err(|e| { | ||
| ( | ||
| StatusCode::BAD_REQUEST, | ||
| format!("Invalid model config: {}", e), | ||
| ) | ||
| })?; | ||
|
|
||
| let new_provider = create(&payload.provider, model_config).map_err(|e| { | ||
| ( | ||
| StatusCode::BAD_REQUEST, | ||
| format!("Failed to create provider: {}", e), | ||
| ) | ||
| })?; | ||
|
|
||
| let new_provider = | ||
| create(&payload.provider, model_config).map_err(|_| StatusCode::BAD_REQUEST)?; | ||
| agent | ||
| .update_provider(new_provider) | ||
| .await | ||
| .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| .map_err(|_e| (StatusCode::INTERNAL_SERVER_ERROR, String::new()))?; | ||
|
|
||
| Ok(StatusCode::OK) | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| post, | ||
| path = "/agent/update_router_tool_selector", | ||
| request_body = UpdateRouterToolSelectorRequest, | ||
| responses( | ||
| (status = 200, description = "Tool selection strategy updated successfully", body = String), | ||
| (status = 401, description = "Unauthorized - invalid secret key"), | ||
|
|
@@ -228,20 +365,15 @@ async fn update_agent_provider( | |
| async fn update_router_tool_selector( | ||
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| Json(_payload): Json<UpdateRouterToolSelectorRequest>, | ||
| ) -> Result<Json<String>, Json<ErrorResponse>> { | ||
| verify_secret_key(&headers, &state).map_err(|_| { | ||
| Json(ErrorResponse { | ||
| error: "Unauthorized - Invalid or missing API key".to_string(), | ||
| }) | ||
| })?; | ||
|
|
||
| let agent = state.get_agent().await.map_err(|e| { | ||
| tracing::error!("Failed to get agent: {}", e); | ||
| Json(ErrorResponse { | ||
| error: format!("Failed to get agent: {}", e), | ||
| }) | ||
| })?; | ||
|
|
||
| let agent = state.get_agent().await; | ||
| agent | ||
| .update_router_tool_selector(None, Some(true)) | ||
| .await | ||
|
|
@@ -279,13 +411,7 @@ async fn update_session_config( | |
| }) | ||
| })?; | ||
|
|
||
| let agent = state.get_agent().await.map_err(|e| { | ||
| tracing::error!("Failed to get agent: {}", e); | ||
| Json(ErrorResponse { | ||
| error: format!("Failed to get agent: {}", e), | ||
| }) | ||
| })?; | ||
|
|
||
| let agent = state.get_agent().await; | ||
| if let Some(response) = payload.response { | ||
| agent.add_final_output_tool(response).await; | ||
|
|
||
|
|
@@ -300,6 +426,8 @@ async fn update_session_config( | |
|
|
||
| pub fn routes(state: Arc<AppState>) -> Router { | ||
| Router::new() | ||
| .route("/agent/start", post(start_agent)) | ||
| .route("/agent/resume", post(resume_agent)) | ||
| .route("/agent/prompt", post(extend_prompt)) | ||
| .route("/agent/tools", get(get_tools)) | ||
| .route("/agent/update_provider", post(update_agent_provider)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you get lint warnings if these attributes are removed? I thought it wouldn't warn because the structs are
pubThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did get a warning, yeah, not sure who said. I think typescript doesn't warn you if something is public, maybe some rust thing alwasy does?