-
Notifications
You must be signed in to change notification settings - Fork 4.6k
feat(acp): add model selection support for session/new and session/set_model #7112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ use goose::conversation::Conversation; | |||||
| use goose::mcp_utils::ToolResult; | ||||||
| use goose::permission::permission_confirmation::PrincipalType; | ||||||
| use goose::permission::{Permission, PermissionConfirmation}; | ||||||
| use goose::providers::base::Provider; | ||||||
| use goose::providers::provider_registry::ProviderConstructor; | ||||||
| use goose::session::session_manager::SessionType; | ||||||
| use goose::session::{Session, SessionManager}; | ||||||
|
|
@@ -21,12 +22,13 @@ use sacp::schema::{ | |||||
| AgentCapabilities, AuthMethod, AuthenticateRequest, AuthenticateResponse, BlobResourceContents, | ||||||
| CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource, | ||||||
| EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse, | ||||||
| LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest, | ||||||
| NewSessionResponse, PermissionOption, PermissionOptionKind, PromptCapabilities, PromptRequest, | ||||||
| PromptResponse, RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, SessionId, | ||||||
| SessionNotification, SessionUpdate, StopReason, TextContent, TextResourceContents, ToolCall, | ||||||
| ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, | ||||||
| ToolCallUpdateFields, ToolKind, | ||||||
| LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, ModelId, ModelInfo, | ||||||
| NewSessionRequest, NewSessionResponse, PermissionOption, PermissionOptionKind, | ||||||
| PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome, | ||||||
| RequestPermissionRequest, ResourceLink, SessionId, SessionModelState, SessionNotification, | ||||||
| SessionUpdate, SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent, | ||||||
| TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, | ||||||
| ToolCallUpdate, ToolCallUpdateFields, ToolKind, | ||||||
| }; | ||||||
| use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx}; | ||||||
| use std::collections::HashMap; | ||||||
|
|
@@ -48,7 +50,7 @@ pub struct GooseAcpAgent { | |||||
| agent: Arc<Agent>, | ||||||
| provider_factory: ProviderConstructor, | ||||||
| config_dir: std::path::PathBuf, | ||||||
| provider_initialized: tokio::sync::OnceCell<String>, | ||||||
| provider_initialized: tokio::sync::OnceCell<Arc<dyn Provider>>, | ||||||
| } | ||||||
|
|
||||||
| fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> { | ||||||
|
|
@@ -266,6 +268,22 @@ async fn add_extensions(agent: &Agent, extensions: Vec<ExtensionConfig>) { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| async fn build_model_state( | ||||||
| provider: &dyn Provider, | ||||||
| current_model: &str, | ||||||
| ) -> Result<SessionModelState, sacp::Error> { | ||||||
| let models = provider.fetch_recommended_models().await.map_err(|e| { | ||||||
| sacp::Error::internal_error().data(format!("Failed to fetch models: {}", e)) | ||||||
| })?; | ||||||
| Ok(SessionModelState::new( | ||||||
| ModelId::new(current_model), | ||||||
| models | ||||||
| .iter() | ||||||
| .map(|name| ModelInfo::new(ModelId::new(&**name), &**name)) | ||||||
| .collect(), | ||||||
| )) | ||||||
| } | ||||||
|
|
||||||
| impl GooseAcpAgent { | ||||||
| pub fn permission_manager(&self) -> Arc<PermissionManager> { | ||||||
| Arc::clone(&self.agent.config.permission_manager) | ||||||
|
|
@@ -682,7 +700,7 @@ impl GooseAcpAgent { | |||||
| .map_err(|e| { | ||||||
| sacp::Error::internal_error().data(format!("Failed to create session: {}", e)) | ||||||
| })?; | ||||||
| self.ensure_provider(&goose_session).await.map_err(|e| { | ||||||
| let provider = self.ensure_provider(&goose_session).await.map_err(|e| { | ||||||
| sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) | ||||||
| })?; | ||||||
|
|
||||||
|
|
@@ -715,25 +733,28 @@ impl GooseAcpAgent { | |||||
| "Session started" | ||||||
| ); | ||||||
|
|
||||||
| Ok(NewSessionResponse::new(SessionId::new(goose_session.id))) | ||||||
| let model_state = | ||||||
| build_model_state(&**provider, &provider.get_model_config().model_name).await?; | ||||||
|
|
||||||
| Ok(NewSessionResponse::new(SessionId::new(goose_session.id)).models(model_state)) | ||||||
| } | ||||||
|
|
||||||
| // Called at most once via OnceCell; returns the model_id used. | ||||||
| async fn create_provider(&self, session: &Session) -> Result<String> { | ||||||
| async fn create_provider(&self, session: &Session) -> Result<Arc<dyn Provider>> { | ||||||
| let config_path = self.config_dir.join(CONFIG_YAML_NAME); | ||||||
| let config = Config::new(&config_path, "goose")?; | ||||||
| let model_id = config.get_goose_model()?; | ||||||
| let model_config = goose::model::ModelConfig::new(&model_id)?; | ||||||
| let provider = (self.provider_factory)(model_config).await?; | ||||||
| self.agent.update_provider(provider, &session.id).await?; | ||||||
| Ok(model_id) | ||||||
| self.agent | ||||||
| .update_provider(provider.clone(), &session.id) | ||||||
| .await?; | ||||||
| Ok(provider) | ||||||
| } | ||||||
|
|
||||||
| async fn ensure_provider(&self, session: &Session) -> Result<()> { | ||||||
| async fn ensure_provider(&self, session: &Session) -> Result<&Arc<dyn Provider>> { | ||||||
| self.provider_initialized | ||||||
| .get_or_try_init(|| self.create_provider(session)) | ||||||
| .await?; | ||||||
| Ok(()) | ||||||
| .await | ||||||
| } | ||||||
|
codefromthecrypt marked this conversation as resolved.
|
||||||
|
|
||||||
| async fn on_load_session( | ||||||
|
|
@@ -750,7 +771,7 @@ impl GooseAcpAgent { | |||||
| sacp::Error::invalid_params() | ||||||
| .data(format!("Failed to load session {}: {}", session_id, e)) | ||||||
| })?; | ||||||
| self.ensure_provider(&goose_session).await.map_err(|e| { | ||||||
| let provider = self.ensure_provider(&goose_session).await.map_err(|e| { | ||||||
| sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) | ||||||
| })?; | ||||||
|
codefromthecrypt marked this conversation as resolved.
|
||||||
|
|
||||||
|
|
@@ -830,7 +851,10 @@ impl GooseAcpAgent { | |||||
| "Session loaded" | ||||||
| ); | ||||||
|
|
||||||
| Ok(LoadSessionResponse::new()) | ||||||
| let model_state = | ||||||
| build_model_state(&**provider, &provider.get_model_config().model_name).await?; | ||||||
|
|
||||||
| Ok(LoadSessionResponse::new().models(model_state)) | ||||||
| } | ||||||
|
|
||||||
| async fn on_prompt( | ||||||
|
|
@@ -928,6 +952,28 @@ impl GooseAcpAgent { | |||||
|
|
||||||
| Ok(()) | ||||||
| } | ||||||
|
|
||||||
| async fn on_set_model( | ||||||
| &self, | ||||||
| session_id: &str, | ||||||
| model_id: &str, | ||||||
| ) -> Result<SetSessionModelResponse, sacp::Error> { | ||||||
| let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| { | ||||||
| sacp::Error::internal_error().data(format!("Invalid model config: {}", e)) | ||||||
|
||||||
| sacp::Error::internal_error().data(format!("Invalid model config: {}", e)) | |
| sacp::Error::invalid_params().data(format!("Invalid model config: {}", e)) |
Uh oh!
There was an error while loading. Please reload this page.