Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 95 additions & 90 deletions crates/goose-acp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,24 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use url::Url;

// Agent binds provider, extensions, and permission channels to a single session.
// ACP has no session/close, so sessions accumulate until transport closes.
struct GooseAcpSession {
agent: Arc<Agent>,
messages: Conversation,
tool_requests: HashMap<String, goose::conversation::message::ToolRequest>,
cancel_token: Option<CancellationToken>,
}

pub struct GooseAcpAgent {
sessions: Arc<Mutex<HashMap<String, GooseAcpSession>>>,
agent: Arc<Agent>,
provider_factory: ProviderConstructor,
config_dir: std::path::PathBuf,
provider_initialized: tokio::sync::OnceCell<Arc<dyn Provider>>,
session_manager: Arc<SessionManager>,
permission_manager: Arc<PermissionManager>,
goose_mode: goose::config::GooseMode,
disable_session_naming: bool,
builtins: Vec<String>,
}

fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> {
Expand Down Expand Up @@ -286,7 +292,7 @@ async fn build_model_state(

impl GooseAcpAgent {
pub fn permission_manager(&self) -> Arc<PermissionManager> {
Arc::clone(&self.agent.config.permission_manager)
Arc::clone(&self.permission_manager)
}

pub async fn new(
Expand All @@ -300,60 +306,36 @@ impl GooseAcpAgent {
let session_manager = Arc::new(SessionManager::new(data_dir));
let permission_manager = Arc::new(PermissionManager::new(config_dir.clone()));

let agent = Agent::with_config(AgentConfig::new(
Arc::clone(&session_manager),
permission_manager,
None,
goose_mode,
disable_session_naming,
));

let agent_ptr = Arc::new(agent);

let config_path = config_dir.join(CONFIG_YAML_NAME);
let config_file = Config::new(&config_path, "goose")?;
let extensions = get_enabled_extensions_with_config(&config_file);

add_builtins(&agent_ptr, builtins).await;
add_extensions(&agent_ptr, extensions).await;

Ok(Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
agent: agent_ptr,
provider_factory,
config_dir,
provider_initialized: tokio::sync::OnceCell::new(),
session_manager,
permission_manager,
goose_mode,
disable_session_naming,
builtins,
})
}

pub async fn create_session(&self) -> Result<String> {
let manager = self.agent.config.session_manager.clone();
let goose_session = manager
.create_session(
std::env::current_dir().unwrap_or_default(),
"ACP Session".to_string(),
SessionType::User,
)
.await?;

self.ensure_provider(&goose_session).await?;

let session = GooseAcpSession {
messages: Conversation::new_unvalidated(Vec::new()),
tool_requests: HashMap::new(),
cancel_token: None,
};

let mut sessions = self.sessions.lock().await;
sessions.insert(goose_session.id.clone(), session);
async fn create_agent_for_session(&self) -> Arc<Agent> {
let agent = Agent::with_config(AgentConfig::new(
Arc::clone(&self.session_manager),
Arc::clone(&self.permission_manager),
None,
self.goose_mode,
self.disable_session_naming,
));
let agent = Arc::new(agent);

info!(
session_id = %goose_session.id,
session_type = "acp",
"Session created"
);
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
if let Ok(config_file) = Config::new(&config_path, "goose") {
let extensions = get_enabled_extensions_with_config(&config_file);
add_extensions(&agent, extensions).await;
}
add_builtins(&agent, self.builtins.clone()).await;

Ok(goose_session.id)
agent
}

pub async fn has_session(&self, session_id: &str) -> bool {
Expand Down Expand Up @@ -433,12 +415,13 @@ impl GooseAcpAgent {
} = &action_required.data
{
self.handle_tool_permission_request(
cx,
&session.agent,
session_id,
id.clone(),
tool_name.clone(),
arguments.clone(),
prompt.clone(),
session_id,
cx,
)?;
}
}
Expand Down Expand Up @@ -513,17 +496,19 @@ impl GooseAcpAgent {
Ok(())
}

#[allow(clippy::too_many_arguments)]
fn handle_tool_permission_request(
&self,
cx: &JrConnectionCx<AgentToClient>,
agent: &Arc<Agent>,
session_id: &SessionId,
request_id: String,
tool_name: String,
arguments: serde_json::Map<String, serde_json::Value>,
prompt: Option<String>,
session_id: &SessionId,
cx: &JrConnectionCx<AgentToClient>,
) -> Result<(), sacp::Error> {
let cx = cx.clone();
let agent = self.agent.clone();
let agent = agent.clone();
let session_id = session_id.clone();

let formatted_name = format_tool_name(&tool_name);
Expand Down Expand Up @@ -689,8 +674,8 @@ impl GooseAcpAgent {
) -> Result<NewSessionResponse, sacp::Error> {
debug!(?args, "new session request");

let manager = self.agent.config.session_manager.clone();
let goose_session = manager
let goose_session = self
.session_manager
.create_session(
args.cwd.clone(),
"ACP Session".to_string(),
Expand All @@ -700,9 +685,14 @@ impl GooseAcpAgent {
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create session: {}", e))
})?;
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
})?;

let agent = self.create_agent_for_session().await;
let provider = self
.init_provider(&agent, &goose_session)
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
})?;

for mcp_server in args.mcp_servers {
let config = match mcp_server_to_extension_config(mcp_server) {
Expand All @@ -712,13 +702,14 @@ impl GooseAcpAgent {
}
};
let name = config.name().to_string();
if let Err(e) = self.agent.add_extension(config, &goose_session.id).await {
if let Err(e) = agent.add_extension(config, &goose_session.id).await {
return Err(sacp::Error::internal_error()
.data(format!("Failed to add MCP server '{}': {}", name, e)));
}
}

let session = GooseAcpSession {
agent,
messages: Conversation::new_unvalidated(Vec::new()),
tool_requests: HashMap::new(),
cancel_token: None,
Expand All @@ -734,29 +725,26 @@ impl GooseAcpAgent {
);

let model_state =
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
build_model_state(&*provider, &provider.get_model_config().model_name).await?;

Ok(NewSessionResponse::new(SessionId::new(goose_session.id)).models(model_state))
}

async fn create_provider(&self, session: &Session) -> Result<Arc<dyn Provider>> {
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
let config = Config::new(&config_path, "goose")?;
let model_id = config.get_goose_model()?;
let model_config = goose::model::ModelConfig::new(&model_id)?;
async fn init_provider(&self, agent: &Agent, session: &Session) -> Result<Arc<dyn Provider>> {
let model_config = match &session.model_config {
Some(config) => config.clone(),
None => {
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
let config = Config::new(&config_path, "goose")?;
let model_id = config.get_goose_model()?;
goose::model::ModelConfig::new(&model_id)?
}
};
let provider = (self.provider_factory)(model_config).await?;
self.agent
.update_provider(provider.clone(), &session.id)
.await?;
agent.update_provider(provider.clone(), &session.id).await?;
Ok(provider)
}

async fn ensure_provider(&self, session: &Session) -> Result<&Arc<dyn Provider>> {
self.provider_initialized
.get_or_try_init(|| self.create_provider(session))
.await
}

async fn on_load_session(
&self,
args: LoadSessionRequest,
Expand All @@ -766,21 +754,29 @@ impl GooseAcpAgent {

let session_id = args.session_id.0.to_string();

let manager = self.agent.config.session_manager.clone();
let goose_session = manager.get_session(&session_id, true).await.map_err(|e| {
sacp::Error::invalid_params()
.data(format!("Failed to load session {}: {}", session_id, e))
})?;
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
})?;
let goose_session = self
.session_manager
.get_session(&session_id, true)
.await
.map_err(|e| {
sacp::Error::invalid_params()
.data(format!("Failed to load session {}: {}", session_id, e))
})?;

let agent = self.create_agent_for_session().await;
let provider = self
.init_provider(&agent, &goose_session)
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
})?;

let conversation = goose_session.conversation.ok_or_else(|| {
sacp::Error::internal_error()
.data(format!("Session {} has no conversation data", session_id))
})?;

manager
self.session_manager
.update(&session_id)
.working_dir(args.cwd.clone())
.apply()
Expand All @@ -791,6 +787,7 @@ impl GooseAcpAgent {
})?;

let mut session = GooseAcpSession {
agent,
messages: conversation.clone(),
tool_requests: HashMap::new(),
cancel_token: None,
Expand Down Expand Up @@ -852,7 +849,7 @@ impl GooseAcpAgent {
);

let model_state =
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
build_model_state(&*provider, &provider.get_model_config().model_name).await?;

Ok(LoadSessionResponse::new().models(model_state))
}
Expand All @@ -865,13 +862,14 @@ impl GooseAcpAgent {
let session_id = args.session_id.0.to_string();
let cancel_token = CancellationToken::new();

{
let agent = {
let mut sessions = self.sessions.lock().await;
let session = sessions.get_mut(&session_id).ok_or_else(|| {
sacp::Error::invalid_params().data(format!("Session not found: {}", session_id))
})?;
session.cancel_token = Some(cancel_token.clone());
}
session.agent.clone()
};

let user_message = self.convert_acp_prompt_to_message(args.prompt);

Expand All @@ -882,8 +880,7 @@ impl GooseAcpAgent {
retry_config: None,
};

let mut stream = self
.agent
let mut stream = agent
.reply(user_message, session_config, Some(cancel_token.clone()))
.await
.map_err(|e| {
Expand Down Expand Up @@ -959,12 +956,20 @@ impl GooseAcpAgent {
model_id: &str,
) -> Result<SetSessionModelResponse, sacp::Error> {
let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| {
sacp::Error::internal_error().data(format!("Invalid model config: {}", e))
sacp::Error::invalid_params().data(format!("Invalid model config: {}", e))
})?;
let provider = (self.provider_factory)(model_config).await.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?;
self.agent

let agent = {
let sessions = self.sessions.lock().await;
let session = sessions.get(session_id).ok_or_else(|| {
sacp::Error::invalid_params().data(format!("Session not found: {}", session_id))
})?;
session.agent.clone()
};
agent
.update_provider(provider, session_id)
.await
.map_err(|e| {
Expand Down
Loading
Loading