Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// npx vitest src/core/sliding-window/__tests__/sliding-window.spec.ts
// cd src && npx vitest run core/context-management/__tests__/context-management.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"

Expand All @@ -9,12 +9,7 @@ import { BaseProvider } from "../../../api/providers/base-provider"
import { ApiMessage } from "../../task-persistence/apiMessages"
import * as condenseModule from "../../condense"

import {
TOKEN_BUFFER_PERCENTAGE,
estimateTokenCount,
truncateConversation,
truncateConversationIfNeeded,
} from "../index"
import { TOKEN_BUFFER_PERCENTAGE, estimateTokenCount, truncateConversation, manageContext } from "../index"

// Create a mock ApiHandler for testing
class MockApiHandler extends BaseProvider {
Expand Down Expand Up @@ -49,7 +44,7 @@ class MockApiHandler extends BaseProvider {
const mockApiHandler = new MockApiHandler()
const taskId = "test-task-id"

describe("Sliding Window", () => {
describe("Context Management", () => {
beforeEach(() => {
if (!TelemetryService.hasInstance()) {
TelemetryService.createInstance([])
Expand Down Expand Up @@ -234,9 +229,9 @@ describe("Sliding Window", () => {
})

/**
* Tests for the truncateConversationIfNeeded function
* Tests for the manageContext function
*/
describe("truncateConversationIfNeeded", () => {
describe("manageContext", () => {
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
contextWindow,
supportsPromptCache: true,
Expand All @@ -261,7 +256,7 @@ describe("Sliding Window", () => {
{ ...messages[messages.length - 1], content: "" },
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -302,7 +297,7 @@ describe("Sliding Window", () => {
messagesWithSmallContent[4],
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -337,7 +332,7 @@ describe("Sliding Window", () => {

// Test below threshold
const belowThreshold = 69999
const result1 = await truncateConversationIfNeeded({
const result1 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: belowThreshold,
contextWindow: modelInfo1.contextWindow,
Expand All @@ -351,7 +346,7 @@ describe("Sliding Window", () => {
currentProfileId: "default",
})

const result2 = await truncateConversationIfNeeded({
const result2 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: belowThreshold,
contextWindow: modelInfo2.contextWindow,
Expand All @@ -372,7 +367,7 @@ describe("Sliding Window", () => {

// Test above threshold
const aboveThreshold = 70001
const result3 = await truncateConversationIfNeeded({
const result3 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: aboveThreshold,
contextWindow: modelInfo1.contextWindow,
Expand All @@ -386,7 +381,7 @@ describe("Sliding Window", () => {
currentProfileId: "default",
})

const result4 = await truncateConversationIfNeeded({
const result4 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: aboveThreshold,
contextWindow: modelInfo2.contextWindow,
Expand Down Expand Up @@ -422,7 +417,7 @@ describe("Sliding Window", () => {
// Set base tokens so total is well below threshold + buffer even with small content added
const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE
const baseTokensForSmall = availableTokens - smallContentTokens - dynamicBuffer - 10
const resultWithSmall = await truncateConversationIfNeeded({
const resultWithSmall = await manageContext({
messages: messagesWithSmallContent,
totalTokens: baseTokensForSmall,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -457,7 +452,7 @@ describe("Sliding Window", () => {

// Set base tokens so we're just below threshold without content, but over with content
const baseTokensForLarge = availableTokens - Math.floor(largeContentTokens / 2)
const resultWithLarge = await truncateConversationIfNeeded({
const resultWithLarge = await manageContext({
messages: messagesWithLargeContent,
totalTokens: baseTokensForLarge,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -485,7 +480,7 @@ describe("Sliding Window", () => {

// Set base tokens so we're just below threshold without content
const baseTokensForVeryLarge = availableTokens - Math.floor(veryLargeContentTokens / 2)
const resultWithVeryLarge = await truncateConversationIfNeeded({
const resultWithVeryLarge = await manageContext({
messages: messagesWithVeryLargeContent,
totalTokens: baseTokensForVeryLarge,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -523,7 +518,7 @@ describe("Sliding Window", () => {
messagesWithSmallContent[4],
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -570,7 +565,7 @@ describe("Sliding Window", () => {
{ ...messages[messages.length - 1], content: "" },
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -637,7 +632,7 @@ describe("Sliding Window", () => {
messagesWithSmallContent[4],
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -684,7 +679,7 @@ describe("Sliding Window", () => {
messagesWithSmallContent[4],
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -741,7 +736,7 @@ describe("Sliding Window", () => {
{ ...messages[messages.length - 1], content: "" },
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow,
Expand Down Expand Up @@ -793,7 +788,7 @@ describe("Sliding Window", () => {
{ ...messages[messages.length - 1], content: "" },
]

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow,
Expand Down Expand Up @@ -880,7 +875,7 @@ describe("Sliding Window", () => {
.spyOn(condenseModule, "summarizeConversation")
.mockResolvedValue(mockSummarizeResponse)

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow,
Expand Down Expand Up @@ -946,7 +941,7 @@ describe("Sliding Window", () => {
.spyOn(condenseModule, "summarizeConversation")
.mockResolvedValue(mockSummarizeResponse)

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow,
Expand Down Expand Up @@ -1000,7 +995,7 @@ describe("Sliding Window", () => {
vi.clearAllMocks()
const summarizeSpy = vi.spyOn(condenseModule, "summarizeConversation")

const result = await truncateConversationIfNeeded({
const result = await manageContext({
messages: messagesWithSmallContent,
totalTokens,
contextWindow,
Expand Down Expand Up @@ -1030,10 +1025,10 @@ describe("Sliding Window", () => {
})

/**
* Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
* Tests for the getMaxTokens function (private but tested through manageContext)
*/
describe("getMaxTokens", () => {
// We'll test this indirectly through truncateConversationIfNeeded
// We'll test this indirectly through manageContext
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
contextWindow,
supportsPromptCache: true, // Not relevant for getMaxTokens
Expand Down Expand Up @@ -1061,7 +1056,7 @@ describe("Sliding Window", () => {

// Account for the dynamic buffer which is 10% of context window (10,000 tokens)
// Below max tokens and buffer - no truncation
const result1 = await truncateConversationIfNeeded({
const result1 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 39999, // Well below threshold + dynamic buffer
contextWindow: modelInfo.contextWindow,
Expand All @@ -1082,7 +1077,7 @@ describe("Sliding Window", () => {
})

// Above max tokens - truncate
const result2 = await truncateConversationIfNeeded({
const result2 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 50001, // Above threshold
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -1114,7 +1109,7 @@ describe("Sliding Window", () => {

// Account for the dynamic buffer which is 10% of context window (10,000 tokens)
// Below max tokens and buffer - no truncation
const result1 = await truncateConversationIfNeeded({
const result1 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 81807, // Well below threshold + dynamic buffer (91808 - 10000 = 81808)
contextWindow: modelInfo.contextWindow,
Expand All @@ -1135,7 +1130,7 @@ describe("Sliding Window", () => {
})

// Above max tokens - truncate
const result2 = await truncateConversationIfNeeded({
const result2 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 81809, // Above threshold (81808)
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -1166,7 +1161,7 @@ describe("Sliding Window", () => {
]

// Below max tokens and buffer - no truncation
const result1 = await truncateConversationIfNeeded({
const result1 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 34999, // Well below threshold + buffer
contextWindow: modelInfo.contextWindow,
Expand All @@ -1182,7 +1177,7 @@ describe("Sliding Window", () => {
expect(result1.messages).toEqual(messagesWithSmallContent)

// Above max tokens - truncate
const result2 = await truncateConversationIfNeeded({
const result2 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 40001, // Above threshold
contextWindow: modelInfo.contextWindow,
Expand Down Expand Up @@ -1211,7 +1206,7 @@ describe("Sliding Window", () => {

// Account for the dynamic buffer which is 10% of context window (20,000 tokens for this test)
// Below max tokens and buffer - no truncation
const result1 = await truncateConversationIfNeeded({
const result1 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 149999, // Well below threshold + dynamic buffer
contextWindow: modelInfo.contextWindow,
Expand All @@ -1227,7 +1222,7 @@ describe("Sliding Window", () => {
expect(result1.messages).toEqual(messagesWithSmallContent)

// Above max tokens - truncate
const result2 = await truncateConversationIfNeeded({
const result2 = await manageContext({
messages: messagesWithSmallContent,
totalTokens: 170001, // Above threshold
contextWindow: modelInfo.contextWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@ import { ApiMessage } from "../task-persistence/apiMessages"
import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "@roo-code/types"

/**
* Default percentage of the context window to use as a buffer when deciding when to truncate
* Context Management
*
* This module provides Context Management for conversations, combining:
* - Intelligent condensation of prior messages when approaching configured thresholds
* - Sliding window truncation as a fallback when necessary
*
* Behavior and exports are preserved exactly from the previous sliding-window implementation.
*/

/**
* Default percentage of the context window to use as a buffer when deciding when to truncate.
* Used by Context Management to determine when to trigger condensation or (fallback) sliding window truncation.
*/
export const TOKEN_BUFFER_PERCENTAGE = 0.1

Expand All @@ -33,6 +44,8 @@ export async function estimateTokenCount(
* The first message is always retained, and a specified fraction (rounded to an even number)
* of messages from the beginning (excluding the first) is removed.
*
* This implements the sliding window truncation behavior.
*
* @param {ApiMessage[]} messages - The conversation messages.
* @param {number} fracToRemove - The fraction (between 0 and 1) of messages (excluding the first) to remove.
* @param {string} taskId - The task ID for the conversation, used for telemetry
Expand All @@ -50,20 +63,16 @@ export function truncateConversation(messages: ApiMessage[], fracToRemove: numbe
}

/**
* Conditionally truncates the conversation messages if the total token count
* exceeds the model's limit, considering the size of incoming content.
* Context Management: Conditionally manages the conversation context when approaching limits.
*
* @param {ApiMessage[]} messages - The conversation messages.
* @param {number} totalTokens - The total number of tokens in the conversation (excluding the last user message).
* @param {number} contextWindow - The context window size.
* @param {number} maxTokens - The maximum number of tokens allowed.
* @param {ApiHandler} apiHandler - The API handler to use for token counting.
* @param {boolean} autoCondenseContext - Whether to use LLM summarization or sliding window implementation
* @param {string} systemPrompt - The system prompt, used for estimating the new context size after summarizing.
* @returns {ApiMessage[]} The original or truncated conversation messages.
* Attempts intelligent condensation of prior messages when thresholds are reached.
* Falls back to sliding window truncation if condensation is unavailable or fails.
*
* @param {ContextManagementOptions} options - The options for truncation/condensation
* @returns {Promise<ApiMessage[]>} The original, condensed, or truncated conversation messages.
*/

type TruncateOptions = {
export type ContextManagementOptions = {
messages: ApiMessage[]
totalTokens: number
contextWindow: number
Expand All @@ -79,16 +88,15 @@ type TruncateOptions = {
currentProfileId: string
}

type TruncateResponse = SummarizeResponse & { prevContextTokens: number }
export type ContextManagementResult = SummarizeResponse & { prevContextTokens: number }

/**
* Conditionally truncates the conversation messages if the total token count
* exceeds the model's limit, considering the size of incoming content.
* Conditionally manages conversation context (condense and fallback truncation).
*
* @param {TruncateOptions} options - The options for truncation
* @returns {Promise<ApiMessage[]>} The original or truncated conversation messages.
* @param {ContextManagementOptions} options - The options for truncation/condensation
* @returns {Promise<ApiMessage[]>} The original, condensed, or truncated conversation messages.
*/
export async function truncateConversationIfNeeded({
export async function manageContext({
messages,
totalTokens,
contextWindow,
Expand All @@ -102,7 +110,7 @@ export async function truncateConversationIfNeeded({
condensingApiHandler,
profileThresholds,
currentProfileId,
}: TruncateOptions): Promise<TruncateResponse> {
}: ContextManagementOptions): Promise<ContextManagementResult> {
let error: string | undefined
let cost = 0
// Calculate the maximum tokens reserved for response
Expand Down
Loading