Skip to content

Commit f1bf436

Browse files
jrmrubens
andauthored
Add a RCC credit balance display (#9386)
* Add a RCC credit balance display * Replace the provider docs with the balance when logged in * PR feedback --------- Co-authored-by: Matt Rubens <[email protected]>
1 parent e618d88 commit f1bf436

File tree

12 files changed

+499
-9
lines changed

12 files changed

+499
-9
lines changed

packages/cloud/src/CloudAPI.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,14 @@ export class CloudAPI {
134134
.parse(data),
135135
})
136136
}
137+
138+
async creditBalance(): Promise<number> {
139+
return this.request("/api/extension/credit-balance", {
140+
method: "GET",
141+
parseResponse: (data) => {
142+
const result = z.object({ balance: z.number() }).parse(data)
143+
return result.balance
144+
},
145+
})
146+
}
137147
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import { describe, it, expect, vi, beforeEach, type Mock } from "vitest"
2+
import { CloudAPI } from "../CloudAPI.js"
3+
import { AuthenticationError, CloudAPIError } from "../errors.js"
4+
import type { AuthService } from "@roo-code/types"
5+
6+
// Mock the config module
7+
vi.mock("../config.js", () => ({
8+
getRooCodeApiUrl: () => "https://api.test.com",
9+
}))
10+
11+
// Mock the utils module
12+
vi.mock("../utils.js", () => ({
13+
getUserAgent: () => "test-user-agent",
14+
}))
15+
16+
describe("CloudAPI.creditBalance", () => {
17+
let mockAuthService: {
18+
getSessionToken: Mock<() => string | undefined>
19+
}
20+
let cloudAPI: CloudAPI
21+
22+
beforeEach(() => {
23+
mockAuthService = {
24+
getSessionToken: vi.fn(),
25+
}
26+
cloudAPI = new CloudAPI(mockAuthService as unknown as AuthService)
27+
28+
// Reset fetch mock
29+
global.fetch = vi.fn()
30+
})
31+
32+
it("should fetch credit balance successfully", async () => {
33+
const mockBalance = 12.34
34+
mockAuthService.getSessionToken.mockReturnValue("test-session-token")
35+
36+
global.fetch = vi.fn().mockResolvedValue({
37+
ok: true,
38+
json: async () => ({ balance: mockBalance }),
39+
})
40+
41+
const balance = await cloudAPI.creditBalance()
42+
43+
expect(balance).toBe(mockBalance)
44+
expect(global.fetch).toHaveBeenCalledWith(
45+
"https://api.test.com/api/extension/credit-balance",
46+
expect.objectContaining({
47+
method: "GET",
48+
headers: expect.objectContaining({
49+
Authorization: "Bearer test-session-token",
50+
"Content-Type": "application/json",
51+
"User-Agent": "test-user-agent",
52+
}),
53+
}),
54+
)
55+
})
56+
57+
it("should throw AuthenticationError when session token is missing", async () => {
58+
mockAuthService.getSessionToken.mockReturnValue(undefined)
59+
60+
await expect(cloudAPI.creditBalance()).rejects.toThrow(AuthenticationError)
61+
})
62+
63+
it("should handle API errors", async () => {
64+
mockAuthService.getSessionToken.mockReturnValue("test-session-token")
65+
66+
global.fetch = vi.fn().mockResolvedValue({
67+
ok: false,
68+
status: 500,
69+
statusText: "Internal Server Error",
70+
json: async () => ({ error: "Server error" }),
71+
})
72+
73+
await expect(cloudAPI.creditBalance()).rejects.toThrow(CloudAPIError)
74+
})
75+
76+
it("should handle network errors", async () => {
77+
mockAuthService.getSessionToken.mockReturnValue("test-session-token")
78+
79+
global.fetch = vi.fn().mockRejectedValue(new TypeError("fetch failed"))
80+
81+
await expect(cloudAPI.creditBalance()).rejects.toThrow(
82+
"Network error while calling /api/extension/credit-balance",
83+
)
84+
})
85+
86+
it("should handle invalid response format", async () => {
87+
mockAuthService.getSessionToken.mockReturnValue("test-session-token")
88+
89+
global.fetch = vi.fn().mockResolvedValue({
90+
ok: true,
91+
json: async () => ({ invalid: "response" }),
92+
})
93+
94+
await expect(cloudAPI.creditBalance()).rejects.toThrow()
95+
})
96+
})
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { webviewMessageHandler } from "../webviewMessageHandler"
3+
import { CloudService } from "@roo-code/cloud"
4+
5+
vi.mock("@roo-code/cloud", () => ({
6+
CloudService: {
7+
hasInstance: vi.fn(),
8+
instance: {
9+
cloudAPI: {
10+
creditBalance: vi.fn(),
11+
},
12+
},
13+
},
14+
}))
15+
16+
describe("webviewMessageHandler - requestRooCreditBalance", () => {
17+
let mockProvider: any
18+
19+
beforeEach(() => {
20+
mockProvider = {
21+
postMessageToWebview: vi.fn(),
22+
contextProxy: {
23+
getValue: vi.fn(),
24+
setValue: vi.fn(),
25+
},
26+
getCurrentTask: vi.fn(),
27+
cwd: "/test/path",
28+
}
29+
30+
vi.clearAllMocks()
31+
})
32+
33+
it("should handle requestRooCreditBalance and return balance", async () => {
34+
const mockBalance = 42.75
35+
const requestId = "test-request-id"
36+
37+
;(CloudService.hasInstance as any).mockReturnValue(true)
38+
;(CloudService.instance.cloudAPI!.creditBalance as any).mockResolvedValue(mockBalance)
39+
40+
await webviewMessageHandler(
41+
mockProvider as any,
42+
{
43+
type: "requestRooCreditBalance",
44+
requestId,
45+
} as any,
46+
)
47+
48+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
49+
type: "rooCreditBalance",
50+
requestId,
51+
values: { balance: mockBalance },
52+
})
53+
})
54+
55+
it("should handle CloudAPI errors", async () => {
56+
const requestId = "test-request-id"
57+
const errorMessage = "Failed to fetch balance"
58+
59+
;(CloudService.hasInstance as any).mockReturnValue(true)
60+
;(CloudService.instance.cloudAPI!.creditBalance as any).mockRejectedValue(new Error(errorMessage))
61+
62+
await webviewMessageHandler(
63+
mockProvider as any,
64+
{
65+
type: "requestRooCreditBalance",
66+
requestId,
67+
} as any,
68+
)
69+
70+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
71+
type: "rooCreditBalance",
72+
requestId,
73+
values: { error: errorMessage },
74+
})
75+
})
76+
77+
it("should handle missing CloudService", async () => {
78+
const requestId = "test-request-id"
79+
80+
;(CloudService.hasInstance as any).mockReturnValue(false)
81+
82+
await webviewMessageHandler(
83+
mockProvider as any,
84+
{
85+
type: "requestRooCreditBalance",
86+
requestId,
87+
} as any,
88+
)
89+
90+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
91+
type: "rooCreditBalance",
92+
requestId,
93+
values: { error: "Cloud service not available" },
94+
})
95+
})
96+
97+
it("should handle missing cloudAPI", async () => {
98+
const requestId = "test-request-id"
99+
100+
;(CloudService.hasInstance as any).mockReturnValue(true)
101+
;(CloudService.instance as any).cloudAPI = null
102+
103+
await webviewMessageHandler(
104+
mockProvider as any,
105+
{
106+
type: "requestRooCreditBalance",
107+
requestId,
108+
} as any,
109+
)
110+
111+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
112+
type: "rooCreditBalance",
113+
requestId,
114+
values: { error: "Cloud service not available" },
115+
})
116+
})
117+
})

