diff --git a/packages/cloud/src/WebAuthService.ts b/packages/cloud/src/WebAuthService.ts index cb0e08754774..e7c886ddcdcd 100644 --- a/packages/cloud/src/WebAuthService.ts +++ b/packages/cloud/src/WebAuthService.ts @@ -129,6 +129,7 @@ export class WebAuthService extends EventEmitter implements A private changeState(newState: AuthState): void { const previousState = this.state this.state = newState + this.log(`[auth] changeState: ${previousState} -> ${newState}`) this.emit("auth-state-changed", { state: newState, previousState }) } @@ -162,8 +163,6 @@ export class WebAuthService extends EventEmitter implements A this.userInfo = null this.changeState("logged-out") - - this.log("[auth] Transitioned to logged-out state") } private transitionToAttemptingSession(credentials: AuthCredentials): void { @@ -176,8 +175,6 @@ export class WebAuthService extends EventEmitter implements A this.changeState("attempting-session") this.timer.start() - - this.log("[auth] Transitioned to attempting-session state") } private transitionToInactiveSession(): void { @@ -185,8 +182,6 @@ export class WebAuthService extends EventEmitter implements A this.userInfo = null this.changeState("inactive-session") - - this.log("[auth] Transitioned to inactive-session state") } /** @@ -422,7 +417,6 @@ export class WebAuthService extends EventEmitter implements A if (previousState !== "active-session") { this.changeState("active-session") - this.log("[auth] Transitioned to active-session state") this.fetchUserInfo() } else { this.state = "active-session" diff --git a/packages/cloud/src/bridge/BaseChannel.ts b/packages/cloud/src/bridge/BaseChannel.ts new file mode 100644 index 000000000000..45b9b525f6c8 --- /dev/null +++ b/packages/cloud/src/bridge/BaseChannel.ts @@ -0,0 +1,109 @@ +import type { Socket } from "socket.io-client" + +/** + * Abstract base class for communication channels in the bridge system. + * Provides common functionality for bidirectional communication between + * the VSCode extension and web application. + * + * @template TCommand - Type of commands this channel can receive. + * @template TEvent - Type of events this channel can publish. + */ +export abstract class BaseChannel { + protected socket: Socket | null = null + protected readonly instanceId: string + + constructor(instanceId: string) { + this.instanceId = instanceId + } + + /** + * Called when socket connects. + */ + public async onConnect(socket: Socket): Promise { + this.socket = socket + await this.handleConnect(socket) + } + + /** + * Called when socket disconnects. + */ + public onDisconnect(): void { + this.socket = null + this.handleDisconnect() + } + + /** + * Called when socket reconnects. + */ + public async onReconnect(socket: Socket): Promise { + this.socket = socket + await this.handleReconnect(socket) + } + + /** + * Cleanup resources. + */ + public async cleanup(socket: Socket | null): Promise { + if (socket) { + await this.handleCleanup(socket) + } + + this.socket = null + } + + /** + * Emit a socket event with error handling. + */ + protected publish( + eventName: TEventName, + data: TEventData, + callback?: (params: Params) => void, + ): boolean { + if (!this.socket) { + console.error(`[${this.constructor.name}#emit] socket not available for ${eventName}`) + return false + } + + try { + // console.log(`[${this.constructor.name}#emit] emit() -> ${eventName}`, data) + this.socket.emit(eventName, data, callback) + + return true + } catch (error) { + console.error( + `[${this.constructor.name}#emit] emit() failed -> ${eventName}: ${ + error instanceof Error ? error.message : String(error) + }`, + ) + + return false + } + } + + /** + * Handle incoming commands - must be implemented by subclasses. + */ + public abstract handleCommand(command: TCommand): void + + /** + * Handle connection-specific logic. + */ + protected abstract handleConnect(socket: Socket): Promise + + /** + * Handle disconnection-specific logic. + */ + protected handleDisconnect(): void { + // Default implementation - can be overridden. + } + + /** + * Handle reconnection-specific logic. + */ + protected abstract handleReconnect(socket: Socket): Promise + + /** + * Handle cleanup-specific logic. + */ + protected abstract handleCleanup(socket: Socket): Promise +} diff --git a/packages/cloud/src/bridge/BridgeOrchestrator.ts b/packages/cloud/src/bridge/BridgeOrchestrator.ts new file mode 100644 index 000000000000..73b757e5c8fb --- /dev/null +++ b/packages/cloud/src/bridge/BridgeOrchestrator.ts @@ -0,0 +1,273 @@ +import crypto from "crypto" + +import { + type TaskProviderLike, + type TaskLike, + type CloudUserInfo, + type ExtensionBridgeCommand, + type TaskBridgeCommand, + ConnectionState, + ExtensionSocketEvents, + TaskSocketEvents, +} from "@roo-code/types" + +import { SocketTransport } from "./SocketTransport.js" +import { ExtensionChannel } from "./ExtensionChannel.js" +import { TaskChannel } from "./TaskChannel.js" + +export interface BridgeOrchestratorOptions { + userId: string + socketBridgeUrl: string + token: string + provider: TaskProviderLike + sessionId?: string +} + +/** + * Central orchestrator for the extension bridge system. + * Coordinates communication between the VSCode extension and web application + * through WebSocket connections and manages extension/task channels. + */ +export class BridgeOrchestrator { + private static instance: BridgeOrchestrator | null = null + + // Core + private readonly userId: string + private readonly socketBridgeUrl: string + private readonly token: string + private readonly provider: TaskProviderLike + private readonly instanceId: string + + // Components + private socketTransport: SocketTransport + private extensionChannel: ExtensionChannel + private taskChannel: TaskChannel + + // Reconnection + private readonly MAX_RECONNECT_ATTEMPTS = Infinity + private readonly RECONNECT_DELAY = 1_000 + private readonly RECONNECT_DELAY_MAX = 30_000 + + public static getInstance(): BridgeOrchestrator | null { + return BridgeOrchestrator.instance + } + + public static isEnabled(user?: CloudUserInfo | null, remoteControlEnabled?: boolean): boolean { + return !!(user?.id && user.extensionBridgeEnabled && remoteControlEnabled) + } + + public static async connectOrDisconnect( + userInfo: CloudUserInfo | null, + remoteControlEnabled: boolean | undefined, + options: BridgeOrchestratorOptions, + ): Promise { + const isEnabled = BridgeOrchestrator.isEnabled(userInfo, remoteControlEnabled) + const instance = BridgeOrchestrator.instance + + if (isEnabled) { + if (!instance) { + try { + console.log(`[BridgeOrchestrator#connectOrDisconnect] Connecting...`) + BridgeOrchestrator.instance = new BridgeOrchestrator(options) + await BridgeOrchestrator.instance.connect() + } catch (error) { + console.error( + `[BridgeOrchestrator#connectOrDisconnect] connect() failed: ${error instanceof Error ? error.message : String(error)}`, + ) + } + } else { + if ( + instance.connectionState === ConnectionState.FAILED || + instance.connectionState === ConnectionState.DISCONNECTED + ) { + console.log( + `[BridgeOrchestrator#connectOrDisconnect] Re-connecting... (state: ${instance.connectionState})`, + ) + + instance.reconnect().catch((error) => { + console.error( + `[BridgeOrchestrator#connectOrDisconnect] reconnect() failed: ${error instanceof Error ? error.message : String(error)}`, + ) + }) + } else { + console.log( + `[BridgeOrchestrator#connectOrDisconnect] Already connected or connecting (state: ${instance.connectionState})`, + ) + } + } + } else { + if (instance) { + try { + console.log( + `[BridgeOrchestrator#connectOrDisconnect] Disconnecting... (state: ${instance.connectionState})`, + ) + + await instance.disconnect() + } catch (error) { + console.error( + `[BridgeOrchestrator#connectOrDisconnect] disconnect() failed: ${error instanceof Error ? error.message : String(error)}`, + ) + } finally { + BridgeOrchestrator.instance = null + } + } else { + console.log(`[BridgeOrchestrator#connectOrDisconnect] Already disconnected`) + } + } + } + + private constructor(options: BridgeOrchestratorOptions) { + this.userId = options.userId + this.socketBridgeUrl = options.socketBridgeUrl + this.token = options.token + this.provider = options.provider + this.instanceId = options.sessionId || crypto.randomUUID() + + this.socketTransport = new SocketTransport({ + url: this.socketBridgeUrl, + socketOptions: { + query: { + token: this.token, + clientType: "extension", + instanceId: this.instanceId, + }, + transports: ["websocket", "polling"], + reconnection: true, + reconnectionAttempts: this.MAX_RECONNECT_ATTEMPTS, + reconnectionDelay: this.RECONNECT_DELAY, + reconnectionDelayMax: this.RECONNECT_DELAY_MAX, + }, + onConnect: () => this.handleConnect(), + onDisconnect: () => this.handleDisconnect(), + onReconnect: () => this.handleReconnect(), + }) + + this.extensionChannel = new ExtensionChannel(this.instanceId, this.userId, this.provider) + this.taskChannel = new TaskChannel(this.instanceId) + } + + private setupSocketListeners() { + const socket = this.socketTransport.getSocket() + + if (!socket) { + console.error("[BridgeOrchestrator] Socket not available") + return + } + + // Remove any existing listeners first to prevent duplicates. + socket.off(ExtensionSocketEvents.RELAYED_COMMAND) + socket.off(TaskSocketEvents.RELAYED_COMMAND) + socket.off("connected") + + socket.on(ExtensionSocketEvents.RELAYED_COMMAND, (message: ExtensionBridgeCommand) => { + console.log( + `[BridgeOrchestrator] on(${ExtensionSocketEvents.RELAYED_COMMAND}) -> ${message.type} for ${message.instanceId}`, + ) + + this.extensionChannel?.handleCommand(message) + }) + + socket.on(TaskSocketEvents.RELAYED_COMMAND, (message: TaskBridgeCommand) => { + console.log( + `[BridgeOrchestrator] on(${TaskSocketEvents.RELAYED_COMMAND}) -> ${message.type} for ${message.taskId}`, + ) + + this.taskChannel.handleCommand(message) + }) + } + + private async handleConnect() { + const socket = this.socketTransport.getSocket() + + if (!socket) { + console.error("[BridgeOrchestrator] Socket not available after connect") + return + } + + await this.extensionChannel.onConnect(socket) + await this.taskChannel.onConnect(socket) + } + + private handleDisconnect() { + this.extensionChannel.onDisconnect() + this.taskChannel.onDisconnect() + } + + private async handleReconnect() { + const socket = this.socketTransport.getSocket() + + if (!socket) { + console.error("[BridgeOrchestrator] Socket not available after reconnect") + return + } + + // Re-setup socket listeners to ensure they're properly configured + // after automatic reconnection (Socket.IO's built-in reconnection) + // The socket.off() calls in setupSocketListeners prevent duplicates + this.setupSocketListeners() + + await this.extensionChannel.onReconnect(socket) + await this.taskChannel.onReconnect(socket) + } + + // Task API + + public async subscribeToTask(task: TaskLike): Promise { + const socket = this.socketTransport.getSocket() + + if (!socket || !this.socketTransport.isConnected()) { + console.warn("[BridgeOrchestrator] Cannot subscribe to task: not connected. Will retry when connected.") + this.taskChannel.addPendingTask(task) + + if ( + this.connectionState === ConnectionState.DISCONNECTED || + this.connectionState === ConnectionState.FAILED + ) { + await this.connect() + } + + return + } + + await this.taskChannel.subscribeToTask(task, socket) + } + + public async unsubscribeFromTask(taskId: string): Promise { + const socket = this.socketTransport.getSocket() + + if (!socket) { + return + } + + await this.taskChannel.unsubscribeFromTask(taskId, socket) + } + + // Shared API + + public get connectionState(): ConnectionState { + return this.socketTransport.getConnectionState() + } + + private async connect(): Promise { + // Populate the app and git properties before registering the instance. + await this.provider.getTelemetryProperties() + + await this.socketTransport.connect() + this.setupSocketListeners() + } + + public async disconnect(): Promise { + await this.extensionChannel.cleanup(this.socketTransport.getSocket()) + await this.taskChannel.cleanup(this.socketTransport.getSocket()) + await this.socketTransport.disconnect() + BridgeOrchestrator.instance = null + } + + public async reconnect(): Promise { + await this.socketTransport.reconnect() + + // After a manual reconnect, we have a new socket instance + // so we need to set up listeners again. + this.setupSocketListeners() + } +} diff --git a/packages/cloud/src/bridge/ExtensionBridgeService.ts b/packages/cloud/src/bridge/ExtensionBridgeService.ts deleted file mode 100644 index 0ab7e304f206..000000000000 --- a/packages/cloud/src/bridge/ExtensionBridgeService.ts +++ /dev/null @@ -1,290 +0,0 @@ -import crypto from "crypto" - -import { - type TaskProviderLike, - type TaskLike, - type CloudUserInfo, - type ExtensionBridgeCommand, - type TaskBridgeCommand, - ConnectionState, - ExtensionSocketEvents, - TaskSocketEvents, -} from "@roo-code/types" - -import { SocketConnectionManager } from "./SocketConnectionManager.js" -import { ExtensionManager } from "./ExtensionManager.js" -import { TaskManager } from "./TaskManager.js" - -export interface ExtensionBridgeServiceOptions { - userId: string - socketBridgeUrl: string - token: string - provider: TaskProviderLike - sessionId?: string -} - -export class ExtensionBridgeService { - private static instance: ExtensionBridgeService | null = null - - // Core - private readonly userId: string - private readonly socketBridgeUrl: string - private readonly token: string - private readonly provider: TaskProviderLike - private readonly instanceId: string - - // Managers - private connectionManager: SocketConnectionManager - private extensionManager: ExtensionManager - private taskManager: TaskManager - - // Reconnection - private readonly MAX_RECONNECT_ATTEMPTS = Infinity - private readonly RECONNECT_DELAY = 1_000 - private readonly RECONNECT_DELAY_MAX = 30_000 - - public static getInstance(): ExtensionBridgeService | null { - return ExtensionBridgeService.instance - } - - public static async createInstance(options: ExtensionBridgeServiceOptions) { - console.log("[ExtensionBridgeService] createInstance") - ExtensionBridgeService.instance = new ExtensionBridgeService(options) - await ExtensionBridgeService.instance.initialize() - return ExtensionBridgeService.instance - } - - public static resetInstance() { - if (ExtensionBridgeService.instance) { - console.log("[ExtensionBridgeService] resetInstance") - ExtensionBridgeService.instance.disconnect().catch(() => {}) - ExtensionBridgeService.instance = null - } - } - - public static async handleRemoteControlState( - userInfo: CloudUserInfo | null, - remoteControlEnabled: boolean | undefined, - options: ExtensionBridgeServiceOptions, - logger?: (message: string) => void, - ) { - if (userInfo?.extensionBridgeEnabled && remoteControlEnabled) { - const existingService = ExtensionBridgeService.getInstance() - - if (!existingService) { - try { - const service = await ExtensionBridgeService.createInstance(options) - const state = service.getConnectionState() - - logger?.(`[ExtensionBridgeService#handleRemoteControlState] Instance created (state: ${state})`) - - if (state !== ConnectionState.CONNECTED) { - logger?.( - `[ExtensionBridgeService#handleRemoteControlState] Service is not connected yet, will retry in background`, - ) - } - } catch (error) { - const message = `[ExtensionBridgeService#handleRemoteControlState] Failed to create instance: ${ - error instanceof Error ? error.message : String(error) - }` - - logger?.(message) - console.error(message) - } - } else { - const state = existingService.getConnectionState() - - if (state === ConnectionState.FAILED || state === ConnectionState.DISCONNECTED) { - logger?.( - `[ExtensionBridgeService#handleRemoteControlState] Existing service is ${state}, attempting reconnection`, - ) - - existingService.reconnect().catch((error) => { - const message = `[ExtensionBridgeService#handleRemoteControlState] Reconnection failed: ${ - error instanceof Error ? error.message : String(error) - }` - - logger?.(message) - console.error(message) - }) - } - } - } else { - const existingService = ExtensionBridgeService.getInstance() - - if (existingService) { - try { - await existingService.disconnect() - ExtensionBridgeService.resetInstance() - - logger?.(`[ExtensionBridgeService#handleRemoteControlState] Service disconnected and reset`) - } catch (error) { - const message = `[ExtensionBridgeService#handleRemoteControlState] Failed to disconnect and reset instance: ${ - error instanceof Error ? error.message : String(error) - }` - - logger?.(message) - console.error(message) - } - } - } - } - - private constructor(options: ExtensionBridgeServiceOptions) { - this.userId = options.userId - this.socketBridgeUrl = options.socketBridgeUrl - this.token = options.token - this.provider = options.provider - this.instanceId = options.sessionId || crypto.randomUUID() - - this.connectionManager = new SocketConnectionManager({ - url: this.socketBridgeUrl, - socketOptions: { - query: { - token: this.token, - clientType: "extension", - instanceId: this.instanceId, - }, - transports: ["websocket", "polling"], - reconnection: true, - reconnectionAttempts: this.MAX_RECONNECT_ATTEMPTS, - reconnectionDelay: this.RECONNECT_DELAY, - reconnectionDelayMax: this.RECONNECT_DELAY_MAX, - }, - onConnect: () => this.handleConnect(), - onDisconnect: () => this.handleDisconnect(), - onReconnect: () => this.handleReconnect(), - }) - - this.extensionManager = new ExtensionManager(this.instanceId, this.userId, this.provider) - - this.taskManager = new TaskManager() - } - - private async initialize() { - // Populate the app and git properties before registering the instance. - await this.provider.getTelemetryProperties() - - await this.connectionManager.connect() - this.setupSocketListeners() - } - - private setupSocketListeners() { - const socket = this.connectionManager.getSocket() - - if (!socket) { - console.error("[ExtensionBridgeService] Socket not available") - return - } - - // Remove any existing listeners first to prevent duplicates. - socket.off(ExtensionSocketEvents.RELAYED_COMMAND) - socket.off(TaskSocketEvents.RELAYED_COMMAND) - socket.off("connected") - - socket.on(ExtensionSocketEvents.RELAYED_COMMAND, (message: ExtensionBridgeCommand) => { - console.log( - `[ExtensionBridgeService] on(${ExtensionSocketEvents.RELAYED_COMMAND}) -> ${message.type} for ${message.instanceId}`, - ) - - this.extensionManager?.handleExtensionCommand(message) - }) - - socket.on(TaskSocketEvents.RELAYED_COMMAND, (message: TaskBridgeCommand) => { - console.log( - `[ExtensionBridgeService] on(${TaskSocketEvents.RELAYED_COMMAND}) -> ${message.type} for ${message.taskId}`, - ) - - this.taskManager.handleTaskCommand(message) - }) - } - - private async handleConnect() { - const socket = this.connectionManager.getSocket() - - if (!socket) { - console.error("[ExtensionBridgeService] Socket not available after connect") - - return - } - - await this.extensionManager.onConnect(socket) - await this.taskManager.onConnect(socket) - } - - private handleDisconnect() { - this.extensionManager.onDisconnect() - this.taskManager.onDisconnect() - } - - private async handleReconnect() { - const socket = this.connectionManager.getSocket() - - if (!socket) { - console.error("[ExtensionBridgeService] Socket not available after reconnect") - - return - } - - // Re-setup socket listeners to ensure they're properly configured - // after automatic reconnection (Socket.IO's built-in reconnection) - // The socket.off() calls in setupSocketListeners prevent duplicates - this.setupSocketListeners() - - await this.extensionManager.onReconnect(socket) - await this.taskManager.onReconnect(socket) - } - - // Task API - - public async subscribeToTask(task: TaskLike): Promise { - const socket = this.connectionManager.getSocket() - - if (!socket || !this.connectionManager.isConnected()) { - console.warn("[ExtensionBridgeService] Cannot subscribe to task: not connected. Will retry when connected.") - - this.taskManager.addPendingTask(task) - - const state = this.connectionManager.getConnectionState() - - if (state === ConnectionState.DISCONNECTED || state === ConnectionState.FAILED) { - this.initialize() - } - - return - } - - await this.taskManager.subscribeToTask(task, socket) - } - - public async unsubscribeFromTask(taskId: string): Promise { - const socket = this.connectionManager.getSocket() - - if (!socket) { - return - } - - await this.taskManager.unsubscribeFromTask(taskId, socket) - } - - // Shared API - - public getConnectionState(): ConnectionState { - return this.connectionManager.getConnectionState() - } - - public async disconnect(): Promise { - await this.extensionManager.cleanup(this.connectionManager.getSocket()) - await this.taskManager.cleanup(this.connectionManager.getSocket()) - await this.connectionManager.disconnect() - ExtensionBridgeService.instance = null - } - - public async reconnect(): Promise { - await this.connectionManager.reconnect() - - // After a manual reconnect, we have a new socket instance - // so we need to set up listeners again. - this.setupSocketListeners() - } -} diff --git a/packages/cloud/src/bridge/ExtensionChannel.ts b/packages/cloud/src/bridge/ExtensionChannel.ts new file mode 100644 index 000000000000..99649f76f4c3 --- /dev/null +++ b/packages/cloud/src/bridge/ExtensionChannel.ts @@ -0,0 +1,220 @@ +import type { Socket } from "socket.io-client" + +import { + type TaskProviderLike, + type TaskProviderEvents, + type ExtensionInstance, + type ExtensionBridgeCommand, + type ExtensionBridgeEvent, + RooCodeEventName, + TaskStatus, + ExtensionBridgeCommandName, + ExtensionBridgeEventName, + ExtensionSocketEvents, + HEARTBEAT_INTERVAL_MS, +} from "@roo-code/types" + +import { BaseChannel } from "./BaseChannel.js" + +/** + * Manages the extension-level communication channel. + * Handles extension registration, heartbeat, and extension-specific commands. + */ +export class ExtensionChannel extends BaseChannel< + ExtensionBridgeCommand, + ExtensionSocketEvents, + ExtensionBridgeEvent | ExtensionInstance +> { + private userId: string + private provider: TaskProviderLike + private extensionInstance: ExtensionInstance + private heartbeatInterval: NodeJS.Timeout | null = null + private eventListeners: Map void> = new Map() + + constructor(instanceId: string, userId: string, provider: TaskProviderLike) { + super(instanceId) + this.userId = userId + this.provider = provider + + this.extensionInstance = { + instanceId: this.instanceId, + userId: this.userId, + workspacePath: this.provider.cwd, + appProperties: this.provider.appProperties, + gitProperties: this.provider.gitProperties, + lastHeartbeat: Date.now(), + task: { + taskId: "", + taskStatus: TaskStatus.None, + }, + taskHistory: [], + } + + this.setupListeners() + } + + /** + * Handle extension-specific commands from the web app + */ + public handleCommand(command: ExtensionBridgeCommand): void { + if (command.instanceId !== this.instanceId) { + console.log(`[ExtensionChannel] command -> instance id mismatch | ${this.instanceId}`, { + messageInstanceId: command.instanceId, + }) + return + } + + switch (command.type) { + case ExtensionBridgeCommandName.StartTask: { + console.log(`[ExtensionChannel] command -> createTask() | ${command.instanceId}`, { + text: command.payload.text?.substring(0, 100) + "...", + hasImages: !!command.payload.images, + }) + + this.provider.createTask(command.payload.text, command.payload.images) + break + } + case ExtensionBridgeCommandName.StopTask: { + const instance = this.updateInstance() + + if (instance.task.taskStatus === TaskStatus.Running) { + console.log(`[ExtensionChannel] command -> cancelTask() | ${command.instanceId}`) + this.provider.cancelTask() + this.provider.postStateToWebview() + } else if (instance.task.taskId) { + console.log(`[ExtensionChannel] command -> clearTask() | ${command.instanceId}`) + this.provider.clearTask() + this.provider.postStateToWebview() + } + break + } + case ExtensionBridgeCommandName.ResumeTask: { + console.log(`[ExtensionChannel] command -> resumeTask() | ${command.instanceId}`, { + taskId: command.payload.taskId, + }) + + // Resume the task from history by taskId + this.provider.resumeTask(command.payload.taskId) + this.provider.postStateToWebview() + break + } + } + } + + protected async handleConnect(socket: Socket): Promise { + await this.registerInstance(socket) + this.startHeartbeat(socket) + } + + protected async handleReconnect(socket: Socket): Promise { + await this.registerInstance(socket) + this.startHeartbeat(socket) + } + + protected override handleDisconnect(): void { + this.stopHeartbeat() + } + + protected async handleCleanup(socket: Socket): Promise { + this.stopHeartbeat() + this.cleanupListeners() + await this.unregisterInstance(socket) + } + + private async registerInstance(_socket: Socket): Promise { + const instance = this.updateInstance() + await this.publish(ExtensionSocketEvents.REGISTER, instance) + } + + private async unregisterInstance(_socket: Socket): Promise { + const instance = this.updateInstance() + await this.publish(ExtensionSocketEvents.UNREGISTER, instance) + } + + private startHeartbeat(socket: Socket): void { + this.stopHeartbeat() + + this.heartbeatInterval = setInterval(async () => { + const instance = this.updateInstance() + + try { + socket.emit(ExtensionSocketEvents.HEARTBEAT, instance) + // Heartbeat is too frequent to log + } catch (error) { + console.error( + `[ExtensionChannel] emit() failed -> ${ExtensionSocketEvents.HEARTBEAT}: ${ + error instanceof Error ? error.message : String(error) + }`, + ) + } + }, HEARTBEAT_INTERVAL_MS) + } + + private stopHeartbeat(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval) + this.heartbeatInterval = null + } + } + + private setupListeners(): void { + const eventMapping = [ + { from: RooCodeEventName.TaskCreated, to: ExtensionBridgeEventName.TaskCreated }, + { from: RooCodeEventName.TaskStarted, to: ExtensionBridgeEventName.TaskStarted }, + { from: RooCodeEventName.TaskCompleted, to: ExtensionBridgeEventName.TaskCompleted }, + { from: RooCodeEventName.TaskAborted, to: ExtensionBridgeEventName.TaskAborted }, + { from: RooCodeEventName.TaskFocused, to: ExtensionBridgeEventName.TaskFocused }, + { from: RooCodeEventName.TaskUnfocused, to: ExtensionBridgeEventName.TaskUnfocused }, + { from: RooCodeEventName.TaskActive, to: ExtensionBridgeEventName.TaskActive }, + { from: RooCodeEventName.TaskInteractive, to: ExtensionBridgeEventName.TaskInteractive }, + { from: RooCodeEventName.TaskResumable, to: ExtensionBridgeEventName.TaskResumable }, + { from: RooCodeEventName.TaskIdle, to: ExtensionBridgeEventName.TaskIdle }, + ] as const + + eventMapping.forEach(({ from, to }) => { + // Create and store the listener function for cleanup/ + const listener = (..._args: unknown[]) => { + this.publish(ExtensionSocketEvents.EVENT, { + type: to, + instance: this.updateInstance(), + timestamp: Date.now(), + }) + } + + this.eventListeners.set(from, listener) + this.provider.on(from, listener) + }) + } + + private cleanupListeners(): void { + this.eventListeners.forEach((listener, eventName) => { + // Cast is safe because we only store valid event names from eventMapping. + this.provider.off(eventName as keyof TaskProviderEvents, listener) + }) + + this.eventListeners.clear() + } + + private updateInstance(): ExtensionInstance { + const task = this.provider?.getCurrentTask() + const taskHistory = this.provider?.getRecentTasks() ?? [] + + this.extensionInstance = { + ...this.extensionInstance, + appProperties: this.extensionInstance.appProperties ?? this.provider.appProperties, + gitProperties: this.extensionInstance.gitProperties ?? this.provider.gitProperties, + lastHeartbeat: Date.now(), + task: task + ? { + taskId: task.taskId, + taskStatus: task.taskStatus, + ...task.metadata, + } + : { taskId: "", taskStatus: TaskStatus.None }, + taskAsk: task?.taskAsk, + taskHistory, + } + + return this.extensionInstance + } +} diff --git a/packages/cloud/src/bridge/ExtensionManager.ts b/packages/cloud/src/bridge/ExtensionManager.ts deleted file mode 100644 index 335245e24c81..000000000000 --- a/packages/cloud/src/bridge/ExtensionManager.ts +++ /dev/null @@ -1,297 +0,0 @@ -import type { Socket } from "socket.io-client" - -import { - type TaskProviderLike, - type ExtensionInstance, - type ExtensionBridgeCommand, - type ExtensionBridgeEvent, - RooCodeEventName, - TaskStatus, - ExtensionBridgeCommandName, - ExtensionBridgeEventName, - ExtensionSocketEvents, - HEARTBEAT_INTERVAL_MS, -} from "@roo-code/types" - -export class ExtensionManager { - private instanceId: string - private userId: string - private provider: TaskProviderLike - private extensionInstance: ExtensionInstance - private heartbeatInterval: NodeJS.Timeout | null = null - private socket: Socket | null = null - - constructor(instanceId: string, userId: string, provider: TaskProviderLike) { - this.instanceId = instanceId - this.userId = userId - this.provider = provider - - this.extensionInstance = { - instanceId: this.instanceId, - userId: this.userId, - workspacePath: this.provider.cwd, - appProperties: this.provider.appProperties, - gitProperties: this.provider.gitProperties, - lastHeartbeat: Date.now(), - task: { - taskId: "", - taskStatus: TaskStatus.None, - }, - taskHistory: [], - } - - this.setupListeners() - } - - public async onConnect(socket: Socket): Promise { - this.socket = socket - await this.registerInstance(socket) - this.startHeartbeat(socket) - } - - public onDisconnect(): void { - this.stopHeartbeat() - this.socket = null - } - - public async onReconnect(socket: Socket): Promise { - this.socket = socket - await this.registerInstance(socket) - this.startHeartbeat(socket) - } - - public async cleanup(socket: Socket | null): Promise { - this.stopHeartbeat() - - if (socket) { - await this.unregisterInstance(socket) - } - - this.socket = null - } - - public handleExtensionCommand(message: ExtensionBridgeCommand): void { - if (message.instanceId !== this.instanceId) { - console.log(`[ExtensionManager] command -> instance id mismatch | ${this.instanceId}`, { - messageInstanceId: message.instanceId, - }) - - return - } - - switch (message.type) { - case ExtensionBridgeCommandName.StartTask: { - console.log(`[ExtensionManager] command -> createTask() | ${message.instanceId}`, { - text: message.payload.text?.substring(0, 100) + "...", - hasImages: !!message.payload.images, - }) - - this.provider.createTask(message.payload.text, message.payload.images) - - break - } - case ExtensionBridgeCommandName.StopTask: { - const instance = this.updateInstance() - - if (instance.task.taskStatus === TaskStatus.Running) { - console.log(`[ExtensionManager] command -> cancelTask() | ${message.instanceId}`) - - this.provider.cancelTask() - this.provider.postStateToWebview() - } else if (instance.task.taskId) { - console.log(`[ExtensionManager] command -> clearTask() | ${message.instanceId}`) - - this.provider.clearTask() - this.provider.postStateToWebview() - } - - break - } - case ExtensionBridgeCommandName.ResumeTask: { - console.log(`[ExtensionManager] command -> resumeTask() | ${message.instanceId}`, { - taskId: message.payload.taskId, - }) - - // Resume the task from history by taskId - this.provider.resumeTask(message.payload.taskId) - - this.provider.postStateToWebview() - - break - } - } - } - - private async registerInstance(socket: Socket): Promise { - const instance = this.updateInstance() - - try { - socket.emit(ExtensionSocketEvents.REGISTER, instance) - - console.log( - `[ExtensionManager] emit() -> ${ExtensionSocketEvents.REGISTER}`, - // instance, - ) - } catch (error) { - console.error( - `[ExtensionManager] emit() failed -> ${ExtensionSocketEvents.REGISTER}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - - return - } - } - - private async unregisterInstance(socket: Socket): Promise { - const instance = this.updateInstance() - - try { - socket.emit(ExtensionSocketEvents.UNREGISTER, instance) - - console.log( - `[ExtensionManager] emit() -> ${ExtensionSocketEvents.UNREGISTER}`, - // instance, - ) - } catch (error) { - console.error( - `[ExtensionManager] emit() failed -> ${ExtensionSocketEvents.UNREGISTER}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - } - - private startHeartbeat(socket: Socket): void { - this.stopHeartbeat() - - this.heartbeatInterval = setInterval(async () => { - const instance = this.updateInstance() - - try { - socket.emit(ExtensionSocketEvents.HEARTBEAT, instance) - - // console.log( - // `[ExtensionManager] emit() -> ${ExtensionSocketEvents.HEARTBEAT}`, - // instance, - // ); - } catch (error) { - console.error( - `[ExtensionManager] emit() failed -> ${ExtensionSocketEvents.HEARTBEAT}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - }, HEARTBEAT_INTERVAL_MS) - } - - private stopHeartbeat(): void { - if (this.heartbeatInterval) { - clearInterval(this.heartbeatInterval) - this.heartbeatInterval = null - } - } - - private setupListeners(): void { - const eventMapping = [ - { - from: RooCodeEventName.TaskCreated, - to: ExtensionBridgeEventName.TaskCreated, - }, - { - from: RooCodeEventName.TaskStarted, - to: ExtensionBridgeEventName.TaskStarted, - }, - { - from: RooCodeEventName.TaskCompleted, - to: ExtensionBridgeEventName.TaskCompleted, - }, - { - from: RooCodeEventName.TaskAborted, - to: ExtensionBridgeEventName.TaskAborted, - }, - { - from: RooCodeEventName.TaskFocused, - to: ExtensionBridgeEventName.TaskFocused, - }, - { - from: RooCodeEventName.TaskUnfocused, - to: ExtensionBridgeEventName.TaskUnfocused, - }, - { - from: RooCodeEventName.TaskActive, - to: ExtensionBridgeEventName.TaskActive, - }, - { - from: RooCodeEventName.TaskInteractive, - to: ExtensionBridgeEventName.TaskInteractive, - }, - { - from: RooCodeEventName.TaskResumable, - to: ExtensionBridgeEventName.TaskResumable, - }, - { - from: RooCodeEventName.TaskIdle, - to: ExtensionBridgeEventName.TaskIdle, - }, - ] as const - - const addListener = - (type: ExtensionBridgeEventName) => - async (..._args: unknown[]) => { - this.publishEvent({ - type, - instance: this.updateInstance(), - timestamp: Date.now(), - }) - } - - eventMapping.forEach(({ from, to }) => this.provider.on(from, addListener(to))) - } - - private async publishEvent(message: ExtensionBridgeEvent): Promise { - if (!this.socket) { - console.error("[ExtensionManager] publishEvent -> socket not available") - return false - } - - try { - this.socket.emit(ExtensionSocketEvents.EVENT, message) - - console.log(`[ExtensionManager] emit() -> ${ExtensionSocketEvents.EVENT} ${message.type}`, message) - - return true - } catch (error) { - console.error( - `[ExtensionManager] emit() failed -> ${ExtensionSocketEvents.EVENT}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - - return false - } - } - - private updateInstance(): ExtensionInstance { - const task = this.provider?.getCurrentTask() - const taskHistory = this.provider?.getRecentTasks() ?? [] - - this.extensionInstance = { - ...this.extensionInstance, - appProperties: this.extensionInstance.appProperties ?? this.provider.appProperties, - gitProperties: this.extensionInstance.gitProperties ?? this.provider.gitProperties, - lastHeartbeat: Date.now(), - task: task - ? { - taskId: task.taskId, - taskStatus: task.taskStatus, - ...task.metadata, - } - : { taskId: "", taskStatus: TaskStatus.None }, - taskAsk: task?.taskAsk, - taskHistory, - } - - return this.extensionInstance - } -} diff --git a/packages/cloud/src/bridge/SocketConnectionManager.ts b/packages/cloud/src/bridge/SocketTransport.ts similarity index 72% rename from packages/cloud/src/bridge/SocketConnectionManager.ts rename to packages/cloud/src/bridge/SocketTransport.ts index 3ba9631fec21..5fb40e989c8a 100644 --- a/packages/cloud/src/bridge/SocketConnectionManager.ts +++ b/packages/cloud/src/bridge/SocketTransport.ts @@ -1,10 +1,10 @@ -import { io, type Socket } from "socket.io-client" +import { io, type Socket, type SocketOptions, type ManagerOptions } from "socket.io-client" import { ConnectionState, type RetryConfig } from "@roo-code/types" -export interface SocketConnectionOptions { +export interface SocketTransportOptions { url: string - socketOptions: Record + socketOptions: Partial onConnect?: () => void | Promise onDisconnect?: (reason: string) => void onReconnect?: (attemptNumber: number) => void | Promise @@ -16,7 +16,11 @@ export interface SocketConnectionOptions { } } -export class SocketConnectionManager { +/** + * Manages the WebSocket transport layer for the bridge system. + * Handles connection lifecycle, retries, and reconnection logic. + */ +export class SocketTransport { private socket: Socket | null = null private connectionState: ConnectionState = ConnectionState.DISCONNECTED private retryAttempt: number = 0 @@ -31,9 +35,9 @@ export class SocketConnectionManager { } private readonly CONNECTION_TIMEOUT = 2_000 - private readonly options: SocketConnectionOptions + private readonly options: SocketTransportOptions - constructor(options: SocketConnectionOptions, retryConfig?: Partial) { + constructor(options: SocketTransportOptions, retryConfig?: Partial) { this.options = options if (retryConfig) { @@ -43,13 +47,12 @@ export class SocketConnectionManager { public async connect(): Promise { if (this.connectionState === ConnectionState.CONNECTED) { - console.log(`[SocketConnectionManager] Already connected`) + console.log(`[SocketTransport] Already connected`) return } if (this.connectionState === ConnectionState.CONNECTING || this.connectionState === ConnectionState.RETRYING) { - console.log(`[SocketConnectionManager] Connection attempt already in progress`) - + console.log(`[SocketTransport] Connection attempt already in progress`) return } @@ -63,7 +66,9 @@ export class SocketConnectionManager { try { await this.connectWithRetry() } catch (error) { - console.error(`[SocketConnectionManager] Initial connection attempts failed:`, error) + console.error( + `[SocketTransport] Initial connection attempts failed: ${error instanceof Error ? error.message : String(error)}`, + ) // If we've never connected successfully, we've exhausted our retry attempts // The user will need to manually retry or fix the issue @@ -79,12 +84,12 @@ export class SocketConnectionManager { this.connectionState = this.retryAttempt === 0 ? ConnectionState.CONNECTING : ConnectionState.RETRYING console.log( - `[SocketConnectionManager] Connection attempt ${this.retryAttempt + 1} / ${this.retryConfig.maxInitialAttempts}`, + `[SocketTransport] Connection attempt ${this.retryAttempt + 1} / ${this.retryConfig.maxInitialAttempts}`, ) await this.connectSocket() - console.log(`[SocketConnectionManager] Connected to ${this.options.url}`) + console.log(`[SocketTransport] Connected to ${this.options.url}`) this.connectionState = ConnectionState.CONNECTED this.retryAttempt = 0 @@ -99,7 +104,7 @@ export class SocketConnectionManager { } catch (error) { this.retryAttempt++ - console.error(`[SocketConnectionManager] Connection attempt ${this.retryAttempt} failed:`, error) + console.error(`[SocketTransport] Connection attempt ${this.retryAttempt} failed:`, error) if (this.socket) { this.socket.disconnect() @@ -112,7 +117,7 @@ export class SocketConnectionManager { throw new Error(`Failed to connect after ${this.retryConfig.maxInitialAttempts} attempts`) } - console.log(`[SocketConnectionManager] Waiting ${delay}ms before retry...`) + console.log(`[SocketTransport] Waiting ${delay}ms before retry...`) await this.delay(delay) @@ -126,7 +131,7 @@ export class SocketConnectionManager { this.socket = io(this.options.url, this.options.socketOptions) const connectionTimeout = setTimeout(() => { - console.error(`[SocketConnectionManager] Connection timeout`) + console.error(`[SocketTransport] Connection timeout`) if (this.connectionState !== ConnectionState.CONNECTED) { this.socket?.disconnect() @@ -140,12 +145,9 @@ export class SocketConnectionManager { const isReconnection = this.hasConnectedOnce // If this is a reconnection (not the first connect), treat it as a - // reconnect. - // This handles server restarts where 'reconnect' event might not fire. + // reconnect. This handles server restarts where 'reconnect' event might not fire. if (isReconnection) { - console.log( - `[SocketConnectionManager] Treating connect as reconnection (server may have restarted)`, - ) + console.log(`[SocketTransport] Treating connect as reconnection (server may have restarted)`) this.connectionState = ConnectionState.CONNECTED @@ -160,7 +162,7 @@ export class SocketConnectionManager { }) this.socket.on("disconnect", (reason: string) => { - console.log(`[SocketConnectionManager] Disconnected (reason: ${reason})`) + console.log(`[SocketTransport] Disconnected (reason: ${reason})`) this.connectionState = ConnectionState.DISCONNECTED @@ -174,19 +176,19 @@ export class SocketConnectionManager { if (!isManualDisconnect && this.hasConnectedOnce) { // After successful initial connection, rely entirely on Socket.IO's // reconnection. - console.log(`[SocketConnectionManager] Socket.IO will handle reconnection (reason: ${reason})`) + console.log(`[SocketTransport] Socket.IO will handle reconnection (reason: ${reason})`) } }) // Listen for reconnection attempts. this.socket.on("reconnect_attempt", (attemptNumber: number) => { - console.log(`[SocketConnectionManager] Socket.IO reconnect attempt:`, { + console.log(`[SocketTransport] Socket.IO reconnect attempt:`, { attemptNumber, }) }) this.socket.on("reconnect", (attemptNumber: number) => { - console.log(`[SocketConnectionManager] Socket reconnected (attempt: ${attemptNumber})`) + console.log(`[SocketTransport] Socket reconnected (attempt: ${attemptNumber})`) this.connectionState = ConnectionState.CONNECTED @@ -196,11 +198,11 @@ export class SocketConnectionManager { }) this.socket.on("reconnect_error", (error: Error) => { - console.error(`[SocketConnectionManager] Socket.IO reconnect error:`, error) + console.error(`[SocketTransport] Socket.IO reconnect error:`, error) }) this.socket.on("reconnect_failed", () => { - console.error(`[SocketConnectionManager] Socket.IO reconnection failed after all attempts`) + console.error(`[SocketTransport] Socket.IO reconnection failed after all attempts`) this.connectionState = ConnectionState.FAILED @@ -209,7 +211,7 @@ export class SocketConnectionManager { }) this.socket.on("error", (error) => { - console.error(`[SocketConnectionManager] Socket error:`, error) + console.error(`[SocketTransport] Socket error:`, error) if (this.connectionState !== ConnectionState.CONNECTED) { clearTimeout(connectionTimeout) @@ -222,7 +224,7 @@ export class SocketConnectionManager { }) this.socket.on("auth_error", (error) => { - console.error(`[SocketConnectionManager] Authentication error:`, error) + console.error(`[SocketTransport] Authentication error:`, error) clearTimeout(connectionTimeout) reject(new Error(error.message || "Authentication failed")) }) @@ -235,9 +237,6 @@ export class SocketConnectionManager { }) } - // 1. Custom retry for initial connection attempts. - // 2. Socket.IO's built-in reconnection after successful initial connection. - private clearRetryTimeouts() { if (this.retryTimeout) { clearTimeout(this.retryTimeout) @@ -246,7 +245,7 @@ export class SocketConnectionManager { } public async disconnect(): Promise { - console.log(`[SocketConnectionManager] Disconnecting...`) + console.log(`[SocketTransport] Disconnecting...`) this.clearRetryTimeouts() @@ -258,7 +257,7 @@ export class SocketConnectionManager { this.connectionState = ConnectionState.DISCONNECTED - console.log(`[SocketConnectionManager] Disconnected`) + console.log(`[SocketTransport] Disconnected`) } public getSocket(): Socket | null { @@ -275,11 +274,11 @@ export class SocketConnectionManager { public async reconnect(): Promise { if (this.connectionState === ConnectionState.CONNECTED) { - console.log(`[SocketConnectionManager] Already connected`) + console.log(`[SocketTransport] Already connected`) return } - console.log(`[SocketConnectionManager] Manual reconnection requested`) + console.log(`[SocketTransport] Manual reconnection requested`) this.hasConnectedOnce = false diff --git a/packages/cloud/src/bridge/TaskChannel.ts b/packages/cloud/src/bridge/TaskChannel.ts new file mode 100644 index 000000000000..f4656dc6d2f6 --- /dev/null +++ b/packages/cloud/src/bridge/TaskChannel.ts @@ -0,0 +1,228 @@ +import type { Socket } from "socket.io-client" + +import { + type ClineMessage, + type TaskEvents, + type TaskLike, + type TaskBridgeCommand, + type TaskBridgeEvent, + type JoinResponse, + type LeaveResponse, + RooCodeEventName, + TaskBridgeEventName, + TaskBridgeCommandName, + TaskSocketEvents, +} from "@roo-code/types" + +import { BaseChannel } from "./BaseChannel.js" + +type TaskEventListener = { + [K in keyof TaskEvents]: (...args: TaskEvents[K]) => void | Promise +}[keyof TaskEvents] + +type TaskEventMapping = { + from: keyof TaskEvents + to: TaskBridgeEventName + createPayload: (task: TaskLike, ...args: any[]) => any // eslint-disable-line @typescript-eslint/no-explicit-any +} + +/** + * Manages task-level communication channels. + * Handles task subscriptions, messaging, and task-specific commands. + */ +export class TaskChannel extends BaseChannel< + TaskBridgeCommand, + TaskSocketEvents, + TaskBridgeEvent | { taskId: string } +> { + private subscribedTasks: Map = new Map() + private pendingTasks: Map = new Map() + private taskListeners: Map> = new Map() + + private readonly eventMapping: readonly TaskEventMapping[] = [ + { + from: RooCodeEventName.Message, + to: TaskBridgeEventName.Message, + createPayload: (task: TaskLike, data: { action: string; message: ClineMessage }) => ({ + type: TaskBridgeEventName.Message, + taskId: task.taskId, + action: data.action, + message: data.message, + }), + }, + { + from: RooCodeEventName.TaskModeSwitched, + to: TaskBridgeEventName.TaskModeSwitched, + createPayload: (task: TaskLike, mode: string) => ({ + type: TaskBridgeEventName.TaskModeSwitched, + taskId: task.taskId, + mode, + }), + }, + { + from: RooCodeEventName.TaskInteractive, + to: TaskBridgeEventName.TaskInteractive, + createPayload: (task: TaskLike, _taskId: string) => ({ + type: TaskBridgeEventName.TaskInteractive, + taskId: task.taskId, + }), + }, + ] as const + + constructor(instanceId: string) { + super(instanceId) + } + + public handleCommand(command: TaskBridgeCommand): void { + const task = this.subscribedTasks.get(command.taskId) + + if (!task) { + console.error(`[TaskChannel] Unable to find task ${command.taskId}`) + return + } + + switch (command.type) { + case TaskBridgeCommandName.Message: + console.log( + `[TaskChannel] ${TaskBridgeCommandName.Message} ${command.taskId} -> submitUserMessage()`, + command, + ) + task.submitUserMessage(command.payload.text, command.payload.images) + break + + case TaskBridgeCommandName.ApproveAsk: + console.log( + `[TaskChannel] ${TaskBridgeCommandName.ApproveAsk} ${command.taskId} -> approveAsk()`, + command, + ) + task.approveAsk(command.payload) + break + + case TaskBridgeCommandName.DenyAsk: + console.log(`[TaskChannel] ${TaskBridgeCommandName.DenyAsk} ${command.taskId} -> denyAsk()`, command) + task.denyAsk(command.payload) + break + } + } + + protected async handleConnect(socket: Socket): Promise { + // Rejoin all subscribed tasks. + for (const taskId of this.subscribedTasks.keys()) { + await this.publish(TaskSocketEvents.JOIN, { taskId }) + } + + // Subscribe to any pending tasks. + for (const task of this.pendingTasks.values()) { + await this.subscribeToTask(task, socket) + } + + this.pendingTasks.clear() + } + + protected async handleReconnect(_socket: Socket): Promise { + // Rejoin all subscribed tasks. + for (const taskId of this.subscribedTasks.keys()) { + await this.publish(TaskSocketEvents.JOIN, { taskId }) + } + } + + protected async handleCleanup(socket: Socket): Promise { + const unsubscribePromises = [] + + for (const taskId of this.subscribedTasks.keys()) { + unsubscribePromises.push(this.unsubscribeFromTask(taskId, socket)) + } + + await Promise.allSettled(unsubscribePromises) + this.subscribedTasks.clear() + this.taskListeners.clear() + this.pendingTasks.clear() + } + + /** + * Add a task to the pending queue (will be subscribed when connected). + */ + public addPendingTask(task: TaskLike): void { + this.pendingTasks.set(task.taskId, task) + } + + public async subscribeToTask(task: TaskLike, _socket: Socket): Promise { + const taskId = task.taskId + + await this.publish(TaskSocketEvents.JOIN, { taskId }, (response: JoinResponse) => { + if (response.success) { + console.log(`[TaskChannel#subscribeToTask] subscribed to ${taskId}`) + this.subscribedTasks.set(taskId, task) + this.setupTaskListeners(task) + } else { + console.error(`[TaskChannel#subscribeToTask] failed to subscribe to ${taskId}: ${response.error}`) + } + }) + } + + public async unsubscribeFromTask(taskId: string, _socket: Socket): Promise { + const task = this.subscribedTasks.get(taskId) + + await this.publish(TaskSocketEvents.LEAVE, { taskId }, (response: LeaveResponse) => { + if (response.success) { + console.log(`[TaskChannel#unsubscribeFromTask] unsubscribed from ${taskId}`, response) + } else { + console.error(`[TaskChannel#unsubscribeFromTask] failed to unsubscribe from ${taskId}`) + } + + // If we failed to unsubscribe then something is probably wrong and + // we should still discard this task from `subscribedTasks`. + if (task) { + this.removeTaskListeners(task) + this.subscribedTasks.delete(taskId) + } + }) + } + + private setupTaskListeners(task: TaskLike): void { + if (this.taskListeners.has(task.taskId)) { + console.warn("[TaskChannel] Listeners already exist for task, removing old listeners:", task.taskId) + this.removeTaskListeners(task) + } + + const listeners = new Map() + + this.eventMapping.forEach(({ from, to, createPayload }) => { + const listener = (...args: unknown[]) => { + const payload = createPayload(task, ...args) + this.publish(TaskSocketEvents.EVENT, payload) + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + task.on(from, listener as any) + listeners.set(to, listener) + }) + + this.taskListeners.set(task.taskId, listeners) + } + + private removeTaskListeners(task: TaskLike): void { + const listeners = this.taskListeners.get(task.taskId) + + if (!listeners) { + return + } + + this.eventMapping.forEach(({ from, to }) => { + const listener = listeners.get(to) + if (listener) { + try { + task.off(from, listener as any) // eslint-disable-line @typescript-eslint/no-explicit-any + } catch (error) { + console.error( + `[TaskChannel] task.off(${from}) failed for task ${task.taskId}: ${ + error instanceof Error ? error.message : String(error) + }`, + ) + } + } + }) + + this.taskListeners.delete(task.taskId) + } +} diff --git a/packages/cloud/src/bridge/TaskManager.ts b/packages/cloud/src/bridge/TaskManager.ts deleted file mode 100644 index 3940d59f259b..000000000000 --- a/packages/cloud/src/bridge/TaskManager.ts +++ /dev/null @@ -1,279 +0,0 @@ -import type { Socket } from "socket.io-client" - -import { - type ClineMessage, - type TaskEvents, - type TaskLike, - type TaskBridgeCommand, - type TaskBridgeEvent, - RooCodeEventName, - TaskBridgeEventName, - TaskBridgeCommandName, - TaskSocketEvents, -} from "@roo-code/types" - -type TaskEventListener = { - [K in keyof TaskEvents]: (...args: TaskEvents[K]) => void | Promise -}[keyof TaskEvents] - -const TASK_EVENT_MAPPING: Record = { - [TaskBridgeEventName.Message]: RooCodeEventName.Message, - [TaskBridgeEventName.TaskModeSwitched]: RooCodeEventName.TaskModeSwitched, - [TaskBridgeEventName.TaskInteractive]: RooCodeEventName.TaskInteractive, -} - -export class TaskManager { - private subscribedTasks: Map = new Map() - private pendingTasks: Map = new Map() - private socket: Socket | null = null - - private taskListeners: Map> = new Map() - - constructor() {} - - public async onConnect(socket: Socket): Promise { - this.socket = socket - - // Rejoin all subscribed tasks. - for (const taskId of this.subscribedTasks.keys()) { - try { - socket.emit(TaskSocketEvents.JOIN, { taskId }) - - console.log(`[TaskManager] emit() -> ${TaskSocketEvents.JOIN} ${taskId}`) - } catch (error) { - console.error( - `[TaskManager] emit() failed -> ${TaskSocketEvents.JOIN}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - } - - // Subscribe to any pending tasks. - for (const task of this.pendingTasks.values()) { - await this.subscribeToTask(task, socket) - } - - this.pendingTasks.clear() - } - - public onDisconnect(): void { - this.socket = null - } - - public async onReconnect(socket: Socket): Promise { - this.socket = socket - - // Rejoin all subscribed tasks. - for (const taskId of this.subscribedTasks.keys()) { - try { - socket.emit(TaskSocketEvents.JOIN, { taskId }) - - console.log(`[TaskManager] emit() -> ${TaskSocketEvents.JOIN} ${taskId}`) - } catch (error) { - console.error( - `[TaskManager] emit() failed -> ${TaskSocketEvents.JOIN}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - } - } - - public async cleanup(socket: Socket | null): Promise { - if (!socket) { - return - } - - const unsubscribePromises = [] - - for (const taskId of this.subscribedTasks.keys()) { - unsubscribePromises.push(this.unsubscribeFromTask(taskId, socket)) - } - - await Promise.allSettled(unsubscribePromises) - this.subscribedTasks.clear() - this.taskListeners.clear() - this.pendingTasks.clear() - this.socket = null - } - - public addPendingTask(task: TaskLike): void { - this.pendingTasks.set(task.taskId, task) - } - - public async subscribeToTask(task: TaskLike, socket: Socket): Promise { - const taskId = task.taskId - this.subscribedTasks.set(taskId, task) - this.setupListeners(task) - - try { - socket.emit(TaskSocketEvents.JOIN, { taskId }) - console.log(`[TaskManager] emit() -> ${TaskSocketEvents.JOIN} ${taskId}`) - } catch (error) { - console.error( - `[TaskManager] emit() failed -> ${TaskSocketEvents.JOIN}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - } - - public async unsubscribeFromTask(taskId: string, socket: Socket): Promise { - const task = this.subscribedTasks.get(taskId) - - if (task) { - this.removeListeners(task) - this.subscribedTasks.delete(taskId) - } - - try { - socket.emit(TaskSocketEvents.LEAVE, { taskId }) - - console.log(`[TaskManager] emit() -> ${TaskSocketEvents.LEAVE} ${taskId}`) - } catch (error) { - console.error( - `[TaskManager] emit() failed -> ${TaskSocketEvents.LEAVE}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - } - } - - public handleTaskCommand(message: TaskBridgeCommand): void { - const task = this.subscribedTasks.get(message.taskId) - - if (!task) { - console.error(`[TaskManager#handleTaskCommand] Unable to find task ${message.taskId}`) - - return - } - - switch (message.type) { - case TaskBridgeCommandName.Message: - console.log( - `[TaskManager#handleTaskCommand] ${TaskBridgeCommandName.Message} ${message.taskId} -> submitUserMessage()`, - message, - ) - - task.submitUserMessage(message.payload.text, message.payload.images) - break - case TaskBridgeCommandName.ApproveAsk: - console.log( - `[TaskManager#handleTaskCommand] ${TaskBridgeCommandName.ApproveAsk} ${message.taskId} -> approveAsk()`, - message, - ) - - task.approveAsk(message.payload) - break - case TaskBridgeCommandName.DenyAsk: - console.log( - `[TaskManager#handleTaskCommand] ${TaskBridgeCommandName.DenyAsk} ${message.taskId} -> denyAsk()`, - message, - ) - - task.denyAsk(message.payload) - break - } - } - - private setupListeners(task: TaskLike): void { - if (this.taskListeners.has(task.taskId)) { - console.warn("[TaskManager] Listeners already exist for task, removing old listeners:", task.taskId) - - this.removeListeners(task) - } - - const listeners = new Map() - - const onMessage = ({ action, message }: { action: string; message: ClineMessage }) => { - this.publishEvent({ - type: TaskBridgeEventName.Message, - taskId: task.taskId, - action, - message, - }) - } - - task.on(RooCodeEventName.Message, onMessage) - listeners.set(TaskBridgeEventName.Message, onMessage) - - const onTaskModeSwitched = (mode: string) => { - this.publishEvent({ - type: TaskBridgeEventName.TaskModeSwitched, - taskId: task.taskId, - mode, - }) - } - - task.on(RooCodeEventName.TaskModeSwitched, onTaskModeSwitched) - listeners.set(TaskBridgeEventName.TaskModeSwitched, onTaskModeSwitched) - - const onTaskInteractive = (_taskId: string) => { - this.publishEvent({ - type: TaskBridgeEventName.TaskInteractive, - taskId: task.taskId, - }) - } - - task.on(RooCodeEventName.TaskInteractive, onTaskInteractive) - - listeners.set(TaskBridgeEventName.TaskInteractive, onTaskInteractive) - - this.taskListeners.set(task.taskId, listeners) - - console.log("[TaskManager] Task listeners setup complete for:", task.taskId) - } - - private removeListeners(task: TaskLike): void { - const listeners = this.taskListeners.get(task.taskId) - - if (!listeners) { - return - } - - console.log("[TaskManager] Removing task listeners for:", task.taskId) - - listeners.forEach((listener, eventName) => { - try { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - task.off(TASK_EVENT_MAPPING[eventName], listener as any) - } catch (error) { - console.error( - `[TaskManager] Error removing listener for ${String(eventName)} on task ${task.taskId}:`, - error, - ) - } - }) - - this.taskListeners.delete(task.taskId) - } - - private async publishEvent(message: TaskBridgeEvent): Promise { - if (!this.socket) { - console.error("[TaskManager] publishEvent -> socket not available") - return false - } - - try { - this.socket.emit(TaskSocketEvents.EVENT, message) - - if (message.type !== TaskBridgeEventName.Message) { - console.log( - `[TaskManager] emit() -> ${TaskSocketEvents.EVENT} ${message.taskId} ${message.type}`, - message, - ) - } - - return true - } catch (error) { - console.error( - `[TaskManager] emit() failed -> ${TaskSocketEvents.EVENT}: ${ - error instanceof Error ? error.message : String(error) - }`, - ) - - return false - } - } -} diff --git a/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts b/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts new file mode 100644 index 000000000000..89979c9a66e9 --- /dev/null +++ b/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts @@ -0,0 +1,252 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import type { Socket } from "socket.io-client" + +import { + type TaskProviderLike, + type TaskProviderEvents, + RooCodeEventName, + ExtensionBridgeEventName, + ExtensionSocketEvents, +} from "@roo-code/types" + +import { ExtensionChannel } from "../ExtensionChannel.js" + +describe("ExtensionChannel", () => { + let mockSocket: Socket + let mockProvider: TaskProviderLike + let extensionChannel: ExtensionChannel + const instanceId = "test-instance-123" + const userId = "test-user-456" + + // Track registered event listeners + const eventListeners = new Map unknown>>() + + beforeEach(() => { + // Reset the event listeners tracker + eventListeners.clear() + + // Create mock socket + mockSocket = { + emit: vi.fn(), + on: vi.fn(), + off: vi.fn(), + disconnect: vi.fn(), + } as unknown as Socket + + // Create mock provider with event listener tracking + mockProvider = { + cwd: "/test/workspace", + appProperties: { + version: "1.0.0", + extensionVersion: "1.0.0", + }, + gitProperties: undefined, + getCurrentTask: vi.fn().mockReturnValue(undefined), + getCurrentTaskStack: vi.fn().mockReturnValue([]), + getRecentTasks: vi.fn().mockReturnValue([]), + createTask: vi.fn(), + cancelTask: vi.fn(), + clearTask: vi.fn(), + resumeTask: vi.fn(), + getState: vi.fn(), + postStateToWebview: vi.fn(), + postMessageToWebview: vi.fn(), + getTelemetryProperties: vi.fn(), + on: vi.fn((event: keyof TaskProviderEvents, listener: (...args: unknown[]) => unknown) => { + if (!eventListeners.has(event)) { + eventListeners.set(event, new Set()) + } + eventListeners.get(event)!.add(listener) + return mockProvider + }), + off: vi.fn((event: keyof TaskProviderEvents, listener: (...args: unknown[]) => unknown) => { + const listeners = eventListeners.get(event) + if (listeners) { + listeners.delete(listener) + if (listeners.size === 0) { + eventListeners.delete(event) + } + } + return mockProvider + }), + } as unknown as TaskProviderLike + + // Create extension channel instance + extensionChannel = new ExtensionChannel(instanceId, userId, mockProvider) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe("Event Listener Management", () => { + it("should register event listeners on initialization", () => { + // Verify that listeners were registered for all expected events + const expectedEvents: RooCodeEventName[] = [ + RooCodeEventName.TaskCreated, + RooCodeEventName.TaskStarted, + RooCodeEventName.TaskCompleted, + RooCodeEventName.TaskAborted, + RooCodeEventName.TaskFocused, + RooCodeEventName.TaskUnfocused, + RooCodeEventName.TaskActive, + RooCodeEventName.TaskInteractive, + RooCodeEventName.TaskResumable, + RooCodeEventName.TaskIdle, + ] + + // Check that on() was called for each event + expect(mockProvider.on).toHaveBeenCalledTimes(expectedEvents.length) + + // Verify each event was registered + expectedEvents.forEach((eventName) => { + expect(mockProvider.on).toHaveBeenCalledWith(eventName, expect.any(Function)) + }) + + // Verify listeners are tracked in our Map + expect(eventListeners.size).toBe(expectedEvents.length) + }) + + it("should remove all event listeners during cleanup", async () => { + // Verify initial state - listeners are registered + const initialListenerCount = eventListeners.size + expect(initialListenerCount).toBeGreaterThan(0) + + // Get the count of listeners for each event before cleanup + const listenerCountsBefore = new Map() + eventListeners.forEach((listeners, event) => { + listenerCountsBefore.set(event, listeners.size) + }) + + // Perform cleanup + await extensionChannel.cleanup(mockSocket) + + // Verify that off() was called for each registered event + expect(mockProvider.off).toHaveBeenCalledTimes(initialListenerCount) + + // Verify all listeners were removed from our tracking Map + expect(eventListeners.size).toBe(0) + + // Verify that the same listener functions that were added were removed + const onCalls = (mockProvider.on as any).mock.calls + const offCalls = (mockProvider.off as any).mock.calls + + // Each on() call should have a corresponding off() call with the same listener + onCalls.forEach(([eventName, listener]: [keyof TaskProviderEvents, any]) => { + const hasMatchingOff = offCalls.some( + ([offEvent, offListener]: [keyof TaskProviderEvents, any]) => + offEvent === eventName && offListener === listener, + ) + expect(hasMatchingOff).toBe(true) + }) + }) + + it("should not have duplicate listeners after multiple channel creations", () => { + // Create a second channel with the same provider + const secondChannel = new ExtensionChannel("instance-2", userId, mockProvider) + + // Each event should have exactly 2 listeners (one from each channel) + eventListeners.forEach((listeners) => { + expect(listeners.size).toBe(2) + }) + + // Clean up the first channel + extensionChannel.cleanup(mockSocket) + + // Each event should now have exactly 1 listener (from the second channel) + eventListeners.forEach((listeners) => { + expect(listeners.size).toBe(1) + }) + + // Clean up the second channel + secondChannel.cleanup(mockSocket) + + // All listeners should be removed + expect(eventListeners.size).toBe(0) + }) + + it("should handle cleanup even if called multiple times", async () => { + // First cleanup + await extensionChannel.cleanup(mockSocket) + const firstOffCallCount = (mockProvider.off as any).mock.calls.length + + // Second cleanup (should be safe to call again) + await extensionChannel.cleanup(mockSocket) + const secondOffCallCount = (mockProvider.off as any).mock.calls.length + + // The second cleanup shouldn't try to remove listeners again + // since the internal Map was cleared + expect(secondOffCallCount).toBe(firstOffCallCount) + }) + + it("should properly forward events to socket when listeners are triggered", async () => { + // Connect the socket to enable publishing + await extensionChannel.onConnect(mockSocket) + + // Get a listener that was registered for TaskStarted + const taskStartedListeners = eventListeners.get(RooCodeEventName.TaskStarted) + expect(taskStartedListeners).toBeDefined() + expect(taskStartedListeners!.size).toBe(1) + + // Trigger the listener + const listener = Array.from(taskStartedListeners!)[0] + if (listener) { + listener("test-task-id") + } + + // Verify the event was published to the socket + expect(mockSocket.emit).toHaveBeenCalledWith( + ExtensionSocketEvents.EVENT, + expect.objectContaining({ + type: ExtensionBridgeEventName.TaskStarted, + instance: expect.objectContaining({ + instanceId, + userId, + }), + timestamp: expect.any(Number), + }), + undefined, + ) + }) + }) + + describe("Memory Leak Prevention", () => { + it("should not accumulate listeners over multiple connect/disconnect cycles", async () => { + // Simulate multiple connect/disconnect cycles + for (let i = 0; i < 5; i++) { + await extensionChannel.onConnect(mockSocket) + extensionChannel.onDisconnect() + } + + // Listeners should still be the same count (not accumulated) + const expectedEventCount = 10 // Number of events we listen to + expect(eventListeners.size).toBe(expectedEventCount) + + // Each event should have exactly 1 listener + eventListeners.forEach((listeners) => { + expect(listeners.size).toBe(1) + }) + }) + + it("should properly clean up heartbeat interval", async () => { + // Spy on setInterval and clearInterval + const setIntervalSpy = vi.spyOn(global, "setInterval") + const clearIntervalSpy = vi.spyOn(global, "clearInterval") + + // Connect to start heartbeat + await extensionChannel.onConnect(mockSocket) + expect(setIntervalSpy).toHaveBeenCalled() + + // Get the interval ID + const intervalId = setIntervalSpy.mock.results[0]?.value + + // Cleanup should stop the heartbeat + await extensionChannel.cleanup(mockSocket) + expect(clearIntervalSpy).toHaveBeenCalledWith(intervalId) + + setIntervalSpy.mockRestore() + clearIntervalSpy.mockRestore() + }) + }) +}) diff --git a/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts b/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts new file mode 100644 index 000000000000..2809ca78f8c9 --- /dev/null +++ b/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts @@ -0,0 +1,389 @@ +/* eslint-disable @typescript-eslint/no-unsafe-function-type */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import type { Socket } from "socket.io-client" + +import { + type TaskLike, + type ClineMessage, + RooCodeEventName, + TaskBridgeEventName, + TaskBridgeCommandName, + TaskSocketEvents, + TaskStatus, +} from "@roo-code/types" + +import { TaskChannel } from "../TaskChannel.js" + +describe("TaskChannel", () => { + let mockSocket: Socket + let taskChannel: TaskChannel + let mockTask: TaskLike + const instanceId = "test-instance-123" + const taskId = "test-task-456" + + beforeEach(() => { + // Create mock socket + mockSocket = { + emit: vi.fn(), + on: vi.fn(), + off: vi.fn(), + disconnect: vi.fn(), + } as unknown as Socket + + // Create mock task with event emitter functionality + const listeners = new Map unknown>>() + mockTask = { + taskId, + taskStatus: TaskStatus.Running, + taskAsk: undefined, + metadata: {}, + on: vi.fn((event: string, listener: (...args: unknown[]) => unknown) => { + if (!listeners.has(event)) { + listeners.set(event, new Set()) + } + listeners.get(event)!.add(listener) + return mockTask + }), + off: vi.fn((event: string, listener: (...args: unknown[]) => unknown) => { + const eventListeners = listeners.get(event) + if (eventListeners) { + eventListeners.delete(listener) + if (eventListeners.size === 0) { + listeners.delete(event) + } + } + return mockTask + }), + approveAsk: vi.fn(), + denyAsk: vi.fn(), + submitUserMessage: vi.fn(), + abortTask: vi.fn(), + // Helper to trigger events in tests + _triggerEvent: (event: string, ...args: any[]) => { + const eventListeners = listeners.get(event) + if (eventListeners) { + eventListeners.forEach((listener) => listener(...args)) + } + }, + _getListenerCount: (event: string) => { + return listeners.get(event)?.size || 0 + }, + } as unknown as TaskLike & { + _triggerEvent: (event: string, ...args: any[]) => void + _getListenerCount: (event: string) => number + } + + // Create task channel instance + taskChannel = new TaskChannel(instanceId) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe("Event Mapping Refactoring", () => { + it("should use the unified event mapping approach", () => { + // Access the private eventMapping through type assertion + const channel = taskChannel as any + + // Verify eventMapping exists and has the correct structure + expect(channel.eventMapping).toBeDefined() + expect(Array.isArray(channel.eventMapping)).toBe(true) + expect(channel.eventMapping.length).toBe(3) + + // Verify each mapping has the required properties + channel.eventMapping.forEach((mapping: any) => { + expect(mapping).toHaveProperty("from") + expect(mapping).toHaveProperty("to") + expect(mapping).toHaveProperty("createPayload") + expect(typeof mapping.createPayload).toBe("function") + }) + + // Verify specific mappings + expect(channel.eventMapping[0].from).toBe(RooCodeEventName.Message) + expect(channel.eventMapping[0].to).toBe(TaskBridgeEventName.Message) + + expect(channel.eventMapping[1].from).toBe(RooCodeEventName.TaskModeSwitched) + expect(channel.eventMapping[1].to).toBe(TaskBridgeEventName.TaskModeSwitched) + + expect(channel.eventMapping[2].from).toBe(RooCodeEventName.TaskInteractive) + expect(channel.eventMapping[2].to).toBe(TaskBridgeEventName.TaskInteractive) + }) + + it("should setup listeners using the event mapping", async () => { + // Mock the publish method to simulate successful subscription + const channel = taskChannel as any + channel.publish = vi.fn((event: string, data: any, callback?: Function) => { + if (event === TaskSocketEvents.JOIN && callback) { + // Simulate successful join response + callback({ success: true }) + } + return true + }) + + // Connect and subscribe to task + await taskChannel.onConnect(mockSocket) + await channel.subscribeToTask(mockTask, mockSocket) + + // Wait for async operations + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Verify listeners were registered for all mapped events + const task = mockTask as any + expect(task._getListenerCount(RooCodeEventName.Message)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskModeSwitched)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskInteractive)).toBe(1) + }) + + it("should correctly transform Message event payloads", async () => { + // Setup channel with task + const channel = taskChannel as any + let publishCalls: any[] = [] + + channel.publish = vi.fn((event: string, data: any, callback?: Function) => { + publishCalls.push({ event, data }) + + if (event === TaskSocketEvents.JOIN && callback) { + callback({ success: true }) + } + + return true + }) + + await taskChannel.onConnect(mockSocket) + await channel.subscribeToTask(mockTask, mockSocket) + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Clear previous calls + publishCalls = [] + + // Trigger Message event + const messageData = { + action: "test-action", + message: { type: "say", text: "Hello" } as ClineMessage, + } + + ;(mockTask as any)._triggerEvent(RooCodeEventName.Message, messageData) + + // Verify the event was published with correct payload + expect(publishCalls.length).toBe(1) + expect(publishCalls[0]).toEqual({ + event: TaskSocketEvents.EVENT, + data: { + type: TaskBridgeEventName.Message, + taskId: taskId, + action: messageData.action, + message: messageData.message, + }, + }) + }) + + it("should correctly transform TaskModeSwitched event payloads", async () => { + // Setup channel with task + const channel = taskChannel as any + let publishCalls: any[] = [] + + channel.publish = vi.fn((event: string, data: any, callback?: Function) => { + publishCalls.push({ event, data }) + + if (event === TaskSocketEvents.JOIN && callback) { + callback({ success: true }) + } + + return true + }) + + await taskChannel.onConnect(mockSocket) + await channel.subscribeToTask(mockTask, mockSocket) + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Clear previous calls + publishCalls = [] + + // Trigger TaskModeSwitched event + const mode = "architect" + ;(mockTask as any)._triggerEvent(RooCodeEventName.TaskModeSwitched, mode) + + // Verify the event was published with correct payload + expect(publishCalls.length).toBe(1) + expect(publishCalls[0]).toEqual({ + event: TaskSocketEvents.EVENT, + data: { + type: TaskBridgeEventName.TaskModeSwitched, + taskId: taskId, + mode: mode, + }, + }) + }) + + it("should correctly transform TaskInteractive event payloads", async () => { + // Setup channel with task + const channel = taskChannel as any + let publishCalls: any[] = [] + + channel.publish = vi.fn((event: string, data: any, callback?: Function) => { + publishCalls.push({ event, data }) + if (event === TaskSocketEvents.JOIN && callback) { + callback({ success: true }) + } + return true + }) + + await taskChannel.onConnect(mockSocket) + await channel.subscribeToTask(mockTask, mockSocket) + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Clear previous calls + publishCalls = [] + + // Trigger TaskInteractive event + ;(mockTask as any)._triggerEvent(RooCodeEventName.TaskInteractive, taskId) + + // Verify the event was published with correct payload + expect(publishCalls.length).toBe(1) + expect(publishCalls[0]).toEqual({ + event: TaskSocketEvents.EVENT, + data: { + type: TaskBridgeEventName.TaskInteractive, + taskId: taskId, + }, + }) + }) + + it("should properly clean up listeners using event mapping", async () => { + // Setup channel with task + const channel = taskChannel as any + + channel.publish = vi.fn((event: string, data: any, callback?: Function) => { + if (event === TaskSocketEvents.JOIN && callback) { + callback({ success: true }) + } + if (event === TaskSocketEvents.LEAVE && callback) { + callback({ success: true }) + } + return true + }) + + await taskChannel.onConnect(mockSocket) + await channel.subscribeToTask(mockTask, mockSocket) + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Verify listeners are registered + const task = mockTask as any + expect(task._getListenerCount(RooCodeEventName.Message)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskModeSwitched)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskInteractive)).toBe(1) + + // Clean up + await taskChannel.cleanup(mockSocket) + + // Verify all listeners were removed + expect(task._getListenerCount(RooCodeEventName.Message)).toBe(0) + expect(task._getListenerCount(RooCodeEventName.TaskModeSwitched)).toBe(0) + expect(task._getListenerCount(RooCodeEventName.TaskInteractive)).toBe(0) + }) + + it("should handle duplicate listener prevention", async () => { + // Setup channel with task + await taskChannel.onConnect(mockSocket) + + // Subscribe to the same task twice + const channel = taskChannel as any + channel.subscribedTasks.set(taskId, mockTask) + channel.setupTaskListeners(mockTask) + + // Try to setup listeners again (should remove old ones first) + const warnSpy = vi.spyOn(console, "warn") + channel.setupTaskListeners(mockTask) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + "[TaskChannel] Listeners already exist for task, removing old listeners:", + taskId, + ) + + // Verify only one set of listeners exists + const task = mockTask as any + expect(task._getListenerCount(RooCodeEventName.Message)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskModeSwitched)).toBe(1) + expect(task._getListenerCount(RooCodeEventName.TaskInteractive)).toBe(1) + + warnSpy.mockRestore() + }) + }) + + describe("Command Handling", () => { + beforeEach(async () => { + // Setup channel with a subscribed task + await taskChannel.onConnect(mockSocket) + const channel = taskChannel as any + channel.subscribedTasks.set(taskId, mockTask) + }) + + it("should handle Message command", () => { + const command = { + type: TaskBridgeCommandName.Message, + taskId, + timestamp: Date.now(), + payload: { + text: "Hello, world!", + images: ["image1.png"], + }, + } + + taskChannel.handleCommand(command) + + expect(mockTask.submitUserMessage).toHaveBeenCalledWith(command.payload.text, command.payload.images) + }) + + it("should handle ApproveAsk command", () => { + const command = { + type: TaskBridgeCommandName.ApproveAsk, + taskId, + timestamp: Date.now(), + payload: { + text: "Approved", + }, + } + + taskChannel.handleCommand(command) + + expect(mockTask.approveAsk).toHaveBeenCalledWith(command.payload) + }) + + it("should handle DenyAsk command", () => { + const command = { + type: TaskBridgeCommandName.DenyAsk, + taskId, + timestamp: Date.now(), + payload: { + text: "Denied", + }, + } + + taskChannel.handleCommand(command) + + expect(mockTask.denyAsk).toHaveBeenCalledWith(command.payload) + }) + + it("should log error for unknown task", () => { + const errorSpy = vi.spyOn(console, "error") + + const command = { + type: TaskBridgeCommandName.Message, + taskId: "unknown-task", + timestamp: Date.now(), + payload: { + text: "Hello", + }, + } + + taskChannel.handleCommand(command) + + expect(errorSpy).toHaveBeenCalledWith(`[TaskChannel] Unable to find task unknown-task`) + + errorSpy.mockRestore() + }) + }) +}) diff --git a/packages/cloud/src/bridge/index.ts b/packages/cloud/src/bridge/index.ts new file mode 100644 index 000000000000..94873c09fdfa --- /dev/null +++ b/packages/cloud/src/bridge/index.ts @@ -0,0 +1,6 @@ +export { type BridgeOrchestratorOptions, BridgeOrchestrator } from "./BridgeOrchestrator.js" +export { type SocketTransportOptions, SocketTransport } from "./SocketTransport.js" + +export { BaseChannel } from "./BaseChannel.js" +export { ExtensionChannel } from "./ExtensionChannel.js" +export { TaskChannel } from "./TaskChannel.js" diff --git a/packages/cloud/src/importVscode.ts b/packages/cloud/src/importVscode.ts index b3c3c94150d6..f389555afa17 100644 --- a/packages/cloud/src/importVscode.ts +++ b/packages/cloud/src/importVscode.ts @@ -7,43 +7,38 @@ let vscodeModule: typeof import("vscode") | undefined /** - * Attempts to dynamically import the VS Code module. - * Returns undefined if not running in a VS Code/Cursor extension context. + * Attempts to dynamically import the `vscode` module. + * Returns undefined if not running in a VSCode extension context. */ export async function importVscode(): Promise { - // Check if already loaded if (vscodeModule) { return vscodeModule } try { - // Method 1: Check if vscode is available in global scope (common in extension hosts). - if (typeof globalThis !== "undefined" && "acquireVsCodeApi" in globalThis) { - // We're in a webview context, vscode module won't be available. - return undefined - } - - // Method 2: Try to require the module (works in most extension contexts). if (typeof require !== "undefined") { try { // eslint-disable-next-line @typescript-eslint/no-require-imports vscodeModule = require("vscode") if (vscodeModule) { + console.log("VS Code module loaded from require") return vscodeModule } } catch (error) { - console.error("Error loading VS Code module:", error) + console.error(`Error loading VS Code module: ${error instanceof Error ? error.message : String(error)}`) // Fall through to dynamic import. } } - // Method 3: Dynamic import (original approach, works in VSCode). vscodeModule = await import("vscode") + console.log("VS Code module loaded from dynamic import") return vscodeModule } catch (error) { - // Log the original error for debugging. - console.warn("VS Code module not available in this environment:", error) + console.warn( + `VS Code module not available in this environment: ${error instanceof Error ? error.message : String(error)}`, + ) + return undefined } } diff --git a/packages/cloud/src/index.ts b/packages/cloud/src/index.ts index 6ba2d3e61e6a..dd40e6fc5279 100644 --- a/packages/cloud/src/index.ts +++ b/packages/cloud/src/index.ts @@ -1,5 +1,5 @@ export * from "./config.js" -export * from "./CloudAPI.js" -export * from "./CloudService.js" -export * from "./bridge/ExtensionBridgeService.js" +export { CloudService } from "./CloudService.js" + +export { BridgeOrchestrator } from "./bridge/index.js" diff --git a/packages/types/src/cloud.ts b/packages/types/src/cloud.ts index b80c562fa3f9..dbf79b6bfac9 100644 --- a/packages/types/src/cloud.ts +++ b/packages/types/src/cloud.ts @@ -587,32 +587,49 @@ export type TaskBridgeCommand = z.infer * ExtensionSocketEvents */ -export const ExtensionSocketEvents = { - CONNECTED: "extension:connected", +export enum ExtensionSocketEvents { + CONNECTED = "extension:connected", - REGISTER: "extension:register", - UNREGISTER: "extension:unregister", + REGISTER = "extension:register", + UNREGISTER = "extension:unregister", - HEARTBEAT: "extension:heartbeat", + HEARTBEAT = "extension:heartbeat", - EVENT: "extension:event", // event from extension instance - RELAYED_EVENT: "extension:relayed_event", // relay from server + EVENT = "extension:event", // event from extension instance + RELAYED_EVENT = "extension:relayed_event", // relay from server - COMMAND: "extension:command", // command from user - RELAYED_COMMAND: "extension:relayed_command", // relay from server -} as const + COMMAND = "extension:command", // command from user + RELAYED_COMMAND = "extension:relayed_command", // relay from server +} /** * TaskSocketEvents */ -export const TaskSocketEvents = { - JOIN: "task:join", - LEAVE: "task:leave", +export enum TaskSocketEvents { + JOIN = "task:join", + LEAVE = "task:leave", - EVENT: "task:event", // event from extension task - RELAYED_EVENT: "task:relayed_event", // relay from server + EVENT = "task:event", // event from extension task + RELAYED_EVENT = "task:relayed_event", // relay from server - COMMAND: "task:command", // command from user - RELAYED_COMMAND: "task:relayed_command", // relay from server -} as const + COMMAND = "task:command", // command from user + RELAYED_COMMAND = "task:relayed_command", // relay from server +} + +/** + * `emit()` Response Types + */ + +export type JoinResponse = { + success: boolean + error?: string + taskId?: string + timestamp?: string +} + +export type LeaveResponse = { + success: boolean + taskId?: string + timestamp?: string +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 60fceb2bb80d..b192440779ab 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -35,7 +35,7 @@ import { isResumableAsk, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { CloudService, ExtensionBridgeService } from "@roo-code/cloud" +import { CloudService, BridgeOrchestrator } from "@roo-code/cloud" // api import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api" @@ -115,7 +115,7 @@ export type TaskOptions = { apiConfiguration: ProviderSettings enableDiff?: boolean enableCheckpoints?: boolean - enableTaskBridge?: boolean + enableBridge?: boolean fuzzyMatchThreshold?: number consecutiveMistakeLimit?: number task?: string @@ -255,8 +255,8 @@ export class Task extends EventEmitter implements TaskLike { checkpointServiceInitializing = false // Task Bridge - enableTaskBridge: boolean - bridgeService: ExtensionBridgeService | null = null + enableBridge: boolean + bridge: BridgeOrchestrator | null = null // Streaming isWaitingForFirstChunk = false @@ -280,7 +280,7 @@ export class Task extends EventEmitter implements TaskLike { apiConfiguration, enableDiff = false, enableCheckpoints = true, - enableTaskBridge = false, + enableBridge = false, fuzzyMatchThreshold = 1.0, consecutiveMistakeLimit = DEFAULT_CONSECUTIVE_MISTAKE_LIMIT, task, @@ -335,7 +335,7 @@ export class Task extends EventEmitter implements TaskLike { this.globalStoragePath = provider.context.globalStorageUri.fsPath this.diffViewProvider = new DiffViewProvider(this.cwd, this) this.enableCheckpoints = enableCheckpoints - this.enableTaskBridge = enableTaskBridge + this.enableBridge = enableBridge this.rootTask = rootTask this.parentTask = parentTask @@ -1082,12 +1082,12 @@ export class Task extends EventEmitter implements TaskLike { // Start / Abort / Resume private async startTask(task?: string, images?: string[]): Promise { - if (this.enableTaskBridge) { + if (this.enableBridge) { try { - this.bridgeService = this.bridgeService || ExtensionBridgeService.getInstance() + this.bridge = this.bridge || BridgeOrchestrator.getInstance() - if (this.bridgeService) { - await this.bridgeService.subscribeToTask(this) + if (this.bridge) { + await this.bridge.subscribeToTask(this) } } catch (error) { console.error( @@ -1154,14 +1154,12 @@ export class Task extends EventEmitter implements TaskLike { } private async resumeTaskFromHistory() { - // Resuming task from history - - if (this.enableTaskBridge) { + if (this.enableBridge) { try { - this.bridgeService = this.bridgeService || ExtensionBridgeService.getInstance() + this.bridge = this.bridge || BridgeOrchestrator.getInstance() - if (this.bridgeService) { - await this.bridgeService.subscribeToTask(this) + if (this.bridge) { + await this.bridge.subscribeToTask(this) } } catch (error) { console.error( @@ -1436,11 +1434,12 @@ export class Task extends EventEmitter implements TaskLike { } // Unsubscribe from TaskBridge service. - if (this.bridgeService) { - this.bridgeService + if (this.bridge) { + this.bridge .unsubscribeFromTask(this.taskId) .catch((error: unknown) => console.error("Error unsubscribing from task bridge:", error)) - this.bridgeService = null + + this.bridge = null } // Release any terminals associated with this task. diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 0c26858b613d..c37288f20d2f 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -39,7 +39,7 @@ import { ORGANIZATION_ALLOW_ALL, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { CloudService, getRooCodeApiUrl } from "@roo-code/cloud" +import { CloudService, BridgeOrchestrator, getRooCodeApiUrl } from "@roo-code/cloud" import { Package } from "../../shared/package" import { findLast } from "../../shared/array" @@ -70,7 +70,6 @@ import { fileExistsAtPath } from "../../utils/fs" import { setTtsEnabled, setTtsSpeed } from "../../utils/tts" import { getWorkspaceGitInfo } from "../../utils/git" import { getWorkspacePath } from "../../utils/path" -import { isRemoteControlEnabled } from "../../utils/remoteControl" import { setPanel } from "../../activate/registerCommands" @@ -134,7 +133,6 @@ export class ClineProvider ) { super() - this.log("ClineProvider instantiated") ClineProvider.activeInstances.add(this) this.mdmService = mdmService @@ -300,11 +298,11 @@ export class ClineProvider // Adds a new Task instance to clineStack, marking the start of a new task. // The instance is pushed to the top of the stack (LIFO order). - // When the task is completed, the top instance is removed, reactivating the previous task. + // When the task is completed, the top instance is removed, reactivating the + // previous task. async addClineToStack(task: Task) { - console.log(`[subtasks] adding task ${task.taskId}.${task.instanceId} to stack`) - - // Add this cline instance into the stack that represents the order of all the called tasks. + // Add this cline instance into the stack that represents the order of + // all the called tasks. this.clineStack.push(task) task.emit(RooCodeEventName.TaskFocused) @@ -348,15 +346,13 @@ export class ClineProvider let task = this.clineStack.pop() if (task) { - console.log(`[subtasks] removing task ${task.taskId}.${task.instanceId} from stack`) - try { // Abort the running task and set isAbandoned to true so // all running promises will exit as well. await task.abortTask(true) } catch (e) { this.log( - `[subtasks] encountered error while aborting task ${task.taskId}.${task.instanceId}: ${e.message}`, + `[removeClineFromStack] encountered error while aborting task ${task.taskId}.${task.instanceId}: ${e.message}`, ) } @@ -382,6 +378,7 @@ export class ClineProvider if (this.clineStack.length === 0) { return undefined } + return this.clineStack[this.clineStack.length - 1] } @@ -394,19 +391,22 @@ export class ClineProvider return this.clineStack.map((cline) => cline.taskId) } - // remove the current task/cline instance (at the top of the stack), so this task is finished - // and resume the previous task/cline instance (if it exists) - // this is used when a sub task is finished and the parent task needs to be resumed + // Remove the current task/cline instance (at the top of the stack), so this + // task is finished and resume the previous task/cline instance (if it + // exists). + // This is used when a subtask is finished and the parent task needs to be + // resumed. async finishSubTask(lastMessage: string) { - console.log(`[subtasks] finishing subtask ${lastMessage}`) - // remove the last cline instance from the stack (this is the finished sub task) + // Remove the last cline instance from the stack (this is the finished + // subtask). await this.removeClineFromStack() - // resume the last cline instance in the stack (if it exists - this is the 'parent' calling task) + // Resume the last cline instance in the stack (if it exists - this is + // the 'parent' calling task). await this.getCurrentTask()?.resumePausedTask(lastMessage) } - // Clear the current task without treating it as a subtask - // This is used when the user cancels a task that is not a subtask + // Clear the current task without treating it as a subtask. + // This is used when the user cancels a task that is not a subtask. async clearTask() { await this.removeClineFromStack() } @@ -621,8 +621,6 @@ export class ClineProvider } async resolveWebviewView(webviewView: vscode.WebviewView | vscode.WebviewPanel) { - this.log("Resolving webview view") - this.view = webviewView const inTabMode = "onDidChangeViewState" in webviewView @@ -741,8 +739,6 @@ export class ClineProvider // If the extension is starting a new session, clear previous task state. await this.removeClineFromStack() - - this.log("Webview view resolved") } // When initializing a new task, (not from history but from a tool command @@ -796,7 +792,7 @@ export class ClineProvider parentTask, taskNumber: this.clineStack.length + 1, onCreated: this.taskCreationCallback, - enableTaskBridge: isRemoteControlEnabled(cloudUserInfo, remoteControlEnabled), + enableBridge: BridgeOrchestrator.isEnabled(cloudUserInfo, remoteControlEnabled), initialTodos: options.initialTodos, ...options, }) @@ -804,7 +800,7 @@ export class ClineProvider await this.addClineToStack(task) this.log( - `[subtasks] ${task.parentTask ? "child" : "parent"} task ${task.taskId}.${task.instanceId} instantiated`, + `[createTask] ${task.parentTask ? "child" : "parent"} task ${task.taskId}.${task.instanceId} instantiated`, ) return task @@ -866,9 +862,6 @@ export class ClineProvider remoteControlEnabled, } = await this.getState() - // Determine if TaskBridge should be enabled - const enableTaskBridge = isRemoteControlEnabled(cloudUserInfo, remoteControlEnabled) - const task = new Task({ provider: this, apiConfiguration, @@ -882,13 +875,13 @@ export class ClineProvider parentTask: historyItem.parentTask, taskNumber: historyItem.number, onCreated: this.taskCreationCallback, - enableTaskBridge, + enableBridge: BridgeOrchestrator.isEnabled(cloudUserInfo, remoteControlEnabled), }) await this.addClineToStack(task) this.log( - `[subtasks] ${task.parentTask ? "child" : "parent"} task ${task.taskId}.${task.instanceId} instantiated`, + `[createTaskWithHistoryItem] ${task.parentTask ? "child" : "parent"} task ${task.taskId}.${task.instanceId} instantiated`, ) return task @@ -1278,7 +1271,7 @@ export class ClineProvider return } - console.log(`[subtasks] cancelling task ${cline.taskId}.${cline.instanceId}`) + console.log(`[cancelTask] cancelling task ${cline.taskId}.${cline.instanceId}`) const { historyItem } = await this.getTaskWithId(cline.taskId) // Preserve parent and root task information for history item. @@ -2199,56 +2192,50 @@ export class ClineProvider return true } - public async handleRemoteControlToggle(enabled: boolean) { - const { CloudService: CloudServiceImport, ExtensionBridgeService } = await import("@roo-code/cloud") - - const userInfo = CloudServiceImport.instance.getUserInfo() + public async remoteControlEnabled(enabled: boolean) { + const userInfo = CloudService.instance.getUserInfo() - const bridgeConfig = await CloudServiceImport.instance.cloudAPI?.bridgeConfig().catch(() => undefined) + const config = await CloudService.instance.cloudAPI?.bridgeConfig().catch(() => undefined) - if (!bridgeConfig) { - this.log("[ClineProvider#handleRemoteControlToggle] Failed to get bridge config") + if (!config) { + this.log("[ClineProvider#remoteControlEnabled] Failed to get bridge config") return } - await ExtensionBridgeService.handleRemoteControlState( - userInfo, - enabled, - { ...bridgeConfig, provider: this, sessionId: vscode.env.sessionId }, - (message: string) => this.log(message), - ) + await BridgeOrchestrator.connectOrDisconnect(userInfo, enabled, { + ...config, + provider: this, + sessionId: vscode.env.sessionId, + }) + + const bridge = BridgeOrchestrator.getInstance() - if (isRemoteControlEnabled(userInfo, enabled)) { + if (bridge) { const currentTask = this.getCurrentTask() - if (currentTask && !currentTask.bridgeService) { + if (currentTask && !currentTask.bridge) { try { - currentTask.bridgeService = ExtensionBridgeService.getInstance() - - if (currentTask.bridgeService) { - await currentTask.bridgeService.subscribeToTask(currentTask) - } + currentTask.bridge = bridge + await currentTask.bridge.subscribeToTask(currentTask) } catch (error) { - const message = `[ClineProvider#handleRemoteControlToggle] subscribeToTask failed - ${error instanceof Error ? error.message : String(error)}` + const message = `[ClineProvider#remoteControlEnabled] subscribeToTask failed - ${error instanceof Error ? error.message : String(error)}` this.log(message) console.error(message) } } } else { for (const task of this.clineStack) { - if (task.bridgeService) { + if (task.bridge) { try { - await task.bridgeService.unsubscribeFromTask(task.taskId) - task.bridgeService = null + await task.bridge.unsubscribeFromTask(task.taskId) + task.bridge = null } catch (error) { - const message = `[ClineProvider#handleRemoteControlToggle] unsubscribeFromTask failed - ${error instanceof Error ? error.message : String(error)}` + const message = `[ClineProvider#remoteControlEnabled] unsubscribeFromTask failed - ${error instanceof Error ? error.message : String(error)}` this.log(message) console.error(message) } } } - - ExtensionBridgeService.resetInstance() } } diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index e5e7e85da5cf..a09ca1a85105 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -320,11 +320,10 @@ vi.mock("@roo-code/cloud", () => ({ } }, }, - getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), - ORGANIZATION_ALLOW_ALL: { - allowAll: true, - providers: {}, + BridgeOrchestrator: { + isEnabled: vi.fn().mockReturnValue(false), }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), })) afterAll(() => { diff --git a/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts b/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts index 0ae7b38b8167..29aefcaeba43 100644 --- a/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts @@ -113,11 +113,10 @@ vi.mock("@roo-code/cloud", () => ({ } }, }, - getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), - ORGANIZATION_ALLOW_ALL: { - allowAll: true, - providers: {}, + BridgeOrchestrator: { + isEnabled: vi.fn().mockReturnValue(false), }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), })) vi.mock("../../../shared/modes", () => ({ diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 6bf1320ccf05..970c9d0712c9 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -937,8 +937,8 @@ export const webviewMessageHandler = async ( const mcpEnabled = message.bool ?? true await updateGlobalState("mcpEnabled", mcpEnabled) - // Delegate MCP enable/disable logic to McpHub const mcpHubInstance = provider.getMcpHub() + if (mcpHubInstance) { await mcpHubInstance.handleMcpEnabledChange(mcpEnabled) } @@ -951,17 +951,18 @@ export const webviewMessageHandler = async ( break case "remoteControlEnabled": await updateGlobalState("remoteControlEnabled", message.bool ?? false) - await provider.handleRemoteControlToggle(message.bool ?? false) + await provider.remoteControlEnabled(message.bool ?? false) await provider.postStateToWebview() break case "refreshAllMcpServers": { const mcpHub = provider.getMcpHub() + if (mcpHub) { await mcpHub.refreshAllConnections() } + break } - // playSound handler removed - now handled directly in the webview case "soundEnabled": const soundEnabled = message.bool ?? true await updateGlobalState("soundEnabled", soundEnabled) @@ -975,7 +976,7 @@ export const webviewMessageHandler = async ( case "ttsEnabled": const ttsEnabled = message.bool ?? true await updateGlobalState("ttsEnabled", ttsEnabled) - setTtsEnabled(ttsEnabled) // Add this line to update the tts utility + setTtsEnabled(ttsEnabled) await provider.postStateToWebview() break case "ttsSpeed": @@ -991,6 +992,7 @@ export const webviewMessageHandler = async ( onStop: () => provider.postMessageToWebview({ type: "ttsStop", text: message.text }), }) } + break case "stopTts": stopTts() diff --git a/src/extension.ts b/src/extension.ts index 6060bb341f59..4aa5ff11133a 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -13,7 +13,7 @@ try { } import type { CloudUserInfo } from "@roo-code/types" -import { CloudService, ExtensionBridgeService } from "@roo-code/cloud" +import { CloudService, BridgeOrchestrator } from "@roo-code/cloud" import { TelemetryService, PostHogTelemetryClient } from "@roo-code/telemetry" import "./utils/path" // Necessary to have access to String.prototype.toPosix. @@ -30,7 +30,6 @@ import { CodeIndexManager } from "./services/code-index/manager" import { MdmService } from "./services/mdm/MdmService" import { migrateSettings } from "./utils/migrateSettings" import { autoImportSettings } from "./utils/autoImportSettings" -import { isRemoteControlEnabled } from "./utils/remoteControl" import { API } from "./extension/api" import { @@ -147,15 +146,10 @@ export async function activate(context: vscode.ExtensionContext) { cloudLogger(`[CloudService] isCloudAgent = ${isCloudAgent}, socketBridgeUrl = ${config.socketBridgeUrl}`) - ExtensionBridgeService.handleRemoteControlState( + await BridgeOrchestrator.connectOrDisconnect( userInfo, isCloudAgent ? true : contextProxy.getValue("remoteControlEnabled"), - { - ...config, - provider, - sessionId: vscode.env.sessionId, - }, - cloudLogger, + { ...config, provider, sessionId: vscode.env.sessionId }, ) } catch (error) { cloudLogger( @@ -333,10 +327,10 @@ export async function deactivate() { } } - const bridgeService = ExtensionBridgeService.getInstance() + const bridge = BridgeOrchestrator.getInstance() - if (bridgeService) { - await bridgeService.disconnect() + if (bridge) { + await bridge.disconnect() } await McpServerManager.cleanup(extensionContext) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 271c6e1fb3fd..6ec5b839e8cf 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -166,7 +166,7 @@ export class McpHub { */ public registerClient(): void { this.refCount++ - console.log(`McpHub: Client registered. Ref count: ${this.refCount}`) + // console.log(`McpHub: Client registered. Ref count: ${this.refCount}`) } /** @@ -175,7 +175,9 @@ export class McpHub { */ public async unregisterClient(): Promise { this.refCount-- - console.log(`McpHub: Client unregistered. Ref count: ${this.refCount}`) + + // console.log(`McpHub: Client unregistered. Ref count: ${this.refCount}`) + if (this.refCount <= 0) { console.log("McpHub: Last client unregistered. Disposing hub.") await this.dispose() diff --git a/src/services/mdm/MdmService.ts b/src/services/mdm/MdmService.ts index 91c407514f27..63bdbe29fcad 100644 --- a/src/services/mdm/MdmService.ts +++ b/src/services/mdm/MdmService.ts @@ -32,14 +32,13 @@ export class MdmService { public async initialize(): Promise { try { this.mdmConfig = await this.loadMdmConfig() + if (this.mdmConfig) { - this.log("[MDM] Loaded MDM configuration:", this.mdmConfig) - } else { - this.log("[MDM] No MDM configuration found") + this.log(`[MDM] Loaded MDM configuration: ${JSON.stringify(this.mdmConfig)}`) } } catch (error) { - this.log("[MDM] Error loading MDM configuration:", error) - // Don't throw - extension should work without MDM config + this.log(`[MDM] Error loading MDM configuration: ${error instanceof Error ? error.message : String(error)}`) + // Don't throw - extension should work without MDM config. } } diff --git a/src/utils/remoteControl.ts b/src/utils/remoteControl.ts deleted file mode 100644 index f003b522d1d6..000000000000 --- a/src/utils/remoteControl.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { CloudUserInfo } from "@roo-code/types" - -/** - * Determines if remote control features should be enabled - * @param cloudUserInfo - User information from cloud service - * @param remoteControlEnabled - User's remote control setting - * @returns true if remote control should be enabled - */ -export function isRemoteControlEnabled(cloudUserInfo?: CloudUserInfo | null, remoteControlEnabled?: boolean): boolean { - return !!(cloudUserInfo?.id && cloudUserInfo.extensionBridgeEnabled && remoteControlEnabled) -}