Skip to content

Commit 20929b0

Browse files
authored
Mode and provider profile selector (#7545)
1 parent a1f9b7d commit 20929b0

File tree

14 files changed

+486
-344
lines changed

14 files changed

+486
-344
lines changed

packages/cloud/src/bridge/BaseChannel.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ export abstract class BaseChannel<TCommand = unknown, TEventName extends string
8383
/**
8484
* Handle incoming commands - must be implemented by subclasses.
8585
*/
86-
public abstract handleCommand(command: TCommand): void
86+
public abstract handleCommand(command: TCommand): Promise<void>
8787

8888
/**
8989
* Handle connection-specific logic.

packages/cloud/src/bridge/ExtensionChannel.ts

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ export class ExtensionChannel extends BaseChannel<
5353
this.setupListeners()
5454
}
5555

56-
/**
57-
* Handle extension-specific commands from the web app
58-
*/
59-
public handleCommand(command: ExtensionBridgeCommand): void {
56+
public async handleCommand(command: ExtensionBridgeCommand): Promise<void> {
6057
if (command.instanceId !== this.instanceId) {
6158
console.log(`[ExtensionChannel] command -> instance id mismatch | ${this.instanceId}`, {
6259
messageInstanceId: command.instanceId,
@@ -69,13 +66,22 @@ export class ExtensionChannel extends BaseChannel<
6966
console.log(`[ExtensionChannel] command -> createTask() | ${command.instanceId}`, {
7067
text: command.payload.text?.substring(0, 100) + "...",
7168
hasImages: !!command.payload.images,
69+
mode: command.payload.mode,
70+
providerProfile: command.payload.providerProfile,
7271
})
7372

74-
this.provider.createTask(command.payload.text, command.payload.images)
73+
this.provider.createTask(
74+
command.payload.text,
75+
command.payload.images,
76+
undefined, // parentTask
77+
undefined, // options
78+
{ mode: command.payload.mode, currentApiConfigName: command.payload.providerProfile },
79+
)
80+
7581
break
7682
}
7783
case ExtensionBridgeCommandName.StopTask: {
78-
const instance = this.updateInstance()
84+
const instance = await this.updateInstance()
7985

8086
if (instance.task.taskStatus === TaskStatus.Running) {
8187
console.log(`[ExtensionChannel] command -> cancelTask() | ${command.instanceId}`)
@@ -86,14 +92,14 @@ export class ExtensionChannel extends BaseChannel<
8692
this.provider.clearTask()
8793
this.provider.postStateToWebview()
8894
}
95+
8996
break
9097
}
9198
case ExtensionBridgeCommandName.ResumeTask: {
9299
console.log(`[ExtensionChannel] command -> resumeTask() | ${command.instanceId}`, {
93100
taskId: command.payload.taskId,
94101
})
95102

96-
// Resume the task from history by taskId
97103
this.provider.resumeTask(command.payload.taskId)
98104
this.provider.postStateToWebview()
99105
break
@@ -122,20 +128,20 @@ export class ExtensionChannel extends BaseChannel<
122128
}
123129

124130
private async registerInstance(_socket: Socket): Promise<void> {
125-
const instance = this.updateInstance()
131+
const instance = await this.updateInstance()
126132
await this.publish(ExtensionSocketEvents.REGISTER, instance)
127133
}
128134

129135
private async unregisterInstance(_socket: Socket): Promise<void> {
130-
const instance = this.updateInstance()
136+
const instance = await this.updateInstance()
131137
await this.publish(ExtensionSocketEvents.UNREGISTER, instance)
132138
}
133139

134140
private startHeartbeat(socket: Socket): void {
135141
this.stopHeartbeat()
136142

137143
this.heartbeatInterval = setInterval(async () => {
138-
const instance = this.updateInstance()
144+
const instance = await this.updateInstance()
139145

140146
try {
141147
socket.emit(ExtensionSocketEvents.HEARTBEAT, instance)
@@ -172,11 +178,11 @@ export class ExtensionChannel extends BaseChannel<
172178
] as const
173179

174180
eventMapping.forEach(({ from, to }) => {
175-
// Create and store the listener function for cleanup/
176-
const listener = (..._args: unknown[]) => {
181+
// Create and store the listener function for cleanup.
182+
const listener = async (..._args: unknown[]) => {
177183
this.publish(ExtensionSocketEvents.EVENT, {
178184
type: to,
179-
instance: this.updateInstance(),
185+
instance: await this.updateInstance(),
180186
timestamp: Date.now(),
181187
})
182188
}
@@ -195,10 +201,16 @@ export class ExtensionChannel extends BaseChannel<
195201
this.eventListeners.clear()
196202
}
197203

198-
private updateInstance(): ExtensionInstance {
204+
private async updateInstance(): Promise<ExtensionInstance> {
199205
const task = this.provider?.getCurrentTask()
200206
const taskHistory = this.provider?.getRecentTasks() ?? []
201207

208+
const mode = await this.provider?.getMode()
209+
const modes = (await this.provider?.getModes()) ?? []
210+
211+
const providerProfile = await this.provider?.getProviderProfile()
212+
const providerProfiles = (await this.provider?.getProviderProfiles()) ?? []
213+
202214
this.extensionInstance = {
203215
...this.extensionInstance,
204216
appProperties: this.extensionInstance.appProperties ?? this.provider.appProperties,
@@ -213,6 +225,10 @@ export class ExtensionChannel extends BaseChannel<
213225
: { taskId: "", taskStatus: TaskStatus.None },
214226
taskAsk: task?.taskAsk,
215227
taskHistory,
228+
mode,
229+
providerProfile,
230+
modes,
231+
providerProfiles,
216232
}
217233

218234
return this.extensionInstance

packages/cloud/src/bridge/TaskChannel.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export class TaskChannel extends BaseChannel<
7373
super(instanceId)
7474
}
7575

76-
public handleCommand(command: TaskBridgeCommand): void {
76+
public async handleCommand(command: TaskBridgeCommand): Promise<void> {
7777
const task = this.subscribedTasks.get(command.taskId)
7878

7979
if (!task) {
@@ -87,14 +87,22 @@ export class TaskChannel extends BaseChannel<
8787
`[TaskChannel] ${TaskBridgeCommandName.Message} ${command.taskId} -> submitUserMessage()`,
8888
command,
8989
)
90-
task.submitUserMessage(command.payload.text, command.payload.images)
90+
91+
await task.submitUserMessage(
92+
command.payload.text,
93+
command.payload.images,
94+
command.payload.mode,
95+
command.payload.providerProfile,
96+
)
97+
9198
break
9299

93100
case TaskBridgeCommandName.ApproveAsk:
94101
console.log(
95102
`[TaskChannel] ${TaskBridgeCommandName.ApproveAsk} ${command.taskId} -> approveAsk()`,
96103
command,
97104
)
105+
98106
task.approveAsk(command.payload)
99107
break
100108

packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ describe("ExtensionChannel", () => {
5353
postStateToWebview: vi.fn(),
5454
postMessageToWebview: vi.fn(),
5555
getTelemetryProperties: vi.fn(),
56+
getMode: vi.fn().mockResolvedValue("code"),
57+
getModes: vi.fn().mockResolvedValue([
58+
{ slug: "code", name: "Code", description: "Code mode" },
59+
{ slug: "architect", name: "Architect", description: "Architect mode" },
60+
]),
61+
getProviderProfile: vi.fn().mockResolvedValue("default"),
62+
getProviderProfiles: vi.fn().mockResolvedValue([{ name: "default", description: "Default profile" }]),
5663
on: vi.fn((event: keyof TaskProviderEvents, listener: (...args: unknown[]) => unknown) => {
5764
if (!eventListeners.has(event)) {
5865
eventListeners.set(event, new Set())
@@ -184,6 +191,9 @@ describe("ExtensionChannel", () => {
184191
// Connect the socket to enable publishing
185192
await extensionChannel.onConnect(mockSocket)
186193

194+
// Clear the mock calls from the connection (which emits a register event)
195+
;(mockSocket.emit as any).mockClear()
196+
187197
// Get a listener that was registered for TaskStarted
188198
const taskStartedListeners = eventListeners.get(RooCodeEventName.TaskStarted)
189199
expect(taskStartedListeners).toBeDefined()
@@ -192,7 +202,7 @@ describe("ExtensionChannel", () => {
192202
// Trigger the listener
193203
const listener = Array.from(taskStartedListeners!)[0]
194204
if (listener) {
195-
listener("test-task-id")
205+
await listener("test-task-id")
196206
}
197207

198208
// Verify the event was published to the socket

packages/cloud/src/bridge/__tests__/TaskChannel.test.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,12 @@ describe("TaskChannel", () => {
333333

334334
taskChannel.handleCommand(command)
335335

336-
expect(mockTask.submitUserMessage).toHaveBeenCalledWith(command.payload.text, command.payload.images)
336+
expect(mockTask.submitUserMessage).toHaveBeenCalledWith(
337+
command.payload.text,
338+
command.payload.images,
339+
undefined,
340+
undefined,
341+
)
337342
})
338343

339344
it("should handle ApproveAsk command", () => {

packages/types/npm/package.metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@roo-code/types",
3-
"version": "1.65.0",
3+
"version": "1.66.0",
44
"description": "TypeScript type definitions for Roo Code.",
55
"publishConfig": {
66
"access": "public",

packages/types/src/cloud.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ export const extensionInstanceSchema = z.object({
378378
task: extensionTaskSchema,
379379
taskAsk: clineMessageSchema.optional(),
380380
taskHistory: z.array(z.string()),
381+
mode: z.string().optional(),
382+
modes: z.array(z.object({ slug: z.string(), name: z.string() })).optional(),
383+
providerProfile: z.string().optional(),
384+
providerProfiles: z.array(z.object({ name: z.string(), provider: z.string().optional() })).optional(),
381385
})
382386

383387
export type ExtensionInstance = z.infer<typeof extensionInstanceSchema>
@@ -398,6 +402,9 @@ export enum ExtensionBridgeEventName {
398402
TaskResumable = RooCodeEventName.TaskResumable,
399403
TaskIdle = RooCodeEventName.TaskIdle,
400404

405+
ModeChanged = RooCodeEventName.ModeChanged,
406+
ProviderProfileChanged = RooCodeEventName.ProviderProfileChanged,
407+
401408
InstanceRegistered = "instance_registered",
402409
InstanceUnregistered = "instance_unregistered",
403410
HeartbeatUpdated = "heartbeat_updated",
@@ -469,6 +476,18 @@ export const extensionBridgeEventSchema = z.discriminatedUnion("type", [
469476
instance: extensionInstanceSchema,
470477
timestamp: z.number(),
471478
}),
479+
z.object({
480+
type: z.literal(ExtensionBridgeEventName.ModeChanged),
481+
instance: extensionInstanceSchema,
482+
mode: z.string(),
483+
timestamp: z.number(),
484+
}),
485+
z.object({
486+
type: z.literal(ExtensionBridgeEventName.ProviderProfileChanged),
487+
instance: extensionInstanceSchema,
488+
providerProfile: z.object({ name: z.string(), provider: z.string().optional() }),
489+
timestamp: z.number(),
490+
}),
472491
])
473492

474493
export type ExtensionBridgeEvent = z.infer<typeof extensionBridgeEventSchema>
@@ -490,6 +509,8 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
490509
payload: z.object({
491510
text: z.string(),
492511
images: z.array(z.string()).optional(),
512+
mode: z.string().optional(),
513+
providerProfile: z.string().optional(),
493514
}),
494515
timestamp: z.number(),
495516
}),
@@ -502,9 +523,7 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
502523
z.object({
503524
type: z.literal(ExtensionBridgeCommandName.ResumeTask),
504525
instanceId: z.string(),
505-
payload: z.object({
506-
taskId: z.string(),
507-
}),
526+
payload: z.object({ taskId: z.string() }),
508527
timestamp: z.number(),
509528
}),
510529
])
@@ -558,6 +577,8 @@ export const taskBridgeCommandSchema = z.discriminatedUnion("type", [
558577
payload: z.object({
559578
text: z.string(),
560579
images: z.array(z.string()).optional(),
580+
mode: z.string().optional(),
581+
providerProfile: z.string().optional(),
561582
}),
562583
timestamp: z.number(),
563584
}),

packages/types/src/events.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ export enum RooCodeEventName {
3636
TaskTokenUsageUpdated = "taskTokenUsageUpdated",
3737
TaskToolFailed = "taskToolFailed",
3838

39+
// Configuration Changes
40+
ModeChanged = "modeChanged",
41+
ProviderProfileChanged = "providerProfileChanged",
42+
3943
// Evals
4044
EvalPass = "evalPass",
4145
EvalFail = "evalFail",
@@ -81,6 +85,9 @@ export const rooCodeEventsSchema = z.object({
8185

8286
[RooCodeEventName.TaskToolFailed]: z.tuple([z.string(), toolNamesSchema, z.string()]),
8387
[RooCodeEventName.TaskTokenUsageUpdated]: z.tuple([z.string(), tokenUsageSchema]),
88+
89+
[RooCodeEventName.ModeChanged]: z.tuple([z.string()]),
90+
[RooCodeEventName.ProviderProfileChanged]: z.tuple([z.object({ name: z.string(), provider: z.string() })]),
8491
})
8592

8693
export type RooCodeEvents = z.infer<typeof rooCodeEventsSchema>

packages/types/src/provider-settings.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,11 @@ export const providerSettingsSchema = z.object({
414414
export type ProviderSettings = z.infer<typeof providerSettingsSchema>
415415

416416
export const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() })
417+
417418
export const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and(
418419
z.object({ id: z.string().optional() }),
419420
)
421+
420422
export type ProviderSettingsWithId = z.infer<typeof providerSettingsWithIdSchema>
421423

422424
export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options
@@ -454,7 +456,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str
454456
return "anthropic"
455457
}
456458

457-
// Vercel AI Gateway uses anthropic protocol for anthropic models
459+
// Vercel AI Gateway uses anthropic protocol for anthropic models.
458460
if (provider && provider === "vercel-ai-gateway" && modelId && modelId.toLowerCase().startsWith("anthropic/")) {
459461
return "anthropic"
460462
}

0 commit comments

Comments
 (0)