src/core/webview/webviewMessageHandler.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,31 @@ export const webviewMessageHandler = async (
10061006
}
10071007
break
10081008
}
1009+
case "requestRooCreditBalance": {
1010+
// Fetch Roo credit balance using CloudAPI
1011+
const requestId = message.requestId
1012+
try {
1013+
if (!CloudService.hasInstance() || !CloudService.instance.cloudAPI) {
1014+
throw new Error("Cloud service not available")
1015+
}
1016+
1017+
const balance = await CloudService.instance.cloudAPI.creditBalance()
1018+
1019+
provider.postMessageToWebview({
1020+
type: "rooCreditBalance",
1021+
requestId,
1022+
values: { balance },
1023+
})
1024+
} catch (error) {
1025+
const errorMessage = error instanceof Error ? error.message : String(error)
1026+
provider.postMessageToWebview({
1027+
type: "rooCreditBalance",
1028+
requestId,
1029+
values: { error: errorMessage },
1030+
})
1031+
}
1032+
break
1033+
}
10091034
case "requestOpenAiModels":
10101035
if (message?.values?.baseUrl && message?.values?.apiKey) {
10111036
const openAiModels = await getOpenAiModels(

src/shared/ExtensionMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ export interface ExtensionMessage {
112112
| "authenticatedUser"
113113
| "condenseTaskContextResponse"
114114
| "singleRouterModelFetchResponse"
115+
| "rooCreditBalance"
115116
| "indexingStatusUpdate"
116117
| "indexCleared"
117118
| "codebaseIndexConfig"

src/shared/WebviewMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ export interface WebviewMessage {
6060
| "requestOllamaModels"
6161
| "requestLmStudioModels"
6262
| "requestRooModels"
63+
| "requestRooCreditBalance"
6364
| "requestVsCodeLmModels"
6465
| "requestHuggingFaceModels"
6566
| "openImage"

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ import { TemperatureControl } from "./TemperatureControl"
115115
import { RateLimitSecondsControl } from "./RateLimitSecondsControl"
116116
import { ConsecutiveMistakeLimitControl } from "./ConsecutiveMistakeLimitControl"
117117
import { BedrockCustomArn } from "./providers/BedrockCustomArn"
118+
import { RooBalanceDisplay } from "./providers/RooBalanceDisplay"
118119
import { buildDocLink } from "@src/utils/docLinks"
119120

120121
export interface ApiOptionsProps {
@@ -460,12 +461,16 @@ const ApiOptions = ({
460461
<div className="flex flex-col gap-1 relative">
461462
<div className="flex justify-between items-center">
462463
<label className="block font-medium mb-1">{t("settings:providers.apiProvider")}</label>
463-
{docs && (
464-
<div className="text-xs text-vscode-descriptionForeground">
465-
<VSCodeLink href={docs.url} className="hover:text-vscode-foreground" target="_blank">
466-
{t("settings:providers.providerDocumentation", { provider: docs.name })}
467-
</VSCodeLink>
468-
</div>
464+
{selectedProvider === "roo" && cloudIsAuthenticated ? (
465+
<RooBalanceDisplay />
466+
) : (
467+
docs && (
468+
<div className="text-xs text-vscode-descriptionForeground">
469+
<VSCodeLink href={docs.url} className="hover:text-vscode-foreground" target="_blank">
470+
{t("settings:providers.providerDocumentation", { provider: docs.name })}
471+
</VSCodeLink>
472+
</div>
473+
)
469474
)}
470475
</div>
471476
<SearchableSelect

webview-ui/src/components/settings/__tests__/ApiOptions.spec.tsx

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
55

66
import { type ModelInfo, type ProviderSettings, openAiModelInfoSaneDefaults } from "@roo-code/types"
77

8-
import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext"
8+
import * as ExtensionStateContext from "@src/context/ExtensionStateContext"
9+
const { ExtensionStateContextProvider } = ExtensionStateContext
910

1011
import ApiOptions, { ApiOptionsProps } from "../ApiOptions"
1112

@@ -238,6 +239,18 @@ vi.mock("../providers/LiteLLM", () => ({
238239
),
239240
}))
240241

242+
// Mock Roo provider for tests
243+
vi.mock("../providers/Roo", () => ({
244+
Roo: ({ cloudIsAuthenticated }: any) => (
245+
<div data-testid="roo-provider">{cloudIsAuthenticated ? "Authenticated" : "Not Authenticated"}</div>
246+
),
247+
}))
248+
249+
// Mock RooBalanceDisplay for tests
250+
vi.mock("../providers/RooBalanceDisplay", () => ({
251+
RooBalanceDisplay: () => <div data-testid="roo-balance-display">Balance: $10.00</div>,
252+
}))
253+
241254
vi.mock("@src/components/ui/hooks/useSelectedModel", () => ({
242255
useSelectedModel: vi.fn((apiConfiguration: ProviderSettings) => {
243256
if (apiConfiguration.apiModelId?.includes("thinking")) {
@@ -563,4 +576,40 @@ describe("ApiOptions", () => {
563576
expect(screen.queryByTestId("litellm-provider")).not.toBeInTheDocument()
564577
})
565578
})
579+
580+
describe("Roo provider tests", () => {
581+
it("shows balance display when authenticated", () => {
582+
// Mock useExtensionState to return authenticated state
583+
const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState")
584+
useExtensionStateMock.mockReturnValue({
585+
cloudIsAuthenticated: true,
586+
organizationAllowList: { providers: {} },
587+
} as any)
588+
589+
renderApiOptions({
590+
apiConfiguration: {
591+
apiProvider: "roo",
592+
},
593+
})
594+
595+
expect(screen.getByTestId("roo-balance-display")).toBeInTheDocument()
596+
})
597+
598+
it("does not show balance display when not authenticated", () => {
599+
// Mock useExtensionState to return unauthenticated state
600+
const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState")
601+
useExtensionStateMock.mockReturnValue({
602+
cloudIsAuthenticated: false,
603+
organizationAllowList: { providers: {} },
604+
} as any)
605+
606+
renderApiOptions({
607+
apiConfiguration: {
608+
apiProvider: "roo",
609+
},
610+
})
611+
612+
expect(screen.queryByTestId("roo-balance-display")).not.toBeInTheDocument()
613+
})
614+
})
566615
})

webview-ui/src/components/settings/providers/Roo.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ export const Roo = ({
3030
return (
3131
<>
3232
{cloudIsAuthenticated ? (
33-
<div className="text-sm text-vscode-descriptionForeground">
34-
{t("settings:providers.roo.authenticatedMessage")}
33+
<div className="flex justify-between items-center mb-2">
34+
<div className="text-sm text-vscode-descriptionForeground">
35+
{t("settings:providers.roo.authenticatedMessage")}
36+
</div>
3537
</div>
3638
) : (
3739
<div className="flex flex-col gap-2">

0 commit comments

Comments
 (0)