Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/api/providers/__tests__/gemini-handler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => {
groundingMetadata: {
groundingChunks: [
{ web: null }, // Missing URI
{ web: { uri: "https://example.com" } }, // Valid
{ web: { uri: "https://example.com", title: "Example Site" } }, // Valid
{}, // Missing web property entirely
],
},
Expand All @@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => {
messages.push(chunk)
}

// Should only include valid citations
const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]"))
expect(sourceMessage).toBeDefined()
if (sourceMessage && "text" in sourceMessage) {
expect(sourceMessage.text).toContain("https://example.com")
expect(sourceMessage.text).not.toContain("[1]")
expect(sourceMessage.text).not.toContain("[3]")
// Should have the text response
const textMessage = messages.find((m) => m.type === "text")
expect(textMessage).toBeDefined()
if (textMessage && "text" in textMessage) {
expect(textMessage.text).toBe("test response")
}

// Should have grounding chunk with only valid sources
const groundingMessage = messages.find((m) => m.type === "grounding")
expect(groundingMessage).toBeDefined()
if (groundingMessage && "sources" in groundingMessage) {
expect(groundingMessage.sources).toHaveLength(1)
expect(groundingMessage.sources[0].url).toBe("https://example.com")
expect(groundingMessage.sources[0].title).toBe("Example Site")
}
})

Expand Down
36 changes: 23 additions & 13 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"

import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
import { t } from "i18next"
import type { ApiStream } from "../transform/stream"
import type { ApiStream, GroundingSource } from "../transform/stream"
import { getModelParams } from "../transform/model-params"

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
Expand Down Expand Up @@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
}

if (pendingGroundingMetadata) {
const citations = this.extractCitationsOnly(pendingGroundingMetadata)
if (citations) {
yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
const sources = this.extractGroundingSources(pendingGroundingMetadata)
if (sources.length > 0) {
yield { type: "grounding", sources }
}
}

Expand Down Expand Up @@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
}

private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] {
const chunks = groundingMetadata?.groundingChunks

if (!chunks) {
return null
return []
}

const citationLinks = chunks
.map((chunk, i) => {
return chunks
.map((chunk): GroundingSource | null => {
const uri = chunk.web?.uri
const title = chunk.web?.title || uri || "Unknown Source"

if (uri) {
return `[${i + 1}](${uri})`
return {
title,
url: uri,
}
}
return null
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding debug logging when sources are filtered out to help with troubleshooting malformed grounding metadata during development. This would make it easier to identify issues with the grounding chunks.

})
.filter((link): link is string => link !== null)
.filter((source): source is GroundingSource => source !== null)
}

private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
const sources = this.extractGroundingSources(groundingMetadata)

if (citationLinks.length > 0) {
return citationLinks.join(", ")
if (sources.length === 0) {
return null
}

return null
const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`)
return citationLinks.join(", ")
}

async completePrompt(prompt: string): Promise<string> {
Expand Down
18 changes: 17 additions & 1 deletion src/api/transform/stream.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
export type ApiStream = AsyncGenerator<ApiStreamChunk>

export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
export type ApiStreamChunk =
| ApiStreamTextChunk
| ApiStreamUsageChunk
| ApiStreamReasoningChunk
| ApiStreamGroundingChunk
| ApiStreamError

export interface ApiStreamError {
type: "error"
Expand All @@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk {
reasoningTokens?: number
totalCost?: number
}

export interface ApiStreamGroundingChunk {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding JSDoc documentation for the new grounding types to clarify their purpose. For example:

/**
 * Represents grounding metadata from search results or citations.
 * Used to decouple source information from the main content stream.
 */
export interface ApiStreamGroundingChunk {
	type: "grounding"
	sources: GroundingSource[]
}

type: "grounding"
sources: GroundingSource[]
}

export interface GroundingSource {
title: string
url: string
snippet?: string
}
22 changes: 20 additions & 2 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import { CloudService, BridgeOrchestrator } from "@roo-code/cloud"

// api
import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
import { ApiStream } from "../../api/transform/stream"
import { ApiStream, GroundingSource } from "../../api/transform/stream"

// shared
import { findLastIndex } from "../../shared/array"
Expand Down Expand Up @@ -1783,7 +1783,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
this.didFinishAbortingStream = true
}

// Reset streaming state.
// Reset streaming state for each new API request
this.currentStreamingContentIndex = 0
this.currentStreamingDidCheckpoint = false
this.assistantMessageContent = []
Expand All @@ -1804,6 +1804,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
const stream = this.attemptApiRequest()
let assistantMessage = ""
let reasoningMessage = ""
let pendingGroundingSources: GroundingSource[] = []
this.isStreaming = true

try {
Expand All @@ -1830,6 +1831,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
cacheReadTokens += chunk.cacheReadTokens ?? 0
totalCost = chunk.totalCost
break
case "grounding":
// Handle grounding sources separately from regular content
// to prevent state persistence issues - store them separately
if (chunk.sources && chunk.sources.length > 0) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good separation of concerns here! The grounding sources are now properly handled independently from the assistant message content, which should prevent the race condition issues mentioned in #6372.

pendingGroundingSources.push(...chunk.sources)
}
break
case "text": {
assistantMessage += chunk.text

Expand Down Expand Up @@ -2123,6 +2131,16 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
let didEndLoop = false

if (assistantMessage.length > 0) {
// Display grounding sources to the user if they exist
if (pendingGroundingSources.length > 0) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice implementation detail - storing grounding sources separately and only displaying them to the user while keeping the API history clean. This prevents the state persistence issues that were causing the race condition.

const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`)
const sourcesText = `${t("common:gemini.sources")} ${citationLinks.join(", ")}`

await this.say("text", sourcesText, undefined, false, undefined, undefined, {
isNonInteractive: true,
})
}

await this.addToApiConversationHistory({
role: "assistant",
content: [{ type: "text", text: assistantMessage }],
Expand Down
Loading
Loading