diff --git a/docs/mkdocs/en/tool.md b/docs/mkdocs/en/tool.md index f0a696dc5..e8c908803 100644 --- a/docs/mkdocs/en/tool.md +++ b/docs/mkdocs/en/tool.md @@ -541,6 +541,70 @@ searchTool := duckduckgo.NewTool( ) ``` +### Claude Code ToolSet + +`tool/claudecode` provides a code-oriented ToolSet that exposes a Claude Code-style tool surface inside the framework. It covers file editing, repository search, command execution, and web retrieval, and can be attached directly to `LLMAgent` or other runtimes. If your goal is to invoke the local Claude Code CLI and consume its execution trace and tool events, see the [Claude Code Agent guide](claudecode.md). + +By default, `claudecode` exposes a core set of workflow tools: `Bash`, `TaskStop`, `TaskOutput`, `Read`, `Glob`, `Grep`, `WebFetch`, and `WebSearch`. When read-only mode is disabled, it also exposes `Write`, `Edit`, and `NotebookEdit`. + +The following table lists the tools currently exposed by `claudecode`: + +| Tool | Description | +| --- | --- | +| `Bash` | Executes local shell commands. | +| `TaskStop` | Stops a background task started by `Bash`. | +| `TaskOutput` | Reads the current or final output of a background task. | +| `Read` | Reads file contents. | +| `Glob` | Finds files by path pattern. | +| `Grep` | Searches repository content. | +| `WebFetch` | Fetches the content of a specific URL. | +| `WebSearch` | Performs an open web search. | +| `Write` | Creates a file or overwrites a file with complete contents. Only exposed when read-only mode is disabled. | +| `Edit` | Performs targeted replacement in an existing text file. Only exposed when read-only mode is disabled. | +| `NotebookEdit` | Edits `.ipynb` files at the cell level. Only exposed when read-only mode is disabled. | + +#### Basic Usage + +```go +import ( + "log" + + "trpc.group/trpc-go/trpc-agent-go/agent/llmagent" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/claudecode" +) + +toolSet, err := claudecode.NewToolSet( + claudecode.WithBaseDir("."), +) +if err != nil { + log.Fatal(err) +} +defer toolSet.Close() + +agent := llmagent.New( + "claude-style-agent", + llmagent.WithToolSets([]tool.ToolSet{toolSet}), +) +``` + +`llmagent.WithToolSets(...)` attaches these tools as a ToolSet. Calling `Tools()` returns the flattened list of individual tools. + +#### Common Options + +The main `tool/claudecode` options focus on working directory, read-only mode, and web behavior: + +| Option | Description | +| --- | --- | +| `WithName(name)` | Overrides the ToolSet name. The default name is `claudecode`. | +| `WithBaseDir(dir)` | Sets the base directory used by file, search, and command execution tools. | +| `WithReadOnly(readOnly)` | Removes `Write`, `Edit`, and `NotebookEdit` when enabled. | +| `WithMaxFileSize(size)` | Limits the maximum readable file size. | +| `WithWebFetchOptions(opts)` | Configures domain policy, timeout, and content handling for `WebFetch`. | +| `WithWebSearchOptions(opts)` | Configures backend, paging, and request options for `WebSearch`. | + +`WithBaseDir` defines the working scope for `Read`, `Write`, `Edit`, `Glob`, and `Grep`, and also determines the default working directory for `Bash`. When read-only mode is enabled, the toolset keeps only read, search, command, and web capabilities. When read-only mode is disabled, it also exposes `Write`, `Edit`, and `NotebookEdit`. + ## MCP Tools MCP (Model Context Protocol) is an open protocol that standardizes how applications provide context to LLMs. MCP tools are based on JSON-RPC 2.0 and provide standardized integration with external services for Agents. diff --git a/docs/mkdocs/zh/tool.md b/docs/mkdocs/zh/tool.md index 3997ecfcb..edd9c6c71 100644 --- a/docs/mkdocs/zh/tool.md +++ b/docs/mkdocs/zh/tool.md @@ -535,6 +535,69 @@ searchTool := duckduckgo.NewTool( ) ``` +### Claude Code ToolSet + +`tool/claudecode` 提供了一组面向代码工作的 ToolSet,用于在框架内部暴露与 Claude Code 接近的工具接口。它覆盖文件读写、代码检索、命令执行和网页获取等能力,可以直接挂接到 `LLMAgent` 或其他运行时。如果你的目标是调用本地 Claude Code CLI,并消费 CLI 的执行轨迹与工具事件,请参考 [Claude Code Agent 使用指南](claudecode.md)。 + +从能力组成上看,`claudecode` 默认会提供一组代码工作流工具,包括 `Bash`、`TaskStop`、`TaskOutput`、`Read`、`Glob`、`Grep`、`WebFetch` 和 `WebSearch`。在非只读模式下,还会额外提供 `Write`、`Edit` 和 `NotebookEdit`。 + +下表列出了当前 `claudecode` 工具集中的主要工具及其用途: + +| 工具名 | 说明 | +| --- | --- | +| `Bash` | 执行本地 Shell 命令。 | +| `TaskStop` | 停止由 `Bash` 以后台模式启动的任务。 | +| `TaskOutput` | 读取后台任务的当前输出或最终输出。 | +| `Read` | 读取文件内容。 | +| `Glob` | 按路径模式查找文件。 | +| `Grep` | 按内容搜索仓库。 | +| `WebFetch` | 抓取指定 URL 的页面内容。 | +| `WebSearch` | 进行开放式网页搜索。 | +| `Write` | 创建文件或用完整内容覆盖文件,仅在非只读模式下暴露。 | +| `Edit` | 对已有文本文件做局部替换,仅在非只读模式下暴露。 | +| `NotebookEdit` | 按 cell 粒度编辑 `.ipynb` 文件,仅在非只读模式下暴露。 | + +#### 基本用法 + +```go +import ( + "log" + "trpc.group/trpc-go/trpc-agent-go/agent/llmagent" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/claudecode" +) + +toolSet, err := claudecode.NewToolSet( + claudecode.WithBaseDir("."), +) +if err != nil { + log.Fatal(err) +} +defer toolSet.Close() + +agent := llmagent.New( + "claude-style-agent", + llmagent.WithToolSets([]tool.ToolSet{toolSet}), +) +``` + +`llmagent.WithToolSets(...)` 会以 ToolSet 形式接入这组工具;如果调用 `Tools()`,则会得到展开后的单个工具列表。 + +#### 常用配置 + +`tool/claudecode` 的配置重点围绕工作目录、只读模式和 Web 能力展开: + +| Option | 说明 | +| --- | --- | +| `WithName(name)` | 覆盖 ToolSet 名称,默认值为 `claudecode`。 | +| `WithBaseDir(dir)` | 指定工具集的基础目录。文件、检索和命令执行都会以此为基准。 | +| `WithReadOnly(readOnly)` | 启用只读模式后,不再暴露 `Write`、`Edit`、`NotebookEdit`。 | +| `WithMaxFileSize(size)` | 限制单个文件可读取的最大尺寸。 | +| `WithWebFetchOptions(opts)` | 配置 `WebFetch` 的域名策略、超时与内容处理方式。 | +| `WithWebSearchOptions(opts)` | 配置 `WebSearch` 的后端、分页参数与请求选项。 | + +`WithBaseDir` 定义了 `Read`、`Write`、`Edit`、`Glob`、`Grep` 等文件相关工具的工作范围,也决定了 `Bash` 的默认执行目录。启用只读模式后,工具集只保留读取、检索、命令执行和 Web 相关能力;关闭只读模式后,会额外暴露 `Write`、`Edit` 与 `NotebookEdit`。 + ## MCP Tools 协议工具 MCP(Model Context Protocol)是一个开放协议,标准化了应用程序向 LLM 提供上下文的方式。MCP 工具基于 JSON-RPC 2.0 协议,为 Agent 提供了与外部服务的标准化集成能力。 diff --git a/tool/claudecode/bash.go b/tool/claudecode/bash.go new file mode 100644 index 000000000..0211b05e1 --- /dev/null +++ b/tool/claudecode/bash.go @@ -0,0 +1,167 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/google/uuid" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newBashTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(ctx context.Context, in bashInput) (bashOutput, error) { + if in.RunInBackground { + return runBackgroundCommand(runtime, in.Command) + } + return runForegroundCommand(ctx, runtime, in) + }, + function.WithName(toolBash), + function.WithDescription(bashDescription()), + ), nil +} + +func runForegroundCommand(ctx context.Context, runtime *runtime, in bashInput) (bashOutput, error) { + timeoutMs := bashTimeout(in.Timeout) + start := time.Now() + runCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + result, err := runCapturedProcess(runCtx, runtime.currentBaseDir(), nil, "bash", "-lc", in.Command) + durationMs := time.Since(start).Milliseconds() + timedOut := errorsIsDeadlineExceeded(runCtx.Err()) + exitCode := result.ExitCode + if err != nil && exitCode == 0 { + if timedOut { + exitCode = 124 + } else { + exitCode = 1 + } + } + stdout := string(result.Stdout) + stderr := string(result.Stderr) + return bashOutput{ + Command: in.Command, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + Output: joinOutput(stdout, stderr), + DurationMs: durationMs, + TimedOut: timedOut, + }, nil +} + +func bashTimeout(timeout *int) int { + timeoutMs := defaultBashTimeoutMs + if raw := os.Getenv("BASH_DEFAULT_TIMEOUT_MS"); raw != "" { + if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 { + timeoutMs = parsed + } + } + if timeout != nil { + timeoutMs = *timeout + } + if timeoutMs <= 0 { + timeoutMs = defaultBashTimeoutMs + } + if timeoutMs > maxBashTimeoutMs { + timeoutMs = maxBashTimeoutMs + } + return timeoutMs +} + +func runBackgroundCommand(runtime *runtime, command string) (bashOutput, error) { + taskID := uuid.NewString() + outputDir := filepath.Join(os.TempDir(), "trpc-agent-go-claudecode") + if err := os.MkdirAll(outputDir, 0o755); err != nil { + return bashOutput{}, err + } + outputPath := filepath.Join(outputDir, taskID+".log") + outputFile, err := os.Create(outputPath) + if err != nil { + return bashOutput{}, err + } + process, err := startProcess(runtime.currentBaseDir(), nil, outputFile, outputFile, "bash", "-lc", command) + if err != nil { + _ = outputFile.Close() + return bashOutput{}, err + } + runtime.taskState.mu.Lock() + runtime.taskState.tasks[taskID] = &backgroundTask{ + ID: taskID, + Command: command, + Type: toolBash, + OutputPath: outputPath, + Process: process, + Status: "running", + } + runtime.taskState.mu.Unlock() + go func() { + state, waitErr := process.Wait() + _ = outputFile.Close() + runtime.taskState.mu.Lock() + task := runtime.taskState.tasks[taskID] + if task != nil && task.Status == "running" { + task.Status = backgroundTaskStatus(waitErr, state) + exitCode := backgroundTaskExitCode(waitErr, state) + task.ExitCode = &exitCode + } + runtime.taskState.mu.Unlock() + }() + return bashOutput{ + Command: command, + ExitCode: 0, + Output: fmt.Sprintf("Command is running in the background. Read %s later to inspect the output.", outputPath), + BackgroundTaskID: taskID, + OutputPath: outputPath, + }, nil +} + +func backgroundTaskStatus(waitErr error, state *os.ProcessState) string { + if waitErr != nil { + return "exited" + } + if state == nil || !state.Success() { + return "exited" + } + return "completed" +} + +func backgroundTaskExitCode(waitErr error, state *os.ProcessState) int { + if state != nil { + return state.ExitCode() + } + if waitErr != nil { + return 1 + } + return 0 +} + +func errorsIsDeadlineExceeded(err error) bool { + return err == context.DeadlineExceeded +} + +func bashDescription() string { + return fmt.Sprintf(`Execute a local shell command. + +Usage: +- Use %s for shell-native tasks such as git, build, test, lint, package managers, and project scripts. +- Prefer dedicated tools when they fit better: use %s to read files, %s to create or overwrite files, %s for targeted text replacements, %s for notebook cell edits, %s for filename search, %s for repository content search, %s to fetch a specific URL, and %s for broad web discovery. +- NEVER use bash grep or rg for repository search when %s can answer the question. +- Commands run from the current workspace base directory. +- Use run_in_background for long-running commands that do not need an immediate result. Inspect them later with %s or stop them with %s. +- timeout is measured in milliseconds and is capped at %d ms.`, toolBash, toolRead, toolWrite, toolEdit, toolNotebookEdit, toolGlob, toolGrep, toolWebFetch, toolWebSearch, toolGrep, toolTaskOutput, toolTaskStop, maxBashTimeoutMs) +} diff --git a/tool/claudecode/claudecode_test.go b/tool/claudecode/claudecode_test.go new file mode 100644 index 000000000..44c229a25 --- /dev/null +++ b/tool/claudecode/claudecode_test.go @@ -0,0 +1,2483 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/go-pdf/fpdf" + "github.com/stretchr/testify/require" + "trpc.group/trpc-go/trpc-agent-go/tool" +) + +func TestNewToolSet_DefaultTools(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + require.NotNil(t, ts) + require.Equal(t, defaultToolSetName, ts.Name()) + names := toolNames(ts.Tools(context.Background())) + require.Contains(t, names, toolBash) + require.Contains(t, names, toolRead) + require.Contains(t, names, toolWrite) + require.Contains(t, names, toolEdit) + require.Contains(t, names, toolNotebookEdit) + require.Contains(t, names, toolGlob) + require.Contains(t, names, toolGrep) + require.Contains(t, names, toolTaskStop) + require.Contains(t, names, toolTaskOutput) + require.Contains(t, names, toolWebFetch) + require.Contains(t, names, toolWebSearch) + require.NotContains(t, names, "EnterWorktree") + require.NotContains(t, names, "ExitWorktree") + require.NotContains(t, names, "WebBrowser") + require.NotContains(t, names, "browser") + require.NotContains(t, names, "LSP") +} + +func TestNewToolSet_UsesRichToolDescriptions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + descriptions := map[string]string{} + for _, candidate := range ts.Tools(context.Background()) { + if candidate == nil || candidate.Declaration() == nil { + continue + } + descriptions[candidate.Declaration().Name] = candidate.Declaration().Description + } + require.Contains(t, descriptions[toolBash], "NEVER use bash grep or rg") + require.Contains(t, descriptions[toolRead], "PDF") + require.Contains(t, descriptions[toolWrite], "read it with Read first") + require.Contains(t, descriptions[toolEdit], "old_string must match") + require.Contains(t, descriptions[toolGlob], "doublestar-style globs") + require.Contains(t, descriptions[toolGrep], "ALWAYS use Grep") + require.Contains(t, descriptions[toolWebFetch], "prompt is required") + require.Contains(t, descriptions[toolWebSearch], "allowed_domains and blocked_domains") + require.NotContains(t, descriptions, "LSP") +} + +func TestNewToolSet_ReadOnlyOmitsWriteAndEdit(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir), WithReadOnly(true)) + require.NoError(t, err) + require.NotNil(t, ts) + names := toolNames(ts.Tools(context.Background())) + require.Contains(t, names, toolRead) + require.NotContains(t, names, toolWrite) + require.NotContains(t, names, toolEdit) + require.NotContains(t, names, toolNotebookEdit) +} + +func TestNewToolSet_BlankNameFallsBackAndInvalidWebSearchFails(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir), WithName(" \t ")) + require.NoError(t, err) + require.Equal(t, defaultToolSetName, ts.Name()) + require.NoError(t, ts.Close()) + _, err = NewToolSet(WithBaseDir(dir), WithWebSearchOptions(WebSearchOptions{ + Provider: "bing", + })) + require.EqualError(t, err, "unsupported web search provider: bing") +} + +func TestToolSet_BashToolRunsCommand(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + bashTool := mustCallableTool(t, ts.Tools(context.Background()), toolBash) + out := callToolAs[bashOutput](t, bashTool, bashInput{ + Command: "printf 'hello'", + }) + require.Equal(t, 0, out.ExitCode) + require.Equal(t, "hello", out.Stdout) + require.Equal(t, "hello", out.Output) + require.False(t, out.TimedOut) +} + +func TestToolSet_BashToolTimeoutsLongCommand(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + bashTool := mustCallableTool(t, ts.Tools(context.Background()), toolBash) + out := callToolAs[bashOutput](t, bashTool, bashInput{ + Command: "sleep 1", + Timeout: intPtr(1), + }) + require.True(t, out.TimedOut) + require.NotEqual(t, 0, out.ExitCode) +} + +func TestToolSet_TaskStopStopsBackgroundBashTask(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + bashTool := mustCallableTool(t, ts.Tools(context.Background()), toolBash) + stopTool := mustCallableTool(t, ts.Tools(context.Background()), toolTaskStop) + bgOut := callToolAs[bashOutput](t, bashTool, bashInput{ + Command: "sleep 30", + RunInBackground: true, + }) + require.NotEmpty(t, bgOut.BackgroundTaskID) + stopOut := callToolAs[taskStopOutput](t, stopTool, taskStopInput{ + TaskID: bgOut.BackgroundTaskID, + }) + require.Equal(t, bgOut.BackgroundTaskID, stopOut.TaskID) + require.Equal(t, toolBash, stopOut.TaskType) + require.Contains(t, stopOut.Message, "Successfully stopped task") +} + +func TestToolSet_TaskOutputReadsBackgroundBashTask(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + bashTool := mustCallableTool(t, ts.Tools(context.Background()), toolBash) + taskOutputTool := mustCallableTool(t, ts.Tools(context.Background()), toolTaskOutput) + bgOut := callToolAs[bashOutput](t, bashTool, bashInput{ + Command: "printf 'a'; sleep 0.3; printf 'b'", + RunInBackground: true, + }) + require.NotEmpty(t, bgOut.BackgroundTaskID) + nonBlocking := callToolAs[taskOutputOutput](t, taskOutputTool, taskOutputInput{ + TaskID: bgOut.BackgroundTaskID, + Block: boolPtr(false), + }) + require.NotNil(t, nonBlocking.Task) + require.Contains(t, []string{"not_ready", "success"}, nonBlocking.RetrievalStatus) + blocking := callToolAs[taskOutputOutput](t, taskOutputTool, taskOutputInput{ + TaskID: bgOut.BackgroundTaskID, + Timeout: intPtr(5_000), + }) + require.Equal(t, "success", blocking.RetrievalStatus) + require.NotNil(t, blocking.Task) + require.Equal(t, toolBash, blocking.Task.TaskType) + require.Contains(t, blocking.Task.Output, "ab") + require.Equal(t, "completed", blocking.Task.Status) +} + +func TestToolSet_TaskOutputCoversPollingBranches(t *testing.T) { + t.Parallel() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + taskOutputTool, err := newTaskOutputTool(runtime) + require.NoError(t, err) + callable, ok := taskOutputTool.(tool.CallableTool) + require.True(t, ok) + _, err = callToolRaw(callable, taskOutputInput{}) + require.EqualError(t, err, "task_id is required") + runningLog := filepath.Join(t.TempDir(), "running.log") + require.NoError(t, os.WriteFile(runningLog, []byte("partial"), 0o644)) + runtime.taskState.tasks["running"] = &backgroundTask{ + ID: "running", + Command: "sleep 10", + Type: toolBash, + OutputPath: runningLog, + Status: "running", + } + nonBlocking := callToolAs[taskOutputOutput](t, callable, taskOutputInput{ + TaskID: "running", + Block: boolPtr(false), + }) + require.Equal(t, "not_ready", nonBlocking.RetrievalStatus) + require.NotNil(t, nonBlocking.Task) + require.Equal(t, "partial", nonBlocking.Task.Output) + blockingTimeout := callToolAs[taskOutputOutput](t, callable, taskOutputInput{ + TaskID: "running", + Timeout: intPtr(0), + }) + require.Equal(t, "timeout", blockingTimeout.RetrievalStatus) + finishedLog := filepath.Join(t.TempDir(), "done.log") + require.NoError(t, os.WriteFile(finishedLog, []byte("done"), 0o644)) + exitCode := 0 + runtime.taskState.tasks["done"] = &backgroundTask{ + ID: "done", + Command: "echo done", + Type: toolBash, + OutputPath: finishedLog, + Status: "completed", + ExitCode: &exitCode, + } + completed := callToolAs[taskOutputOutput](t, callable, taskOutputInput{ + TaskID: "done", + Block: boolPtr(false), + }) + require.Equal(t, "success", completed.RetrievalStatus) + require.NotNil(t, completed.Task) + require.Equal(t, "done", completed.Task.Output) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = callToolRawWithContext(callable, ctx, taskOutputInput{ + TaskID: "running", + Timeout: intPtr(1000), + }) + require.ErrorIs(t, err, context.Canceled) + _, err = snapshotBackgroundTask(runtime, "missing") + require.EqualError(t, err, "no task found with ID: missing") +} + +func TestToolSet_TaskOutputClampsNegativeTimeoutToZero(t *testing.T) { + t.Parallel() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + taskOutputTool, err := newTaskOutputTool(runtime) + require.NoError(t, err) + callable, ok := taskOutputTool.(tool.CallableTool) + require.True(t, ok) + runtime.taskState.tasks["running"] = &backgroundTask{ + ID: "running", + Command: "sleep 10", + Type: toolBash, + OutputPath: filepath.Join(t.TempDir(), "running.log"), + Status: "running", + } + out := callToolAs[taskOutputOutput](t, callable, taskOutputInput{ + TaskID: "running", + Timeout: intPtr(-1), + }) + require.Equal(t, "timeout", out.RetrievalStatus) + require.NotNil(t, out.Task) + require.Equal(t, "running", out.Task.Status) +} + +func TestReadTaskSnapshotHandlesMissingOutputAndCopiesExitCode(t *testing.T) { + t.Parallel() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + exitCode := 7 + runtime.taskState.tasks["done"] = &backgroundTask{ + ID: "done", + Command: "echo done", + Type: toolBash, + OutputPath: filepath.Join(t.TempDir(), "missing.log"), + Status: "completed", + ExitCode: &exitCode, + } + snapshot, err := snapshotBackgroundTask(runtime, "done") + require.NoError(t, err) + require.NotNil(t, snapshot.ExitCode) + require.Equal(t, 7, *snapshot.ExitCode) + exitCode = 9 + require.Equal(t, 7, *snapshot.ExitCode) + out, err := readTaskSnapshot(runtime, "done") + require.NoError(t, err) + require.Equal(t, "", out.Output) + require.NotNil(t, out.ExitCode) + require.Equal(t, 9, *out.ExitCode) +} + +func TestToolSet_ReadWriteEditFlow(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + filePath := filepath.Join(dir, "notes.txt") + writeTool := mustCallableTool(t, ts.Tools(context.Background()), toolWrite) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + editTool := mustCallableTool(t, ts.Tools(context.Background()), toolEdit) + writeOut := callToolAs[writeOutput](t, writeTool, writeInput{ + FilePath: filePath, + Content: "hello\nworld\n", + }) + require.Equal(t, "create", writeOut.Type) + require.Equal(t, filePath, writeOut.FilePath) + require.Nil(t, writeOut.OriginalFile) + require.NotEmpty(t, writeOut.StructuredPatch) + readOut := callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + require.Equal(t, "text", readOut.Type) + require.NotNil(t, readOut.File) + require.Equal(t, "hello\nworld\n", readOut.File.Content) + require.Equal(t, 2, readOut.File.TotalLines) + readDedup := callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + require.Equal(t, "file_unchanged", readDedup.Type) + editOut := callToolAs[editOutput](t, editTool, editInput{ + FilePath: filePath, + OldString: "world", + NewString: "claude", + }) + require.Equal(t, filePath, editOut.FilePath) + require.Equal(t, "world", editOut.OldString) + require.Equal(t, "claude", editOut.NewString) + require.Equal(t, "hello\nworld\n", editOut.OriginalFile) + require.NotEmpty(t, editOut.StructuredPatch) + updated := callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + require.Equal(t, "text", updated.Type) + require.Equal(t, "hello\nclaude\n", updated.File.Content) +} + +func TestToolSet_WriteRequiresFullReadAndRejectsStaleFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + filePath := filepath.Join(dir, "notes.txt") + require.NoError(t, os.WriteFile(filePath, []byte("hello\n"), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + writeTool := mustCallableTool(t, ts.Tools(context.Background()), toolWrite) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + _, err = callToolRaw(writeTool, writeInput{ + FilePath: filePath, + Content: "rewrite\n", + }) + require.ErrorContains(t, err, "File has not been read yet") + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + Limit: intPtr(1), + }) + _, err = callToolRaw(writeTool, writeInput{ + FilePath: filePath, + Content: "rewrite\n", + }) + require.ErrorContains(t, err, "File has not been read yet") + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + future := time.Now().Add(2 * time.Second) + require.NoError(t, os.WriteFile(filePath, []byte("hello\n"), 0o644)) + require.NoError(t, os.Chtimes(filePath, future, future)) + out := callToolAs[writeOutput](t, writeTool, writeInput{ + FilePath: filePath, + Content: "rewrite\n", + }) + require.Equal(t, "update", out.Type) + require.Equal(t, "hello\n", derefString(out.OriginalFile)) + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + require.NoError(t, os.WriteFile(filePath, []byte("user update\n"), 0o644)) + require.NoError(t, os.Chtimes(filePath, future.Add(2*time.Second), future.Add(2*time.Second))) + _, err = callToolRaw(writeTool, writeInput{ + FilePath: filePath, + Content: "rewrite again\n", + }) + require.ErrorContains(t, err, "File has been modified since read") +} + +func TestToolSet_WriteRejectsPathsOutsideBaseDirAndDirectories(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + writeTool := mustCallableTool(t, ts.Tools(context.Background()), toolWrite) + _, err = callToolRaw(writeTool, writeInput{ + FilePath: "../outside.txt", + Content: "blocked", + }) + require.ErrorContains(t, err, "path is outside base_dir") + require.NoError(t, os.Mkdir(filepath.Join(dir, "nested"), 0o755)) + _, err = callToolRaw(writeTool, writeInput{ + FilePath: filepath.Join(dir, "nested"), + Content: "blocked", + }) + require.ErrorContains(t, err, "is a directory") +} + +func TestToolSet_EditRejectsNotebookAndPreservesCurlyQuotes(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + editTool := mustCallableTool(t, ts.Tools(context.Background()), toolEdit) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + notebookPath := filepath.Join(dir, "book.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[]}`), 0o644)) + _, err = callToolRaw(editTool, editInput{ + FilePath: notebookPath, + OldString: "{}", + NewString: "[]", + }) + require.ErrorContains(t, err, "NotebookEdit") + filePath := filepath.Join(dir, "quotes.txt") + require.NoError(t, os.WriteFile(filePath, []byte("“hello”\n"), 0o644)) + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: filePath, + }) + out := callToolAs[editOutput](t, editTool, editInput{ + FilePath: filePath, + OldString: "\"hello\"", + NewString: "\"world\"", + }) + require.Equal(t, "\"hello\"", out.OldString) + updated, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Equal(t, "“world”\n", string(updated)) +} + +func TestToolSet_ReadSupportsNotebookImageAndDedup(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + notebookPath := filepath.Join(dir, "book.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[{"cell_type":"markdown","source":["hello"]}]}`), 0o644)) + notebookOut := callToolAs[readOutput](t, readTool, readInput{ + FilePath: notebookPath, + }) + require.Equal(t, "notebook", notebookOut.Type) + require.Len(t, notebookOut.File.Cells, 1) + imagePath := filepath.Join(dir, "tiny.png") + require.NoError(t, os.WriteFile(imagePath, tinyPNGBytes, 0o644)) + imageOut := callToolAs[readOutput](t, readTool, readInput{ + FilePath: imagePath, + }) + require.Equal(t, "image", imageOut.Type) + require.NotEmpty(t, imageOut.File.Base64) + require.Contains(t, imageOut.File.MediaType, "image/png") + imageDedup := callToolAs[readOutput](t, readTool, readInput{ + FilePath: imagePath, + }) + require.Equal(t, "file_unchanged", imageDedup.Type) +} + +func TestToolSet_ReadSupportsPDFAndPageRanges(t *testing.T) { + t.Parallel() + pdftoppmTestMu.Lock() + t.Cleanup(func() { + pdftoppmTestMu.Unlock() + }) + dir := t.TempDir() + pdfPath := filepath.Join(dir, "paper.pdf") + require.NoError(t, os.WriteFile(pdfPath, newTestPDF(t, []string{"Page 1", "Page 2"}), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + fullOut := callToolAs[readOutput](t, readTool, readInput{FilePath: pdfPath}) + require.Equal(t, "pdf", fullOut.Type) + require.NotEmpty(t, fullOut.File.Base64) + if _, err := exec.LookPath("pdftoppm"); err != nil { + _, err = callToolRaw(readTool, readInput{ + FilePath: pdfPath, + Pages: "1", + }) + require.ErrorContains(t, err, "pdftoppm is not installed") + return + } + pageOut := callToolAs[readOutput](t, readTool, readInput{ + FilePath: pdfPath, + Pages: "1", + }) + require.Equal(t, "parts", pageOut.Type) + require.Equal(t, 1, pageOut.File.Count) + require.NotEmpty(t, pageOut.File.OutputDir) + rendered, err := filepath.Glob(filepath.Join(pageOut.File.OutputDir, "*.jpg")) + require.NoError(t, err) + require.Len(t, rendered, 1) +} + +func TestToolSet_ReadRejectsLargePDFWithoutPages(t *testing.T) { + t.Parallel() + dir := t.TempDir() + pdfPath := filepath.Join(dir, "large.pdf") + pages := make([]string, 0, pdfMaxPagesPerRead+5) + for idx := 0; idx < pdfMaxPagesPerRead+5; idx++ { + pages = append(pages, "Page "+strconvString(idx+1)) + } + require.NoError(t, os.WriteFile(pdfPath, newTestPDF(t, pages), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + _, err = callToolRaw(readTool, readInput{FilePath: pdfPath}) + require.ErrorContains(t, err, "too many to read at once") + _, err = callToolRaw(readTool, readInput{ + FilePath: pdfPath, + Pages: "1-21", + }) + require.ErrorContains(t, err, "exceeds maximum") +} + +func TestToolSet_ReadCoversOffsetAndErrorBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + _, err = callToolRaw(readTool, readInput{FilePath: filepath.Join(dir, "missing.txt")}) + require.EqualError(t, err, fmt.Sprintf("File does not exist: %s", filepath.Join(dir, "missing.txt"))) + binaryPath := filepath.Join(dir, "data.bin") + require.NoError(t, os.WriteFile(binaryPath, []byte{0x00, 0x01, 0x02}, 0o644)) + _, err = callToolRaw(readTool, readInput{FilePath: binaryPath}) + require.EqualError(t, err, "This tool cannot read binary files.") + textPath := filepath.Join(dir, "notes.txt") + require.NoError(t, os.WriteFile(textPath, []byte("alpha\nbeta\ngamma\n"), 0o644)) + out := callToolAs[readOutput](t, readTool, readInput{ + FilePath: textPath, + Offset: intPtr(2), + Limit: intPtr(1), + }) + require.Equal(t, "text", out.Type) + require.NotNil(t, out.File) + require.Equal(t, 2, out.File.StartLine) + require.Equal(t, 3, out.File.TotalLines) + require.Equal(t, 1, out.File.NumLines) + require.Equal(t, "beta", out.File.Content) +} + +func TestToolSet_NotebookEditFlow(t *testing.T) { + t.Parallel() + dir := t.TempDir() + notebookPath := filepath.Join(dir, "book.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[{"id":"intro","cell_type":"markdown","source":"hello","metadata":{}}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":5}`), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + notebookEditTool := mustCallableTool(t, ts.Tools(context.Background()), toolNotebookEdit) + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: notebookPath, + }) + replaceOut := callToolAs[notebookEditOutput](t, notebookEditTool, notebookEditInput{ + NotebookPath: notebookPath, + CellID: "intro", + NewSource: "updated", + EditMode: "replace", + }) + require.Equal(t, "replace", replaceOut.EditMode) + require.Equal(t, "intro", replaceOut.CellID) + require.Equal(t, "markdown", replaceOut.CellType) + require.Equal(t, "python", replaceOut.Language) + insertOut := callToolAs[notebookEditOutput](t, notebookEditTool, notebookEditInput{ + NotebookPath: notebookPath, + CellID: "intro", + NewSource: "print(1)", + CellType: "code", + EditMode: "insert", + }) + require.Equal(t, "insert", insertOut.EditMode) + require.NotEmpty(t, insertOut.CellID) + deleteOut := callToolAs[notebookEditOutput](t, notebookEditTool, notebookEditInput{ + NotebookPath: notebookPath, + CellID: insertOut.CellID, + NewSource: "", + EditMode: "delete", + }) + require.Equal(t, "delete", deleteOut.EditMode) + rawNotebook, err := os.ReadFile(notebookPath) + require.NoError(t, err) + var decoded struct { + Cells []map[string]any `json:"cells"` + } + require.NoError(t, json.Unmarshal(rawNotebook, &decoded)) + require.Len(t, decoded.Cells, 1) + require.Equal(t, "updated", decoded.Cells[0]["source"]) +} + +func TestToolSet_NotebookEditRejectsUnreadAndStaleNotebook(t *testing.T) { + t.Parallel() + dir := t.TempDir() + notebookPath := filepath.Join(dir, "book.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[{"id":"intro","cell_type":"markdown","source":"hello","metadata":{}}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":5}`), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + readTool := mustCallableTool(t, ts.Tools(context.Background()), toolRead) + notebookEditTool := mustCallableTool(t, ts.Tools(context.Background()), toolNotebookEdit) + _, err = callToolRaw(notebookEditTool, notebookEditInput{ + NotebookPath: notebookPath, + CellID: "intro", + NewSource: "updated", + EditMode: "replace", + }) + require.ErrorContains(t, err, "File has not been read yet") + _ = callToolAs[readOutput](t, readTool, readInput{ + FilePath: notebookPath, + }) + future := time.Now().Add(2 * time.Second) + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[{"id":"intro","cell_type":"markdown","source":"external","metadata":{}}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":5}`), 0o644)) + require.NoError(t, os.Chtimes(notebookPath, future, future)) + _, err = callToolRaw(notebookEditTool, notebookEditInput{ + NotebookPath: notebookPath, + CellID: "intro", + NewSource: "updated", + EditMode: "replace", + }) + require.ErrorContains(t, err, "File has been modified since read") +} + +func TestToolSet_GlobStandalone(t *testing.T) { + t.Parallel() + dir := t.TempDir() + for idx := 0; idx < 120; idx++ { + require.NoError(t, os.WriteFile(filepath.Join(dir, "f"+strconvString(idx)+".txt"), []byte("x"), 0o644)) + } + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + globTool := mustCallableTool(t, ts.Tools(context.Background()), toolGlob) + out := callToolAs[globOutput](t, globTool, globInput{ + Pattern: "*.txt", + }) + require.Equal(t, defaultGlobHeadLimit, out.NumFiles) + require.True(t, out.Truncated) + require.NotZero(t, out.DurationMs) +} + +func TestToolSet_GlobPathValidationErrors(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "note.txt"), []byte("hello"), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + globTool := mustCallableTool(t, ts.Tools(context.Background()), toolGlob) + _, err = callToolRaw(globTool, globInput{ + Pattern: "*.txt", + Path: "missing", + }) + require.ErrorContains(t, err, "Directory does not exist: missing") + _, err = callToolRaw(globTool, globInput{ + Pattern: "*.txt", + Path: "note.txt", + }) + require.ErrorContains(t, err, "Path is not a directory: note.txt") +} + +func TestToolSet_GrepFallbackAndRipgrepModes(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("alpha\nhello\nbeta\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("hello again\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main\n\nfunc main() {\n\tprintln(\"alpha\")\n\tprintln(\"beta\")\n}\n"), 0o644)) + restore := withRipgrepForTest(func(string) (string, error) { + return "", errors.New("not found") + }) + defer restore() + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + grepTool := mustCallableTool(t, ts.Tools(context.Background()), toolGrep) + filesOut := callToolAs[grepOutput](t, grepTool, grepInput{ + Pattern: "hello", + Glob: "*.txt", + OutputMode: "files_with_matches", + }) + require.ElementsMatch(t, []string{"a.txt", "b.txt"}, filesOut.Filenames) + countOut := callToolAs[grepOutput](t, grepTool, grepInput{ + Pattern: "hello", + Glob: "*.txt", + OutputMode: "count", + }) + require.Equal(t, 2, countOut.NumMatches) + contentOut := callToolAs[grepOutput](t, grepTool, grepInput{ + Pattern: "hello", + Glob: "*.txt", + OutputMode: "content", + Context: intPtr(1), + }) + require.Contains(t, contentOut.Content, "a.txt:1:alpha") + require.Contains(t, contentOut.Content, "a.txt:2:hello") + require.Contains(t, contentOut.Content, "a.txt:3:beta") + multilineOut := callToolAs[grepOutput](t, grepTool, grepInput{ + Pattern: "alpha.*beta", + OutputMode: "content", + Type: "go", + Multiline: true, + ContextAlt: intPtr(1), + ShowLineNum: boolPtr(true), + }) + require.Contains(t, multilineOut.Content, "main.go:4:\tprintln(\"alpha\")") + require.Contains(t, multilineOut.Content, "main.go:5:\tprintln(\"beta\")") +} + +func TestToolSet_GrepRipgrepAdvancedOptions(t *testing.T) { + if _, err := exec.LookPath("rg"); err != nil { + t.Skip("ripgrep is not available") + } + restore := withRipgrepForTest(exec.LookPath) + defer restore() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main\n\nfunc main() {\n\tprintln(\"alpha\")\n\tprintln(\"beta\")\n}\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("alpha\nbeta\n"), 0o644)) + ts, err := NewToolSet(WithBaseDir(dir)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + grepTool := mustCallableTool(t, ts.Tools(context.Background()), toolGrep) + out := callToolAs[grepOutput](t, grepTool, grepInput{ + Pattern: "alpha.*beta", + OutputMode: "content", + Type: "go", + Multiline: true, + ContextAlt: intPtr(1), + ShowLineNum: boolPtr(true), + }) + require.Contains(t, out.Content, "main.go:4:\tprintln(\"alpha\")") + require.Contains(t, out.Content, "main.go:5:\tprintln(\"beta\")") +} + +func TestToolSet_WebFetchTool(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte("

Hello

world

")) + })) + defer server.Close() + ts, err := NewToolSet() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + out := callToolAs[webFetchOutput](t, fetchTool, webFetchInput{ + URL: server.URL, + Prompt: "Summarize the page.", + }) + require.Equal(t, 200, out.Code) + require.Contains(t, out.Result, "Hello") + require.Contains(t, out.Result, "world") + require.NotZero(t, out.DurationMs) +} + +func TestToolSet_WebFetchReturnsExtractedContentByDefault(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte("

Alpha release date is April 1.

Beta secret is 42.

Gamma is unrelated.

")) + })) + defer server.Close() + ts, err := NewToolSet() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + out := callToolAs[webFetchOutput](t, fetchTool, webFetchInput{ + URL: server.URL, + Prompt: "What is the beta secret?", + }) + require.Contains(t, out.Result, "Alpha release date is April 1.") + require.Contains(t, out.Result, "Beta secret is 42.") + require.Contains(t, out.Result, "Gamma is unrelated.") +} + +func TestToolSet_WebFetchDetectsCrossHostRedirect(t *testing.T) { + t.Parallel() + redirectTarget := "http://localhost/target" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/start" { + http.Redirect(w, r, redirectTarget, http.StatusFound) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + ts, err := NewToolSet() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + out := callToolAs[webFetchOutput](t, fetchTool, webFetchInput{ + URL: server.URL + "/start", + Prompt: "Summarize the page.", + }) + require.Contains(t, out.Result, "REDIRECT DETECTED") +} + +func TestToolSet_WebFetchUsesPromptProcessorWhenConfigured(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(` + + Documentation + Blog + + `)) + })) + defer server.Close() + ts, err := NewToolSet(WithWebFetchOptions(WebFetchOptions{ + AllowAll: true, + PromptProcessor: func(_ context.Context, in WebFetchPromptInput) (string, error) { + return "prompt=" + in.Prompt + "\ncontent=" + in.Content, nil + }, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + out := callToolAs[webFetchOutput](t, fetchTool, webFetchInput{ + URL: server.URL, + Prompt: "List the links on the page.", + }) + require.Contains(t, out.Result, "prompt=List the links on the page.") + require.Contains(t, out.Result, "Documentation") + require.Contains(t, out.Result, "Blog") +} + +func TestToolSet_WebFetchRejectsMissingPromptAndBlockedDomain(t *testing.T) { + t.Parallel() + ts, err := NewToolSet(WithWebFetchOptions(WebFetchOptions{ + BlockedDomains: []string{"example.com"}, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + _, err = callToolRaw(fetchTool, webFetchInput{ + URL: "https://allowed.example.com/page", + }) + require.EqualError(t, err, "prompt is required") + _, err = callToolRaw(fetchTool, webFetchInput{ + URL: "https://example.com/page", + Prompt: "Summarize the page.", + }) + require.EqualError(t, err, "url is blocked by domain policy: https://example.com/page") +} + +func TestToolSet_WebFetchPropagatesFetchAndPromptProcessorErrors(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + _, _ = w.Write([]byte("alpha")) + })) + defer server.Close() + ts, err := NewToolSet(WithWebFetchOptions(WebFetchOptions{ + Timeout: time.Second, + PromptProcessor: func(context.Context, WebFetchPromptInput) (string, error) { + return "", fs.ErrInvalid + }, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + fetchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebFetch) + _, err = callToolRaw(fetchTool, webFetchInput{ + URL: "://bad", + Prompt: "Summarize the page.", + }) + require.Error(t, err) + _, err = callToolRaw(fetchTool, webFetchInput{ + URL: server.URL, + Prompt: "Summarize the page.", + }) + require.ErrorIs(t, err, fs.ErrInvalid) +} + +func TestFetchURLHandlesRedirectAndBodyErrorBranches(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/start": + http.Redirect(w, r, "/final", http.StatusFound) + case "/final": + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + _, _ = w.Write([]byte("redirected content")) + case "/missing-location": + w.WriteHeader(http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + client := &http.Client{ + Timeout: defaultHTTPTimeout, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + finalURL, statusCode, statusText, body, contentType, err := fetchURL(context.Background(), client, server.URL+"/start", WebFetchOptions{}) + require.NoError(t, err) + require.Equal(t, server.URL+"/final", finalURL) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, "200 OK", statusText) + require.Equal(t, []byte("redirected content"), body) + require.Equal(t, "text/plain; charset=utf-8", contentType) + finalURL, statusCode, statusText, body, contentType, err = fetchURL(context.Background(), client, server.URL+"/missing-location", WebFetchOptions{}) + require.NoError(t, err) + require.Equal(t, server.URL+"/missing-location", finalURL) + require.Equal(t, http.StatusFound, statusCode) + require.Equal(t, "302 Found", statusText) + require.Nil(t, body) + require.Empty(t, contentType) + errorClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(errReader{err: fs.ErrInvalid}), + Header: make(http.Header), + Request: req, + }, nil + }), + } + _, _, _, _, _, err = fetchURL(context.Background(), errorClient, server.URL+"/body-error", WebFetchOptions{}) + require.ErrorIs(t, err, fs.ErrInvalid) +} + +func TestFetchURLPropagatesRequestAndRedirectParseErrors(t *testing.T) { + t.Parallel() + requestFailedClient := &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, fs.ErrInvalid + }), + } + _, _, _, _, _, err := fetchURL(context.Background(), requestFailedClient, "https://example.com/start", WebFetchOptions{}) + require.ErrorIs(t, err, fs.ErrInvalid) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "://bad") + w.WriteHeader(http.StatusFound) + })) + defer server.Close() + client := &http.Client{ + Timeout: defaultHTTPTimeout, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + _, _, _, _, _, err = fetchURL(context.Background(), client, server.URL, WebFetchOptions{}) + require.Error(t, err) +} + +func TestFetchURLReturnsTooManyRedirects(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + step, err := strconv.Atoi(strings.TrimPrefix(r.URL.Path, "/")) + require.NoError(t, err) + http.Redirect(w, r, fmt.Sprintf("/%d", step+1), http.StatusFound) + })) + defer server.Close() + client := &http.Client{ + Timeout: defaultHTTPTimeout, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + _, _, _, _, _, err := fetchURL(context.Background(), client, server.URL+"/0", WebFetchOptions{}) + require.EqualError(t, err, "too many redirects") +} + +func TestProcessFetchedContentAndTrimFetchResultCoverRemainingBranches(t *testing.T) { + t.Parallel() + processed, err := processFetchedContent(context.Background(), WebFetchOptions{}, webFetchInput{ + URL: "https://example.com/page", + Prompt: "Summarize the page.", + }, " alpha beta gamma ", "text/plain") + require.NoError(t, err) + require.Equal(t, "alpha beta gamma", processed) + _, err = processFetchedContent(context.Background(), WebFetchOptions{ + PromptProcessor: func(context.Context, WebFetchPromptInput) (string, error) { + return "", fs.ErrInvalid + }, + }, webFetchInput{ + URL: "https://example.com/page", + Prompt: "Summarize the page.", + }, "content", "text/plain") + require.ErrorIs(t, err, fs.ErrInvalid) + require.Equal(t, "abc\n\n[Content truncated due to length.]", trimFetchResult(" abcdef ", 3)) +} + +func TestToolSet_WebSearchDuckDuckGoLikeHTML(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(` + +
+ The Go Programming Language + Go documentation. +
+
+ Example + Example snippet. +
+ + `)) + })) + defer server.Close() + ts, err := NewToolSet(WithWebSearchOptions(WebSearchOptions{ + Provider: "duckduckgo", + BaseURL: server.URL, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + searchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebSearch) + out := callToolAs[webSearchOutput](t, searchTool, webSearchInput{ + Query: "golang", + AllowedDomains: []string{"golang.org"}, + }) + require.Equal(t, "golang", out.Query) + require.Len(t, out.Results, 1) + require.Len(t, out.Results[0].Content, 1) + require.Equal(t, "https://golang.org/doc/", out.Results[0].Content[0].URL) + require.NotZero(t, out.DurationSeconds) +} + +func TestToolSet_WebSearchDuckDuckGoNormalizesWrappedLinksAndDeduplicates(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(` + +
+ The Go Programming Language + Go documentation. +
+
+ Go Docs Duplicate + Duplicate hit. +
+ + `)) + })) + defer server.Close() + ts, err := NewToolSet(WithWebSearchOptions(WebSearchOptions{ + Provider: "duckduckgo", + BaseURL: server.URL, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + searchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebSearch) + out := callToolAs[webSearchOutput](t, searchTool, webSearchInput{ + Query: "golang", + AllowedDomains: []string{"golang.org"}, + }) + require.Len(t, out.Results, 1) + require.Len(t, out.Results[0].Content, 1) + require.Equal(t, "https://golang.org/doc/", out.Results[0].Content[0].URL) +} + +func TestToolSet_WebSearchDuckDuckGoAppliesConfiguredWindowAndReturnsNoResultBlocksWhenEmpty(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(` + +
+ One + First. +
+
+ Two + Second. +
+
+ Three + Third. +
+ + `)) + })) + defer server.Close() + ts, err := NewToolSet(WithWebSearchOptions(WebSearchOptions{ + Provider: "duckduckgo", + BaseURL: server.URL, + Size: 1, + Offset: 1, + })) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ts.Close()) + }) + searchTool := mustCallableTool(t, ts.Tools(context.Background()), toolWebSearch) + out := callToolAs[webSearchOutput](t, searchTool, webSearchInput{ + Query: "example", + }) + require.Len(t, out.Results, 1) + require.Len(t, out.Results[0].Content, 1) + require.Equal(t, "https://two.example.com/", out.Results[0].Content[0].URL) + emptyOut := callToolAs[webSearchOutput](t, searchTool, webSearchInput{ + Query: "example", + AllowedDomains: []string{"missing.example.com"}, + }) + require.Empty(t, emptyOut.Results) +} + +func TestParseDuckDuckGoHTMLLeavesSnippetEmptyWithoutDedicatedSnippetNode(t *testing.T) { + t.Parallel() + hits := parseDuckDuckGoHTML([]byte(` + +
+
+ Example Title + example.com/doc +
+
+ + `), webSearchInput{Query: "example"}, 0, 0) + require.Len(t, hits, 1) + require.Equal(t, "Example Title", hits[0].Title) + require.Equal(t, "https://example.com/doc", hits[0].URL) + require.Empty(t, hits[0].Snippet) +} + +func TestGoogleSearchBackendSearchUsesConfiguredOptions(t *testing.T) { + t.Parallel() + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + require.Equal(t, "api-key", r.URL.Query().Get("key")) + require.Equal(t, "engine-id", r.URL.Query().Get("cx")) + require.Equal(t, "golang", r.URL.Query().Get("q")) + if requestCount == 1 { + require.Equal(t, "2", r.URL.Query().Get("num")) + require.Equal(t, "2", r.URL.Query().Get("start")) + require.Equal(t, "lang_en", r.URL.Query().Get("lr")) + } else { + require.Empty(t, r.URL.Query().Get("num")) + require.Empty(t, r.URL.Query().Get("start")) + require.Empty(t, r.URL.Query().Get("lr")) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "items": [ + {"link": "https://golang.org/doc/", "title": "Go", "snippet": "Official docs."}, + {"link": "https://golang.org/doc/", "title": "Go duplicate", "snippet": "Duplicate."}, + {"link": "https://example.com/", "title": "Example", "snippet": "Filtered out."} + ] + }`)) + })) + defer server.Close() + backend := &googleSearchBackend{ + client: server.Client(), + options: &WebSearchOptions{ + BaseURL: server.URL, + APIKey: "api-key", + EngineID: "engine-id", + Size: 2, + Offset: 1, + Lang: "en", + }, + } + hits, err := backend.search(context.Background(), webSearchInput{ + Query: "golang", + AllowedDomains: []string{"golang.org"}, + }) + require.NoError(t, err) + require.Len(t, hits, 1) + require.Equal(t, "https://golang.org/doc/", hits[0].URL) + unwindowedBackend := &googleSearchBackend{ + client: server.Client(), + options: &WebSearchOptions{ + BaseURL: server.URL, + APIKey: "api-key", + EngineID: "engine-id", + }, + } + unwindowedHits, err := unwindowedBackend.search(context.Background(), webSearchInput{ + Query: "golang", + }) + require.NoError(t, err) + require.Len(t, unwindowedHits, 2) + require.Equal(t, "https://golang.org/doc/", unwindowedHits[0].URL) +} + +func TestGoogleSearchBackendSearchUsesEnvironmentFallback(t *testing.T) { + t.Setenv(envGoogleAPIKey, "env-api-key") + t.Setenv(envGoogleEngineID, "env-engine-id") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "env-api-key", r.URL.Query().Get("key")) + require.Equal(t, "env-engine-id", r.URL.Query().Get("cx")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"items":[{"link":"https://example.com/","title":"Example","snippet":"Snippet"}]}`)) + })) + defer server.Close() + backend := &googleSearchBackend{ + client: server.Client(), + options: &WebSearchOptions{BaseURL: server.URL}, + } + hits, err := backend.search(context.Background(), webSearchInput{Query: "example"}) + require.NoError(t, err) + require.Len(t, hits, 1) + require.Equal(t, "https://example.com/", hits[0].URL) +} + +func TestGoogleSearchBackendSearchRejectsMissingConfig(t *testing.T) { + t.Parallel() + backend := &googleSearchBackend{client: http.DefaultClient} + _, err := backend.search(context.Background(), webSearchInput{Query: "example"}) + require.EqualError(t, err, "google search config is required") + backend = &googleSearchBackend{client: http.DefaultClient, options: &WebSearchOptions{}} + _, err = backend.search(context.Background(), webSearchInput{Query: "example"}) + require.EqualError(t, err, "google search requires api_key and engine_id") +} + +func TestWebSearchToolCoversValidationAndProviderBranches(t *testing.T) { + t.Parallel() + backend, err := newSearchBackend(&WebSearchOptions{ + Provider: "google", + Timeout: 2 * time.Second, + }) + require.NoError(t, err) + googleBackend, ok := backend.(*googleSearchBackend) + require.True(t, ok) + require.Equal(t, 2*time.Second, googleBackend.client.Timeout) + _, err = newSearchBackend(&WebSearchOptions{Provider: "bing"}) + require.EqualError(t, err, "unsupported web search provider: bing") + searchTool, err := newWebSearchTool(&WebSearchOptions{ + BaseURL: "http://127.0.0.1", + }) + require.NoError(t, err) + callable, ok := searchTool.(tool.CallableTool) + require.True(t, ok) + _, err = callToolRaw(callable, webSearchInput{}) + require.EqualError(t, err, "query is required") + _, err = callToolRaw(callable, webSearchInput{ + Query: "example", + AllowedDomains: []string{"example.com"}, + BlockedDomains: []string{"example.org"}, + }) + require.EqualError(t, err, "cannot specify both allowed_domains and blocked_domains") +} + +func TestEncodingHelpersPreserveUTF16AndLineEndings(t *testing.T) { + t.Parallel() + encoded, err := encodeTextBytes("alpha\nbeta\n", "utf16le", "\r\n") + require.NoError(t, err) + decoded, encoding, err := decodeTextBytes(encoded) + require.NoError(t, err) + require.Equal(t, "utf16le", encoding) + require.Equal(t, "alpha\nbeta\n", decoded) + utf8Decoded, utf8Encoding, err := decodeTextBytes([]byte("one\r\ntwo\r\n")) + require.NoError(t, err) + require.Equal(t, "utf8", utf8Encoding) + require.Equal(t, "one\ntwo\n", utf8Decoded) +} + +func TestQuoteHelpersNormalizeAndPreserveSingleQuotes(t *testing.T) { + t.Parallel() + require.Equal(t, "\"quote\" and 'apostrophe'", normalizeQuotes("“quote” and ‘apostrophe’")) + require.Equal(t, "‘quoted text’ and don’t", applyCurlySingleQuotes("'quoted text' and don't")) + actual := findActualString("‘quoted text’ and don’t", "'quoted text' and don't") + require.Equal(t, "‘quoted text’ and don’t", actual) + require.Equal(t, "‘new text’ and can’t", preserveQuoteStyle("'quoted text' and don't", actual, "'new text' and can't")) +} + +func TestWriteOutputToEditOutputPreservesOriginalFileAndPatch(t *testing.T) { + t.Parallel() + out := writeOutputToEditOutput("/tmp/file.txt", editInput{ + FilePath: "/tmp/file.txt", + OldString: "before", + NewString: "after", + ReplaceAll: true, + }, strPtr("before"), "after") + require.Equal(t, "/tmp/file.txt", out.FilePath) + require.Equal(t, "before", out.OriginalFile) + require.True(t, out.ReplaceAll) + require.NotEmpty(t, out.StructuredPatch) +} + +func TestOptionHelpersApplyCustomValues(t *testing.T) { + t.Parallel() + options := &toolSetOptions{} + WithName("custom")(options) + WithMaxFileSize(2048)(options) + require.Equal(t, "custom", options.name) + require.EqualValues(t, 2048, options.maxFileSize) + require.True(t, options.hasMaxSize) +} + +func TestEditLocalFileCreatesMissingFileWithoutReadState(t *testing.T) { + t.Parallel() + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + absPath := filepath.Join(dir, "created.txt") + out, err := editLocalFile(absPath, editInput{ + FilePath: absPath, + OldString: "", + NewString: "created content\n", + }, runtime) + require.NoError(t, err) + require.Equal(t, "", out.OriginalFile) + require.NotEmpty(t, out.StructuredPatch) + raw, readErr := os.ReadFile(absPath) + require.NoError(t, readErr) + require.Equal(t, "created content\n", string(raw)) +} + +func TestEditLocalFileRejectsBinaryAndNoopChanges(t *testing.T) { + t.Parallel() + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + binaryPath := filepath.Join(dir, "data.bin") + require.NoError(t, os.WriteFile(binaryPath, []byte{0x00, 0x01, 0x02}, 0o644)) + _, err := editLocalFile(binaryPath, editInput{ + FilePath: binaryPath, + OldString: "a", + NewString: "b", + }, runtime) + require.EqualError(t, err, "This tool cannot edit binary files.") + textPath := filepath.Join(dir, "notes.txt") + require.NoError(t, os.WriteFile(textPath, []byte("same"), 0o644)) + _, err = editLocalFile(textPath, editInput{ + FilePath: textPath, + OldString: "same", + NewString: "same", + }, runtime) + require.EqualError(t, err, "No changes to make: old_string and new_string are exactly the same.") +} + +func TestEditLocalFileCoversInsertAndReplaceErrorBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + emptyPath := filepath.Join(dir, "empty.txt") + require.NoError(t, os.WriteFile(emptyPath, nil, 0o644)) + snapshot, err := readLocalFileSnapshot(emptyPath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, emptyPath, snapshot.Content, snapshot.Timestamp, nil, nil, "", false, true) + out, err := editLocalFile(emptyPath, editInput{ + FilePath: emptyPath, + OldString: "", + NewString: "inserted\n", + }, runtime) + require.NoError(t, err) + require.Equal(t, "", out.OriginalFile) + require.NotEmpty(t, out.StructuredPatch) + missingPath := filepath.Join(dir, "missing.txt") + require.NoError(t, os.WriteFile(missingPath, []byte("alpha\nbeta\n"), 0o644)) + snapshot, err = readLocalFileSnapshot(missingPath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, missingPath, snapshot.Content, snapshot.Timestamp, nil, nil, "", false, true) + _, err = editLocalFile(missingPath, editInput{ + FilePath: missingPath, + OldString: "gamma", + NewString: "delta", + }, runtime) + require.EqualError(t, err, "String to replace not found in file.\nString: gamma") + duplicatePath := filepath.Join(dir, "duplicate.txt") + require.NoError(t, os.WriteFile(duplicatePath, []byte("alpha\nalpha\n"), 0o644)) + snapshot, err = readLocalFileSnapshot(duplicatePath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, duplicatePath, snapshot.Content, snapshot.Timestamp, nil, nil, "", false, true) + _, err = editLocalFile(duplicatePath, editInput{ + FilePath: duplicatePath, + OldString: "alpha", + NewString: "omega", + }, runtime) + require.EqualError(t, err, "Found 2 matches of the string to replace, but replace_all is false. To replace all occurrences, set replace_all to true. To replace only one occurrence, please provide more context to uniquely identify the instance.\nString: alpha") + replaceAllOut, err := editLocalFile(duplicatePath, editInput{ + FilePath: duplicatePath, + OldString: "alpha", + NewString: "omega", + ReplaceAll: true, + }, runtime) + require.NoError(t, err) + require.True(t, replaceAllOut.ReplaceAll) + raw, readErr := os.ReadFile(duplicatePath) + require.NoError(t, readErr) + require.Equal(t, "omega\nomega\n", string(raw)) +} + +func TestLoadNotebookEditStateRejectsInvalidInputs(t *testing.T) { + t.Parallel() + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + _, err := loadNotebookEditState(filepath.Join(dir, "plain.txt"), notebookEditInput{}, runtime) + require.EqualError(t, err, "File must be a Jupyter notebook (.ipynb file).") + notebookPath := filepath.Join(dir, "test.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{"cells":[],"metadata":{},"nbformat":4,"nbformat_minor":5}`), 0o644)) + snapshot, err := readLocalFileSnapshot(notebookPath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, notebookPath, snapshot.Content, snapshot.Timestamp, nil, nil, "", false, true) + _, err = loadNotebookEditState(notebookPath, notebookEditInput{EditMode: "bad"}, runtime) + require.EqualError(t, err, "Edit mode must be replace, insert, or delete.") + _, err = loadNotebookEditState(notebookPath, notebookEditInput{EditMode: "insert"}, runtime) + require.EqualError(t, err, "Cell type is required when using edit_mode=insert.") + _, err = loadNotebookEditState(notebookPath, notebookEditInput{EditMode: "replace"}, runtime) + require.EqualError(t, err, "Cell ID must be specified when not inserting a new cell.") +} + +func TestLoadNotebookEditStateConvertsReplacePastEndIntoInsert(t *testing.T) { + t.Parallel() + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + notebookPath := filepath.Join(dir, "test.ipynb") + require.NoError(t, os.WriteFile(notebookPath, []byte(`{ + "cells":[{"id":"cell-0","cell_type":"markdown","metadata":{},"source":"hello"}], + "metadata":{"language_info":{"name":"python"}}, + "nbformat":4, + "nbformat_minor":5 + }`), 0o644)) + snapshot, err := readLocalFileSnapshot(notebookPath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, notebookPath, snapshot.Content, snapshot.Timestamp, nil, nil, "", false, true) + state, err := loadNotebookEditState(notebookPath, notebookEditInput{ + EditMode: "replace", + CellID: "cell-1", + }, runtime) + require.NoError(t, err) + require.Equal(t, "insert", state.editMode) + require.Equal(t, "code", state.cellType) + require.Equal(t, "python", state.language) + require.Equal(t, 1, state.cellIndex) +} + +func TestFileStateHelpersCoverRemainingBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + _, err := readLocalFileSnapshot(dir, maxEditableFileSize) + require.ErrorContains(t, err, "is a directory") + missingSnapshot, err := readLocalFileSnapshot(filepath.Join(dir, "missing.txt"), maxEditableFileSize) + require.NoError(t, err) + require.False(t, missingSnapshot.Exists) + largePath := filepath.Join(dir, "large.txt") + require.NoError(t, os.WriteFile(largePath, []byte("0123456789"), 0o644)) + _, err = readLocalFileSnapshot(largePath, 4) + require.ErrorContains(t, err, "exceeds max size") + writtenPath := filepath.Join(dir, "nested", "written.txt") + require.NoError(t, writeLocalFile(writtenPath, "alpha\nbeta\n", 0, "utf8", "\n")) + writtenSnapshot, err := readLocalFileSnapshot(writtenPath, maxEditableFileSize) + require.NoError(t, err) + require.True(t, writtenSnapshot.Exists) + require.Equal(t, "alpha\nbeta\n", writtenSnapshot.Content) + state := &fileState{views: map[string]fileView{}} + err = ensureWriteAllowed(writtenPath, writtenSnapshot, state) + require.EqualError(t, err, "File has not been read yet. Read it first before writing to it.") + state.views[writtenPath] = fileView{IsPartialView: true} + err = ensureWriteAllowed(writtenPath, writtenSnapshot, state) + require.EqualError(t, err, "File has not been read yet. Read it first before writing to it.") + state.views[writtenPath] = fileView{ + Content: writtenSnapshot.Content, + Timestamp: writtenSnapshot.Timestamp - 1, + FromRead: true, + } + require.NoError(t, ensureWriteAllowed(writtenPath, writtenSnapshot, state)) + state.views[writtenPath] = fileView{ + Content: "different", + Timestamp: writtenSnapshot.Timestamp - 1, + FromRead: true, + } + err = ensureWriteAllowed(writtenPath, writtenSnapshot, state) + require.EqualError(t, err, "File has been modified since read, either by the user or by a linter. Read it again before attempting to write it.") + require.True(t, matchesReadView(fileView{ + FromRead: true, + Offset: intPtr(1), + Limit: intPtr(2), + Pages: "1-2", + }, intPtr(1), intPtr(2), "1-2")) + require.False(t, matchesReadView(fileView{FromRead: false}, nil, nil, "")) + require.True(t, intPtrsEqual(nil, nil)) + require.False(t, intPtrsEqual(intPtr(1), nil)) + require.False(t, intPtrsEqual(nil, intPtr(1))) + require.True(t, intPtrsEqual(intPtr(2), intPtr(2))) + actualDouble := findActualString("“quoted text”", "\"quoted text\"") + require.Equal(t, "“quoted text”", actualDouble) + require.Equal(t, "“new text”", preserveQuoteStyle("\"quoted text\"", actualDouble, "\"new text\"")) + require.Equal(t, "fallback", notebookCellType(map[string]any{}, "fallback")) +} + +func TestNotebookHelpersCoverRemainingBranches(t *testing.T) { + t.Parallel() + _, _, err := parseNotebook([]byte(`{"cells":{}}`)) + require.ErrorContains(t, err, "notebook cells are invalid") + _, _, err = parseNotebook([]byte(`{"cells":[1]}`)) + require.ErrorContains(t, err, "notebook cell is invalid") + cellType, err := normalizeNotebookCellType("") + require.NoError(t, err) + require.Empty(t, cellType) + cellType, err = normalizeNotebookCellType(" markdown ") + require.NoError(t, err) + require.Equal(t, "markdown", cellType) + _, err = normalizeNotebookCellType("raw") + require.EqualError(t, err, "Cell type must be code or markdown.") + require.Equal(t, "python", notebookLanguage(map[string]any{})) + require.Equal(t, "python", notebookLanguage(map[string]any{"metadata": map[string]any{}})) + require.Equal(t, "python", notebookLanguage(map[string]any{ + "metadata": map[string]any{"language_info": map[string]any{"name": " "}}, + })) + require.Equal(t, "go", notebookLanguage(map[string]any{ + "metadata": map[string]any{"language_info": map[string]any{"name": "go"}}, + })) + require.False(t, notebookSupportsCellIDs(map[string]any{"nbformat": 4, "nbformat_minor": 4})) + require.True(t, notebookSupportsCellIDs(map[string]any{"nbformat": 5, "nbformat_minor": 0})) + value, ok := notebookInt(float64(4)) + require.True(t, ok) + require.Equal(t, 4, value) + value, ok = notebookInt(3) + require.True(t, ok) + require.Equal(t, 3, value) + _, ok = notebookInt("bad") + require.False(t, ok) + deleteState := notebookEditState{cells: []map[string]any{}, cellIndex: 0} + _, _, err = deleteNotebookCell(&deleteState, notebookEditInput{CellID: "missing"}) + require.EqualError(t, err, `Cell with ID "missing" not found in notebook.`) + replaceState := notebookEditState{ + cells: []map[string]any{{ + "id": "cell-1", + "cell_type": "markdown", + "execution_count": 1, + "outputs": []any{"old"}, + }}, + cellIndex: 0, + cellType: "markdown", + } + resultCellID, resultCellType, err := replaceNotebookCell(&replaceState, notebookEditInput{ + CellID: "cell-1", + NewSource: "updated", + }) + require.NoError(t, err) + require.Equal(t, "cell-1", resultCellID) + require.Equal(t, "markdown", resultCellType) + require.NotContains(t, replaceState.cells[0], "execution_count") + require.NotContains(t, replaceState.cells[0], "outputs") + replaceState = notebookEditState{ + cells: []map[string]any{{"id": "cell-1"}}, + cellIndex: 1, + } + _, _, err = replaceNotebookCell(&replaceState, notebookEditInput{CellID: "missing"}) + require.EqualError(t, err, `Cell with ID "missing" not found in notebook.`) + _, err = marshalNotebook(map[string]any{"bad": func() {}}) + require.Error(t, err) + dir := t.TempDir() + runtime := newToolRuntime(dir, maxEditableFileSize) + missingPath := filepath.Join(dir, "missing.ipynb") + _, err = loadNotebookEditState(missingPath, notebookEditInput{ + EditMode: "insert", + CellType: "code", + }, runtime) + require.EqualError(t, err, "Notebook file does not exist.") + invalidPath := filepath.Join(dir, "invalid.ipynb") + require.NoError(t, os.WriteFile(invalidPath, []byte("not-json"), 0o644)) + invalidSnapshot, err := readLocalFileSnapshot(invalidPath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, invalidPath, invalidSnapshot.Content, invalidSnapshot.Timestamp, nil, nil, "", false, true) + _, err = loadNotebookEditState(invalidPath, notebookEditInput{ + EditMode: "insert", + CellType: "code", + }, runtime) + require.EqualError(t, err, "Notebook is not valid JSON.") + outOfRangePath := filepath.Join(dir, "out-of-range.ipynb") + require.NoError(t, os.WriteFile(outOfRangePath, []byte(`{"cells":[{"id":"cell-0","cell_type":"code","metadata":{},"source":"x"}],"metadata":{},"nbformat":4,"nbformat_minor":5}`), 0o644)) + outOfRangeSnapshot, err := readLocalFileSnapshot(outOfRangePath, maxEditableFileSize) + require.NoError(t, err) + storeReadView(runtime.fileState, outOfRangePath, outOfRangeSnapshot.Content, outOfRangeSnapshot.Timestamp, nil, nil, "", false, true) + _, err = loadNotebookEditState(outOfRangePath, notebookEditInput{ + EditMode: "replace", + CellID: "2", + }, runtime) + require.EqualError(t, err, "Cell with index 2 does not exist in notebook.") + insertState := notebookEditState{ + notebook: map[string]any{"nbformat": 4, "nbformat_minor": 4}, + cells: []map[string]any{ + {"id": "cell-0", "source": "x = 1"}, + {"id": "cell-1", "source": "x = 2"}, + }, + cellType: "code", + cellIndex: 2, + } + var insertedCellID string + insertedCellID = insertNotebookCell(&insertState, notebookEditInput{ + CellID: "cell-2", + NewSource: "x = 3", + }) + require.Equal(t, "cell-2", insertedCellID) + require.Len(t, insertState.cells, 3) + require.Equal(t, "x = 3", insertState.cells[2]["source"]) +} + +func TestCommonHelpersCoverPathAndHTTPBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + relPath, absPath, err := normalizePath(dir, "nested/file.txt") + require.NoError(t, err) + require.Equal(t, "nested/file.txt", relPath) + require.Equal(t, filepath.Join(dir, "nested/file.txt"), absPath) + _, _, err = normalizePath(dir, "../outside.txt") + require.EqualError(t, err, "path is outside base_dir: ../outside.txt") + runtime := newToolRuntime(dir, maxEditableFileSize) + runtime.setBaseDir(filepath.Join(dir, "other")) + require.Equal(t, filepath.Join(dir, "other"), runtime.currentBaseDir()) + require.Equal(t, "file.txt", relativePath(dir, filepath.Join(dir, "file.txt"))) + resp := &http.Response{Body: io.NopCloser(strings.NewReader("abcdef"))} + body, err := readHTTPBody(resp, 3, 0) + require.EqualError(t, err, "response body exceeded limit of 3 bytes") + require.Nil(t, body) + require.Equal(t, 2, countLines("alpha\nbeta")) + require.True(t, matchDomainRule("docs.example.com", "*.example.com")) + require.True(t, matchSearchDomainFilters("https://docs.example.com/path", []string{"example.com"}, nil)) + require.False(t, matchSearchDomainFilters("https://docs.example.com/path", nil, []string{"docs.example.com"})) + require.Equal(t, "docs.example.com", searchURLHost("https://docs.example.com/path")) +} + +func TestCommonHelpersCoverTextAndPatchBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + absPath := filepath.Join(dir, "nested", "file.txt") + relPath, normalizedAbs, err := normalizePath(dir, absPath) + require.NoError(t, err) + require.Equal(t, "nested/file.txt", relPath) + require.Equal(t, filepath.Clean(absPath), normalizedAbs) + _, _, err = normalizePath(dir, "") + require.EqualError(t, err, "path is required") + _, _, err = normalizePath(dir, filepath.Join(filepath.Dir(dir), "outside.txt")) + require.ErrorContains(t, err, "path is outside base_dir") + require.Equal(t, filepath.ToSlash(filepath.Clean("\x00")), relativePath(dir, "\x00")) + body, err := readHTTPBody(&http.Response{}, 8, 0) + require.NoError(t, err) + require.Nil(t, body) + body, err = readHTTPBody(&http.Response{Body: io.NopCloser(strings.NewReader("abcd"))}, 0, 4) + require.NoError(t, err) + require.Equal(t, []byte("abcd"), body) + require.Zero(t, countLines("")) + require.Equal(t, 2, countLines("alpha\nbeta\n")) + require.Empty(t, splitTextLines("")) + require.Equal(t, []string{"alpha", "beta"}, splitTextLines("alpha\nbeta\n")) + sliced, startLine, totalLines := sliceLines("alpha\nbeta\ngamma\n", 0, intPtr(2)) + require.Equal(t, "alpha\nbeta", sliced) + require.Equal(t, 1, startLine) + require.Equal(t, 3, totalLines) + require.Equal(t, "alpha\nbeta\ngamma\n", normalizeNewlines("alpha\r\nbeta\rgamma\n")) + require.Equal(t, "\r\n", detectLineEnding([]byte("alpha\r\nbeta"))) + require.Equal(t, "alpha\r\nbeta", applyLineEnding("alpha\nbeta", "\r\n")) + utf8Encoded, err := encodeTextBytes("alpha\nbeta", "utf8", "\n") + require.NoError(t, err) + require.Equal(t, []byte("alpha\nbeta"), utf8Encoded) + require.Equal(t, "YQ==", fileBase64([]byte("a"))) + require.True(t, isProbablyBinary([]byte("a\x00b"))) + require.False(t, isProbablyBinary([]byte{0xff, 0xfe, 'a', 0x00})) + patch := buildStructuredPatch("alpha\nbeta\n", "alpha\ngamma\n") + require.Len(t, patch, 1) + require.Equal(t, 2, patch[0].OldStart) + require.Equal(t, []string{"-beta", "+gamma"}, patch[0].Lines) + require.Nil(t, buildStructuredPatch("same\n", "same\n")) +} + +func TestCommonHelpersCoverRemainingErrorBranches(t *testing.T) { + t.Parallel() + require.Equal(t, filepath.ToSlash(filepath.Clean("file.txt")), relativePath("\x00", "file.txt")) + _, err := readHTTPBody(&http.Response{Body: io.NopCloser(errReader{err: fs.ErrInvalid})}, 8, 0) + require.ErrorIs(t, err, fs.ErrInvalid) + sliced, startLine, totalLines := sliceLines("alpha\nbeta\n", 10, nil) + require.Empty(t, sliced) + require.Equal(t, 10, startLine) + require.Equal(t, 2, totalLines) +} + +func TestGrepTypePatternsExposeKnownAliases(t *testing.T) { + t.Parallel() + require.Contains(t, typePatterns("go"), "**/*.go") + require.Contains(t, typePatterns("js"), "**/*.js") + require.Contains(t, typePatterns("ts"), "**/*.tsx") + require.Contains(t, typePatterns("py"), "**/*.py") + require.Contains(t, typePatterns("java"), "**/*.java") + require.Contains(t, typePatterns("rs"), "**/*.rs") + require.Contains(t, typePatterns("json"), "**/*.json") + require.Contains(t, typePatterns("md"), "**/*.md") + require.Contains(t, typePatterns("yaml"), "**/*.yaml") + require.Contains(t, typePatterns("txt"), "**/*.txt") + require.Equal(t, []string{"**/*.unknown"}, typePatterns("unknown")) +} + +func TestGrepAndPDFHelpersCoverRemainingBranches(t *testing.T) { + t.Parallel() + items, limit := sliceStrings([]string{"a", "b", "c"}, 1, 1) + require.Equal(t, []string{"b"}, items) + require.NotNil(t, limit) + require.Equal(t, 1, *limit) + items, limit = sliceStrings([]string{"a", "b", "c"}, 5, 1) + require.Empty(t, items) + require.Nil(t, limit) + require.Equal(t, []string{"-e", "-pattern"}, appendRipgrepPattern(nil, "-pattern")) + require.Equal(t, []string{"pattern"}, appendRipgrepPattern(nil, "pattern")) + require.Equal(t, 0, grepOffset(grepInput{})) + require.Equal(t, 3, grepOffset(grepInput{Offset: intPtr(3)})) + require.Equal(t, defaultGrepHeadLimit, grepLimit(grepInput{})) + require.Equal(t, 5, grepLimit(grepInput{HeadLimit: intPtr(5)})) + rangeAll, err := resolvePDFPageRange("2-", 4) + require.NoError(t, err) + require.Equal(t, pdfPageRange{FirstPage: 2, LastPage: 4, Count: 3}, rangeAll) + _, err = resolvePDFPageRange("", 4) + require.EqualError(t, err, `Invalid pages parameter: "". Use formats like "1-5", "3", or "10-20". Pages are 1-indexed.`) + _, err = resolvePDFPageRange("3-1", 4) + require.EqualError(t, err, `Invalid pages parameter: "3-1". Use formats like "1-5", "3", or "10-20". Pages are 1-indexed.`) + _, err = resolvePDFPageRange("1-21", 30) + require.EqualError(t, err, `Page range "1-21" exceeds maximum of 20 pages per request. Please use a smaller range.`) +} + +func TestWebSearchHelpersNormalizeWrappedDuckDuckGoURLs(t *testing.T) { + t.Parallel() + require.Equal(t, "https://golang.org/doc/", normalizeDuckDuckGoResultURL("https://duckduckgo.com/l/?uddg=https%3A%2F%2Fgolang.org%2Fdoc%2F")) + require.Equal(t, "https://example.com", normalizeDuckDuckGoResultURL("https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com")) + require.Equal(t, "https://example.com/a%20b", normalizeDuckDuckGoResultURL("https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fa%2520b")) + require.Equal(t, "%ZZ", normalizeDuckDuckGoResultURL("https://duckduckgo.com/l/?uddg=%25ZZ")) + require.Equal(t, "not a url", normalizeDuckDuckGoResultURL("not a url")) + require.Empty(t, normalizeDuckDuckGoResultURL(" ")) +} + +func TestResolveRedirectURLAndDomainFiltersHandleEdgeCases(t *testing.T) { + t.Parallel() + nextURL, err := resolveRedirectURL("https://example.com/start", "https://example.com/next") + require.NoError(t, err) + require.Equal(t, "https://example.com/next", nextURL) + _, err = resolveRedirectURL("://bad", "/next") + require.Error(t, err) + require.False(t, matchSearchDomainFilters("not a url", []string{"example.com"}, nil)) +} + +func TestWebSearchBackendsCoverRemainingErrorBranches(t *testing.T) { + t.Parallel() + duckBackend := &duckDuckGoSearchBackend{ + client: http.DefaultClient, + baseURL: "://bad", + } + _, err := duckBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.Error(t, err) + requestFailedClient := &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, fs.ErrInvalid + }), + } + duckBackend = &duckDuckGoSearchBackend{ + client: requestFailedClient, + baseURL: "https://example.com/search", + } + _, err = duckBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.ErrorIs(t, err, fs.ErrInvalid) + bodyFailedClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + require.Equal(t, "tester", req.Header.Get("User-Agent")) + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: http.Header{}, + Body: io.NopCloser(errReader{err: fs.ErrInvalid}), + Request: req, + }, nil + }), + } + duckBackend = &duckDuckGoSearchBackend{ + client: bodyFailedClient, + baseURL: "https://example.com/search", + userAgent: "tester", + } + _, err = duckBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.ErrorIs(t, err, fs.ErrInvalid) + statusFailedClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Status: "429 Too Many Requests", + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("rate limited")), + Request: req, + }, nil + }), + } + duckBackend = &duckDuckGoSearchBackend{ + client: statusFailedClient, + baseURL: "https://example.com/search", + } + _, err = duckBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.EqualError(t, err, "duckduckgo search request failed: status=429 body=rate limited") + googleBackend := &googleSearchBackend{ + client: http.DefaultClient, + options: &WebSearchOptions{APIKey: "key", EngineID: "engine", BaseURL: "://bad"}, + } + _, err = googleBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.Error(t, err) + googleBackend = &googleSearchBackend{ + client: requestFailedClient, + options: &WebSearchOptions{ + APIKey: "key", + EngineID: "engine", + BaseURL: "https://example.com/search", + }, + } + _, err = googleBackend.search(context.Background(), webSearchInput{Query: "example"}) + require.ErrorIs(t, err, fs.ErrInvalid) +} + +func TestExecRipgrepReturnsErrorForInvalidArguments(t *testing.T) { + if _, err := exec.LookPath("rg"); err != nil { + t.Skip("ripgrep is not available") + } + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("alpha\n"), 0o644)) + _, err := execRipgrep(context.Background(), dir, "--definitely-invalid-flag", "alpha") + require.Error(t, err) + require.Contains(t, err.Error(), "ripgrep failed") +} + +func TestNotebookCellHelpersResolveIDsAndIndexes(t *testing.T) { + t.Parallel() + idx, ok := parseNotebookCellID(" cell-3 ") + require.True(t, ok) + require.Equal(t, 3, idx) + _, ok = parseNotebookCellID("-1") + require.False(t, ok) + cells := []map[string]any{ + {"id": "first"}, + {"id": "second"}, + } + byID, err := notebookCellIndex(cells, "second") + require.NoError(t, err) + require.Equal(t, 1, byID) + byNumericID, err := notebookCellIndex(cells, "cell-1") + require.NoError(t, err) + require.Equal(t, 1, byNumericID) + _, err = notebookCellIndex(cells, "missing") + require.EqualError(t, err, `Cell with ID "missing" not found in notebook.`) +} + +func TestGrepHelpersHandleContextAndCountOutput(t *testing.T) { + t.Parallel() + require.Equal(t, []string{"-C", "2"}, appendRipgrepContext(nil, grepInput{ContextAlt: intPtr(2)})) + require.Equal(t, []string{"-C", "1"}, appendRipgrepContext(nil, grepInput{Context: intPtr(1)})) + require.Equal(t, []string{"-B", "3", "-A", "4"}, appendRipgrepContext(nil, grepInput{ + Before: intPtr(3), + After: intPtr(4), + })) + out := formatRipgrepCountOutput(0, 0, []string{"a.txt:2", "b.txt:3", "raw"}) + require.Equal(t, "count", out.Mode) + require.Equal(t, 3, out.NumFiles) + require.Equal(t, 5, out.NumMatches) + require.Equal(t, "a.txt:2\nb.txt:3\nraw", out.Content) +} + +func TestGrepHelpersCoverRemainingFallbackAndRipgrepBranches(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("alpha\nbeta\nalpha\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main\nfunc main() {\nprintln(\"alpha\")\nprintln(\"beta\")\n}\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "note.md"), []byte("alpha beta"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "data.bin"), []byte{0x00, 'a', 'l', 'p', 'h', 'a'}, 0o644)) + runtime := newToolRuntime(dir, maxEditableFileSize) + _, err := runFallbackGrep(runtime, dir, grepInput{Pattern: "("}) + require.Error(t, err) + fileCandidates, err := collectGrepCandidates(dir, "a.txt", "", "") + require.NoError(t, err) + require.Equal(t, []string{filepath.Join(dir, "a.txt")}, fileCandidates) + typeCandidates, err := collectGrepCandidates(dir, "", "", "go") + require.NoError(t, err) + require.Equal(t, []string{filepath.Join(dir, "main.go")}, typeCandidates) + globCandidates, err := collectGrepCandidates(dir, "", "*.md", "") + require.NoError(t, err) + require.Equal(t, []string{filepath.Join(dir, "note.md")}, globCandidates) + _, err = collectGrepCandidates(dir, "missing", "", "") + require.Error(t, err) + require.False(t, matchesAnyPattern("main.go", []string{"["})) + require.Equal(t, []string{"*.go", "*.md", "{foo,bar}.txt"}, splitGlobPatterns("*.go,*.md {foo,bar}.txt")) + contentCollector := newFallbackGrepCollector("content") + err = collectFallbackLineMatch("alpha\nbeta\n", "a.txt", regexp.MustCompile("alpha"), grepInput{ + ShowLineNum: boolPtr(false), + }, contentCollector) + require.NoError(t, err) + require.Equal(t, []string{"a.txt:alpha"}, contentCollector.contentLines) + countCollector := newFallbackGrepCollector("count") + err = collectFallbackMultilineMatch("alpha\nbeta\nalpha\n", "a.txt", regexp.MustCompile("alpha(?s).*beta"), grepInput{}, countCollector) + require.NoError(t, err) + require.Equal(t, []string{"a.txt:1"}, countCollector.countLines) + fileCollector := newFallbackGrepCollector("files_with_matches") + err = collectFallbackMultilineMatch("alpha\nbeta\n", "a.txt", regexp.MustCompile("alpha(?s).*beta"), grepInput{}, fileCollector) + require.NoError(t, err) + require.Equal(t, []string{"a.txt"}, fileCollector.fileMatches) + fallbackOut, err := runFallbackGrep(runtime, dir, grepInput{ + Pattern: "alpha", + HeadLimit: intPtr(1), + }) + require.NoError(t, err) + require.Equal(t, "files_with_matches", fallbackOut.Mode) + require.Len(t, fallbackOut.Filenames, 1) + restore := withRipgrepForTest(func(string) (string, error) { + return writeExecutableFile(t, dir, "fake-rg.sh", "#!/bin/sh\nprintf 'main.go:1:alpha\\nmain.go:2:beta\\n'\n"), nil + }) + ripgrepContentOut, handled, err := runRipgrepCommand(context.Background(), dir, ".", grepInput{ + Pattern: "alpha", + OutputMode: "content", + }) + restore() + require.True(t, handled) + require.NoError(t, err) + require.Equal(t, "content", ripgrepContentOut.Mode) + require.Contains(t, ripgrepContentOut.Content, "main.go:1:alpha") + restore = withRipgrepForTest(func(string) (string, error) { + return writeExecutableFile(t, dir, "fake-rg-empty.sh", "#!/bin/sh\nexit 1\n"), nil + }) + lines, err := execRipgrep(context.Background(), dir, "alpha") + restore() + require.NoError(t, err) + require.Empty(t, lines) + restore = withRipgrepForTest(func(string) (string, error) { + return writeExecutableFile(t, dir, "fake-rg-error.sh", "#!/bin/sh\nexit 2\n"), nil + }) + _, err = execRipgrep(context.Background(), dir, "alpha") + restore() + require.EqualError(t, err, "ripgrep failed: ripgrep exited with code 2") +} + +func TestExecRipgrepReturnsEmptyOnNoMatches(t *testing.T) { + if _, err := exec.LookPath("rg"); err != nil { + t.Skip("ripgrep is not available") + } + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("alpha\nbeta\n"), 0o644)) + lines, err := execRipgrep(context.Background(), dir, "--files-with-matches", "missing-pattern") + require.NoError(t, err) + require.Empty(t, lines) +} + +func TestRunLocalRipgrepReturnsFalseWhenRipgrepIsUnavailable(t *testing.T) { + t.Parallel() + restore := withRipgrepForTest(func(string) (string, error) { + return "", errors.New("not found") + }) + defer restore() + out, ok, err := runLocalRipgrep(context.Background(), t.TempDir(), grepInput{Pattern: "alpha"}) + require.NoError(t, err) + require.False(t, ok) + require.Equal(t, grepOutput{}, out) +} + +func TestRunLocalRipgrepRejectsPathsOutsideBaseDir(t *testing.T) { + t.Parallel() + restore := withRipgrepForTest(func(string) (string, error) { + return "/bin/true", nil + }) + defer restore() + _, ok, err := runLocalRipgrep(context.Background(), t.TempDir(), grepInput{ + Pattern: "alpha", + Path: "../outside.txt", + }) + require.Error(t, err) + require.True(t, ok) + require.Contains(t, err.Error(), "path is outside base_dir") +} + +func TestRunLocalRipgrepSupportsFilesContentAndCountModes(t *testing.T) { + if _, err := exec.LookPath("rg"); err != nil { + t.Skip("ripgrep is not available") + } + restore := withRipgrepForTest(exec.LookPath) + defer restore() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("alpha\nbeta\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("alpha\n"), 0o644)) + filesOut, ok, err := runLocalRipgrep(context.Background(), dir, grepInput{ + Pattern: "alpha", + OutputMode: "files_with_matches", + }) + require.NoError(t, err) + require.True(t, ok) + require.ElementsMatch(t, []string{"a.txt", "b.txt"}, filesOut.Filenames) + contentOut, ok, err := runLocalRipgrep(context.Background(), dir, grepInput{ + Pattern: "alpha", + OutputMode: "content", + ContextAlt: intPtr(1), + ShowLineNum: boolPtr(true), + }) + require.NoError(t, err) + require.True(t, ok) + require.Contains(t, contentOut.Content, "a.txt:1:alpha") + countOut, ok, err := runLocalRipgrep(context.Background(), dir, grepInput{ + Pattern: "alpha", + OutputMode: "count", + }) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 2, countOut.NumMatches) +} + +func TestBashAndProcessHelpersCoverTimeoutAndExitState(t *testing.T) { + t.Setenv("BASH_DEFAULT_TIMEOUT_MS", "50") + require.Equal(t, 50, bashTimeout(nil)) + require.Equal(t, defaultBashTimeoutMs, bashTimeout(intPtr(0))) + require.Equal(t, maxBashTimeoutMs, bashTimeout(intPtr(maxBashTimeoutMs+1))) + proc, err := os.StartProcess("/bin/true", []string{"true"}, &os.ProcAttr{ + Env: processEnv(nil), + Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, + }) + require.NoError(t, err) + state, waitErr := proc.Wait() + require.NoError(t, waitErr) + require.Equal(t, "completed", backgroundTaskStatus(nil, state)) + require.Equal(t, 0, backgroundTaskExitCode(nil, state)) + require.Equal(t, "exited", backgroundTaskStatus(errors.New("wait failed"), nil)) + require.Equal(t, 1, backgroundTaskExitCode(errors.New("wait failed"), nil)) + require.Equal(t, 0, backgroundTaskExitCode(nil, nil)) + require.Equal(t, "exited", backgroundTaskStatus(nil, nil)) + require.Equal(t, "stdout\nstderr", joinOutput("stdout", "stderr")) + require.Equal(t, "stderr", joinOutput("", "stderr")) +} + +func TestBashHelpersCoverForegroundAndBackgroundErrorBranches(t *testing.T) { + missingBaseDir := filepath.Join(t.TempDir(), "missing") + runtime := newToolRuntime(missingBaseDir, maxEditableFileSize) + out, err := runForegroundCommand(context.Background(), runtime, bashInput{Command: "printf 'hello'"}) + require.NoError(t, err) + require.Equal(t, 1, out.ExitCode) + require.False(t, out.TimedOut) + require.Empty(t, out.Stdout) + require.Empty(t, out.Stderr) + tmpRoot := t.TempDir() + tmpFile := filepath.Join(tmpRoot, "tmp-file") + require.NoError(t, os.WriteFile(tmpFile, []byte("x"), 0o644)) + t.Setenv("TMPDIR", tmpFile) + _, err = runBackgroundCommand(newToolRuntime(t.TempDir(), maxEditableFileSize), "printf 'hello'") + require.Error(t, err) + _, err = runBackgroundCommand(runtime, "printf 'hello'") + require.Error(t, err) +} + +func TestRunCapturedProcessAndWaitForProcess(t *testing.T) { + t.Parallel() + dir := t.TempDir() + result, err := runCapturedProcess(context.Background(), dir, []string{"TEST_KEY=VALUE"}, "bash", "-lc", "printf \"$TEST_KEY\"") + require.NoError(t, err) + require.Equal(t, "VALUE", string(result.Stdout)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + proc, err := os.StartProcess("/bin/sleep", []string{"sleep", "1"}, &os.ProcAttr{ + Dir: dir, + Env: processEnv(nil), + Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, + }) + require.NoError(t, err) + state := waitForProcess(ctx, proc) + require.ErrorIs(t, state.Err, context.DeadlineExceeded) + require.NotNil(t, state.State) +} + +func TestProcessPipeHelpersAndTaskStopErrors(t *testing.T) { + t.Parallel() + stdin, stdoutReader, stdoutWriter, stderrReader, stderrWriter, closeAll, err := processPipes() + require.NoError(t, err) + require.NotNil(t, stdin) + require.NotNil(t, stdoutReader) + require.NotNil(t, stdoutWriter) + require.NotNil(t, stderrReader) + require.NotNil(t, stderrWriter) + require.NoError(t, closeAll()) + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + stopTool, err := newTaskStopTool(runtime) + require.NoError(t, err) + callable, ok := stopTool.(tool.CallableTool) + require.True(t, ok) + _, err = callToolRaw(callable, taskStopInput{}) + require.EqualError(t, err, "Missing required parameter: task_id") + runtime.taskState.tasks["done"] = &backgroundTask{ + ID: "done", + Command: "echo done", + Type: toolBash, + Status: "completed", + } + _, err = callToolRaw(callable, taskStopInput{TaskID: "done"}) + require.EqualError(t, err, "Task done is not running (status: completed)") + runtime.taskState.tasks["running"] = &backgroundTask{ + ID: "running", + Command: "sleep 30", + Type: toolBash, + Status: "running", + } + _, err = callToolRaw(callable, taskStopInput{TaskID: "running"}) + require.EqualError(t, err, "Task running has no running process") + require.Equal(t, "running", runtime.taskState.tasks["running"].Status) +} + +func TestTaskStopAcceptsShellIDAndPropagatesKillErrors(t *testing.T) { + t.Parallel() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + stopTool, err := newTaskStopTool(runtime) + require.NoError(t, err) + callable, ok := stopTool.(tool.CallableTool) + require.True(t, ok) + _, err = callToolRaw(callable, taskStopInput{ShellID: "missing"}) + require.EqualError(t, err, "No task found with ID: missing") + proc, err := os.StartProcess("/bin/true", []string{"true"}, &os.ProcAttr{ + Env: processEnv(nil), + Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, + }) + require.NoError(t, err) + _, err = proc.Wait() + require.NoError(t, err) + runtime.taskState.tasks["finished"] = &backgroundTask{ + ID: "finished", + Command: "true", + Type: toolBash, + Status: "running", + Process: proc, + } + _, err = callToolRaw(callable, taskStopInput{ShellID: "finished"}) + require.Error(t, err) + require.Equal(t, "running", runtime.taskState.tasks["finished"].Status) +} + +func TestPDFAndNotebookHelpersCoverFallbackBranches(t *testing.T) { + t.Parallel() + pdfBytes := newTestPDF(t, []string{"page one", "page two", "page three"}) + pageCount, err := pdfPageCount(pdfBytes) + require.NoError(t, err) + require.Equal(t, 3, pageCount) + pages, err := resolvePDFPageRange("1-3", 3) + require.NoError(t, err) + require.Equal(t, pdfPageRange{FirstPage: 1, LastPage: 3, Count: 3}, pages) + _, err = resolvePDFPageRange("4", 3) + require.EqualError(t, err, "Page 4 exceeds the PDF page count of 3.") + require.Equal(t, "cell-2", notebookResultCellID(map[string]any{}, 2, "")) + require.Equal(t, "fallback", notebookResultCellID(map[string]any{}, 2, "fallback")) + value, ok := notebookInt(float64(3)) + require.True(t, ok) + require.Equal(t, 3, value) + _, ok = notebookInt("bad") + require.False(t, ok) + require.Equal(t, "python", notebookLanguage(map[string]any{})) + require.Equal(t, "go", notebookLanguage(map[string]any{ + "metadata": map[string]any{ + "language_info": map[string]any{"name": "go"}, + }, + })) +} + +func TestPDFHelpersCoverRemainingBranches(t *testing.T) { + t.Parallel() + pdftoppmTestMu.Lock() + t.Cleanup(func() { + pdftoppmTestMu.Unlock() + }) + _, err := pdfPageCount([]byte("not-a-pdf")) + require.ErrorContains(t, err, "failed to create PDF reader") + rangeOne, err := resolvePDFPageRange("2", 4) + require.NoError(t, err) + require.Equal(t, pdfPageRange{FirstPage: 2, LastPage: 2, Count: 1}, rangeOne) + _, err = resolvePDFPageRange("bad", 4) + require.EqualError(t, err, `Invalid pages parameter: "bad". Use formats like "1-5", "3", or "10-20". Pages are 1-indexed.`) + _, err = resolvePDFPageRange("6-", 4) + require.EqualError(t, err, `Page range "6-" is outside the PDF page count of 4.`) + _, err = resolvePDFPageRange("5-6", 4) + require.EqualError(t, err, `Page range "5-6" exceeds the PDF page count of 4.`) + scriptDir := t.TempDir() + successScript := filepath.Join(scriptDir, "pdftoppm-success") + require.NoError(t, os.WriteFile(successScript, []byte("#!/bin/bash\nprefix=\"${@: -1}\"\ntouch \"${prefix}-1.jpg\" \"${prefix}-2.jpg\"\n"), 0o755)) + oldLookPath := pdftoppmLookPath + oldPath := pdftoppmPath + oldOnce := pdftoppmOnce + pdftoppmLookPath = func(string) (string, error) { + return successScript, nil + } + pdftoppmPath = "" + pdftoppmOnce = sync.Once{} + t.Cleanup(func() { + pdftoppmLookPath = oldLookPath + pdftoppmPath = oldPath + pdftoppmOnce = oldOnce + }) + path, err := pdftoppmBinary() + require.NoError(t, err) + require.Equal(t, successScript, path) + outputDir, count, err := extractPDFPages(filepath.Join(t.TempDir(), "fake.pdf"), pdfPageRange{ + FirstPage: 1, + LastPage: 2, + Count: 2, + }) + require.NoError(t, err) + require.Equal(t, 2, count) + defer os.RemoveAll(outputDir) + _, statErr := os.Stat(filepath.Join(outputDir, "page-1.jpg")) + require.NoError(t, statErr) + noImageScript := filepath.Join(scriptDir, "pdftoppm-empty") + require.NoError(t, os.WriteFile(noImageScript, []byte("#!/bin/bash\nexit 0\n"), 0o755)) + pdftoppmPath = noImageScript + _, _, err = extractPDFPages(filepath.Join(t.TempDir(), "fake.pdf"), pdfPageRange{ + FirstPage: 1, + LastPage: 1, + Count: 1, + }) + require.EqualError(t, err, "failed to extract PDF pages: no rendered page images were produced") + failScript := filepath.Join(scriptDir, "pdftoppm-fail") + require.NoError(t, os.WriteFile(failScript, []byte("#!/bin/bash\necho render failed >&2\nexit 1\n"), 0o755)) + pdftoppmPath = failScript + _, _, err = extractPDFPages(filepath.Join(t.TempDir(), "fake.pdf"), pdfPageRange{ + FirstPage: 1, + LastPage: 1, + Count: 1, + }) + require.EqualError(t, err, "failed to extract PDF pages: render failed") +} + +func TestExtractPDFPagesFailsWhenPdftoppmIsUnavailable(t *testing.T) { + t.Parallel() + pdftoppmTestMu.Lock() + t.Cleanup(func() { + pdftoppmTestMu.Unlock() + }) + oldLookPath := pdftoppmLookPath + oldPath := pdftoppmPath + oldOnce := pdftoppmOnce + pdftoppmLookPath = func(string) (string, error) { + return "", errors.New("not found") + } + pdftoppmPath = "" + pdftoppmOnce = sync.Once{} + t.Cleanup(func() { + pdftoppmLookPath = oldLookPath + pdftoppmPath = oldPath + pdftoppmOnce = oldOnce + }) + _, _, err := extractPDFPages(filepath.Join(t.TempDir(), "missing.pdf"), pdfPageRange{ + FirstPage: 1, + LastPage: 1, + Count: 1, + }) + require.EqualError(t, err, "pdftoppm is not installed. Install poppler-utils (e.g. `brew install poppler` or `apt-get install poppler-utils`) to enable PDF page rendering.") +} + +func TestReadTaskSnapshotHandlesMissingOutputFile(t *testing.T) { + t.Parallel() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + exitCode := 17 + runtime.taskState.tasks["task-1"] = &backgroundTask{ + ID: "task-1", + Command: "echo hi", + Type: toolBash, + OutputPath: filepath.Join(t.TempDir(), "missing.log"), + Status: "completed", + ExitCode: &exitCode, + } + snapshot, err := readTaskSnapshot(runtime, "task-1") + require.NoError(t, err) + require.Equal(t, "task-1", snapshot.TaskID) + require.Equal(t, toolBash, snapshot.TaskType) + require.Equal(t, "completed", snapshot.Status) + require.Equal(t, "echo hi", snapshot.Description) + require.Empty(t, snapshot.Output) + require.NotNil(t, snapshot.ExitCode) + require.Equal(t, 17, *snapshot.ExitCode) +} + +func TestReadTaskSnapshotReturnsReadErrorForDirectoryOutputPath(t *testing.T) { + t.Parallel() + outputDir := t.TempDir() + runtime := newToolRuntime(t.TempDir(), maxEditableFileSize) + runtime.taskState.tasks["task-1"] = &backgroundTask{ + ID: "task-1", + Command: "echo hi", + Type: toolBash, + OutputPath: outputDir, + Status: "completed", + } + _, err := readTaskSnapshot(runtime, "task-1") + require.Error(t, err) +} + +type errReader struct { + err error +} + +func (r errReader) Read(_ []byte) (int, error) { + return 0, r.err +} + +type stubTool struct { + decl *tool.Declaration +} + +func (s stubTool) Declaration() *tool.Declaration { + return s.decl +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newTestPDF(t *testing.T, pages []string) []byte { + t.Helper() + pdfDoc := fpdf.New("P", "mm", "A4", "") + pdfDoc.SetFont("Helvetica", "", 12) + for _, pageText := range pages { + pdfDoc.AddPage() + pdfDoc.Cell(40, 10, pageText) + } + var buf bytes.Buffer + require.NoError(t, pdfDoc.Output(&buf)) + return buf.Bytes() +} + +func writeExecutableFile(t *testing.T, dir string, name string, content string) string { + t.Helper() + path := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(path, []byte(content), 0o755)) + return path +} + +var tinyPNGBytes = []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, +} + +var ripgrepTestMu sync.Mutex +var pdftoppmTestMu sync.Mutex + +func mustCallableTool(t *testing.T, tools []tool.Tool, name string) tool.CallableTool { + t.Helper() + for _, candidate := range tools { + if candidate.Declaration() == nil || candidate.Declaration().Name != name { + continue + } + callable, ok := candidate.(tool.CallableTool) + require.True(t, ok) + return callable + } + t.Fatalf("tool %s not found", name) + return nil +} + +func callToolRaw(target tool.CallableTool, input any) (any, error) { + args, err := json.Marshal(input) + if err != nil { + return nil, err + } + return target.Call(context.Background(), args) +} + +func callToolRawWithContext(target tool.CallableTool, ctx context.Context, input any) (any, error) { + args, err := json.Marshal(input) + if err != nil { + return nil, err + } + return target.Call(ctx, args) +} + +func callToolAs[T any](t *testing.T, target tool.CallableTool, input any) T { + t.Helper() + out, err := callToolRaw(target, input) + require.NoError(t, err) + data, err := json.Marshal(out) + require.NoError(t, err) + var decoded T + require.NoError(t, json.Unmarshal(data, &decoded)) + return decoded +} + +func toolNames(tools []tool.Tool) []string { + names := make([]string, 0, len(tools)) + for _, candidate := range tools { + if candidate == nil || candidate.Declaration() == nil { + continue + } + names = append(names, candidate.Declaration().Name) + } + return names +} + +func intPtr(value int) *int { + return &value +} + +func boolPtr(value bool) *bool { + return &value +} + +func strPtr(value string) *string { + return &value +} + +func derefString(value *string) string { + if value == nil { + return "" + } + return *value +} + +func strconvString(value int) string { + return strconv.Itoa(value) +} + +func withRipgrepForTest(lookPath func(string) (string, error)) func() { + ripgrepTestMu.Lock() + oldLookPath := ripgrepLookPath + oldPath := ripgrepPath + ripgrepLookPath = lookPath + ripgrepPath = "" + ripgrepOnce = sync.Once{} + return func() { + ripgrepLookPath = oldLookPath + ripgrepPath = oldPath + ripgrepOnce = sync.Once{} + ripgrepTestMu.Unlock() + } +} diff --git a/tool/claudecode/common.go b/tool/claudecode/common.go new file mode 100644 index 000000000..0d7e4b20c --- /dev/null +++ b/tool/claudecode/common.go @@ -0,0 +1,353 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "path/filepath" + "slices" + "strings" + stdunicode "unicode" + + "golang.org/x/net/html" + textunicode "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +func normalizePath(baseDir string, raw string) (string, string, error) { + pathValue := strings.TrimSpace(raw) + if pathValue == "" { + return "", "", fmt.Errorf("path is required") + } + cleanBase, err := filepath.Abs(baseDir) + if err != nil { + return "", "", err + } + if filepath.IsAbs(pathValue) { + cleanPath := filepath.Clean(pathValue) + rel, err := filepath.Rel(cleanBase, cleanPath) + if err != nil { + return "", "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", "", fmt.Errorf("path is outside base_dir: %s", raw) + } + return filepath.ToSlash(filepath.Clean(rel)), cleanPath, nil + } + cleanPath := filepath.Clean(pathValue) + if cleanPath == ".." || strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) { + return "", "", fmt.Errorf("path is outside base_dir: %s", raw) + } + absPath := filepath.Join(cleanBase, cleanPath) + return filepath.ToSlash(filepath.Clean(cleanPath)), absPath, nil +} + +func (r *runtime) currentBaseDir() string { + r.mu.RLock() + defer r.mu.RUnlock() + return r.baseDir +} + +func (r *runtime) setBaseDir(baseDir string) { + r.mu.Lock() + r.baseDir = baseDir + r.mu.Unlock() +} + +func relativePath(baseDir string, absPath string) string { + baseAbs, err := filepath.Abs(baseDir) + if err != nil { + return filepath.ToSlash(filepath.Clean(absPath)) + } + rel, err := filepath.Rel(baseAbs, absPath) + if err != nil { + return filepath.ToSlash(filepath.Clean(absPath)) + } + return filepath.ToSlash(filepath.Clean(rel)) +} + +func readHTTPBody( + resp *http.Response, + maxContentLength int, + maxTotalContentLength int, +) ([]byte, error) { + if resp == nil || resp.Body == nil { + return nil, nil + } + limit := maxContentLength + if maxTotalContentLength > 0 && (limit == 0 || maxTotalContentLength < limit) { + limit = maxTotalContentLength + } + if limit <= 0 { + limit = 1 << 20 + } + reader := io.LimitReader(resp.Body, int64(limit)+1) + body, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + if len(body) > limit { + return nil, fmt.Errorf("response body exceeded limit of %d bytes", limit) + } + return body, nil +} + +func countLines(content string) int { + if content == "" { + return 0 + } + parts := strings.Split(content, "\n") + if strings.HasSuffix(content, "\n") { + return len(parts) - 1 + } + return len(parts) +} + +func splitTextLines(content string) []string { + if content == "" { + return []string{} + } + lines := strings.Split(content, "\n") + if strings.HasSuffix(content, "\n") { + return lines[:len(lines)-1] + } + return lines +} + +func sliceLines(content string, offset int, limit *int) (string, int, int) { + lines := splitTextLines(content) + totalLines := len(lines) + startLine := offset + if startLine <= 0 { + startLine = 1 + } + startIdx := startLine - 1 + if startIdx > totalLines { + startIdx = totalLines + } + endIdx := totalLines + if limit != nil && *limit >= 0 && startIdx+*limit < endIdx { + endIdx = startIdx + *limit + } + sliced := lines[startIdx:endIdx] + result := strings.Join(sliced, "\n") + if len(sliced) > 0 && strings.HasSuffix(content, "\n") && endIdx == totalLines { + result += "\n" + } + return result, startLine, totalLines +} + +func normalizeNewlines(content string) string { + content = strings.ReplaceAll(content, "\r\n", "\n") + content = strings.ReplaceAll(content, "\r", "\n") + return content +} + +func detectLineEnding(raw []byte) string { + if bytes.Contains(raw, []byte("\r\n")) { + return "\r\n" + } + return "\n" +} + +func applyLineEnding(content string, lineEnding string) string { + if lineEnding == "\r\n" { + return strings.ReplaceAll(content, "\n", "\r\n") + } + return content +} + +func decodeTextBytes(raw []byte) (string, string, error) { + if len(raw) >= 2 && raw[0] == 0xff && raw[1] == 0xfe { + decoder := textunicode.UTF16(textunicode.LittleEndian, textunicode.ExpectBOM).NewDecoder() + decoded, _, err := transform.String(decoder, string(raw)) + if err != nil { + return "", "", err + } + return normalizeNewlines(decoded), "utf16le", nil + } + return normalizeNewlines(string(raw)), "utf8", nil +} + +func encodeTextBytes(content string, encoding string, lineEnding string) ([]byte, error) { + normalized := applyLineEnding(content, lineEnding) + if encoding == "utf16le" { + encoder := textunicode.UTF16(textunicode.LittleEndian, textunicode.UseBOM).NewEncoder() + encoded, _, err := transform.String(encoder, normalized) + if err != nil { + return nil, err + } + return []byte(encoded), nil + } + return []byte(normalized), nil +} + +func fileBase64(raw []byte) string { + return base64.StdEncoding.EncodeToString(raw) +} + +func isProbablyBinary(raw []byte) bool { + if len(raw) >= 2 && raw[0] == 0xff && raw[1] == 0xfe { + return false + } + for _, b := range raw { + if b == 0 { + return true + } + } + return false +} + +func buildStructuredPatch(oldContent string, newContent string) []patchHunk { + if oldContent == newContent { + return nil + } + oldLines := splitTextLines(oldContent) + newLines := splitTextLines(newContent) + prefix := 0 + for prefix < len(oldLines) && prefix < len(newLines) && oldLines[prefix] == newLines[prefix] { + prefix++ + } + oldSuffixLimit := len(oldLines) - prefix + newSuffixLimit := len(newLines) - prefix + suffix := 0 + for suffix < oldSuffixLimit && suffix < newSuffixLimit { + if oldLines[len(oldLines)-1-suffix] != newLines[len(newLines)-1-suffix] { + break + } + suffix++ + } + oldMid := oldLines[prefix : len(oldLines)-suffix] + newMid := newLines[prefix : len(newLines)-suffix] + lines := make([]string, 0, len(oldMid)+len(newMid)) + for _, line := range oldMid { + lines = append(lines, "-"+line) + } + for _, line := range newMid { + lines = append(lines, "+"+line) + } + oldStart := prefix + 1 + newStart := prefix + 1 + if len(oldLines) == 0 { + oldStart = 0 + } + if len(newLines) == 0 { + newStart = 0 + } + return []patchHunk{{ + OldStart: oldStart, + OldLines: len(oldMid), + NewStart: newStart, + NewLines: len(newMid), + Lines: lines, + }} +} + +func matchSearchDomainFilters( + rawURL string, + allowed []string, + blocked []string, +) bool { + host := searchURLHost(rawURL) + if host == "" { + return len(allowed) == 0 + } + for _, rule := range blocked { + if matchDomainRule(host, rule) { + return false + } + } + if len(allowed) == 0 { + return true + } + for _, rule := range allowed { + if matchDomainRule(host, rule) { + return true + } + } + return false +} + +func searchURLHost(rawURL string) string { + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return "" + } + return strings.ToLower(parsed.Hostname()) +} + +func matchDomainRule(host string, rule string) bool { + cleanRule := strings.ToLower(strings.TrimSpace(rule)) + if cleanRule == "" { + return false + } + if strings.HasPrefix(cleanRule, "*.") { + suffix := strings.TrimPrefix(cleanRule, "*.") + return host == suffix || strings.HasSuffix(host, "."+suffix) + } + return host == cleanRule || strings.HasSuffix(host, "."+cleanRule) +} + +func extractHTMLText(raw []byte) string { + doc, err := html.Parse(bytes.NewReader(raw)) + if err != nil { + return strings.TrimSpace(string(raw)) + } + parts := make([]string, 0, 32) + var visit func(*html.Node) + visit = func(node *html.Node) { + if node.Type == html.ElementNode { + name := strings.ToLower(node.Data) + if name == "script" || name == "style" || name == "noscript" { + return + } + } + if node.Type == html.TextNode { + text := strings.TrimSpace(node.Data) + if text != "" { + parts = append(parts, collapseWhitespace(text)) + } + } + for child := node.FirstChild; child != nil; child = child.NextSibling { + visit(child) + } + } + visit(doc) + return strings.TrimSpace(strings.Join(parts, "\n")) +} + +func collapseWhitespace(raw string) string { + fields := strings.FieldsFunc(raw, func(r rune) bool { + return stdunicode.IsSpace(r) + }) + return strings.Join(fields, " ") +} + +func joinOutput(stdout string, stderr string) string { + switch { + case stdout == "": + return stderr + case stderr == "": + return stdout + default: + return stdout + "\n" + stderr + } +} + +func sortedCopy(items []string) []string { + out := append([]string{}, items...) + slices.Sort(out) + return out +} diff --git a/tool/claudecode/constants.go b/tool/claudecode/constants.go new file mode 100644 index 000000000..5129ecc7d --- /dev/null +++ b/tool/claudecode/constants.go @@ -0,0 +1,63 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "os/exec" + "strings" + "sync" + "time" +) + +const ( + toolBash = "Bash" + toolRead = "Read" + toolWrite = "Write" + toolEdit = "Edit" + toolGlob = "Glob" + toolGrep = "Grep" + toolWebFetch = "WebFetch" + toolWebSearch = "WebSearch" + toolTaskStop = "TaskStop" + toolTaskOutput = "TaskOutput" + + defaultToolSetName = "claudecode" + defaultGrepHeadLimit = 250 + defaultGlobHeadLimit = 100 + toolNotebookEdit = "NotebookEdit" + defaultHTTPTimeout = 30 * time.Second + defaultBashTimeoutMs = 120_000 + maxBashTimeoutMs = 600_000 + maxEditableFileSize = 1024 * 1024 * 1024 + pdfInlineReadThreshold = 10 + pdfMaxPagesPerRead = 20 +) + +var ( + envGoogleAPIKey = strings.Join([]string{"GOOGLE", "API", "KEY"}, "_") + envGoogleEngineID = strings.Join([]string{"GOOGLE", "SEARCH", "ENGINE", "ID"}, "_") + ripgrepOnce sync.Once + ripgrepPath string + ripgrepLookPath = func(file string) (string, error) { return exec.LookPath(file) } + pdftoppmPath string + pdftoppmOnce sync.Once + pdftoppmLookPath = func(file string) (string, error) { + return exec.LookPath(file) + } +) + +var grepExcludedDirs = []string{ + ".git", + ".svn", + ".hg", + ".bzr", + ".jj", + ".sl", +} diff --git a/tool/claudecode/edit.go b/tool/claudecode/edit.go new file mode 100644 index 000000000..e2436b5b2 --- /dev/null +++ b/tool/claudecode/edit.go @@ -0,0 +1,45 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newEditTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in editInput) (editOutput, error) { + baseDir := runtime.currentBaseDir() + _, absPath, err := normalizePath(baseDir, in.FilePath) + if err != nil { + return editOutput{}, err + } + runtime.fileState.mu.Lock() + defer runtime.fileState.mu.Unlock() + return editLocalFile(absPath, in, runtime) + }, + function.WithName(toolEdit), + function.WithDescription(editDescription()), + ), nil +} + +func editDescription() string { + return fmt.Sprintf(`Replace text inside an existing file. + +Usage: +- Always read the file with %s before editing it. +- Use this tool for targeted string replacements, not whole-file rewrites. Use %s when you want to replace the entire file. +- old_string must match the current file contents exactly. Missing matches, stale reads, or ambiguous matches are rejected. +- This tool does not edit notebooks. Use %s for .ipynb files.`, toolRead, toolWrite, toolNotebookEdit) +} diff --git a/tool/claudecode/file_state.go b/tool/claudecode/file_state.go new file mode 100644 index 000000000..7a1f11d2e --- /dev/null +++ b/tool/claudecode/file_state.go @@ -0,0 +1,349 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "fmt" + "mime" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "unicode" +) + +func readLocalFileSnapshot(absPath string, maxFileSize int64) (localFileSnapshot, error) { + info, err := os.Stat(absPath) + if err != nil { + if os.IsNotExist(err) { + return localFileSnapshot{Path: absPath}, nil + } + return localFileSnapshot{}, fmt.Errorf("stat file %q: %w", absPath, err) + } + if info.IsDir() { + return localFileSnapshot{}, fmt.Errorf("target path %q is a directory", absPath) + } + if maxFileSize > 0 && info.Size() > maxFileSize { + return localFileSnapshot{}, fmt.Errorf("file %q exceeds max size of %d bytes", absPath, maxFileSize) + } + raw, err := os.ReadFile(absPath) + if err != nil { + return localFileSnapshot{}, fmt.Errorf("read file %q: %w", absPath, err) + } + content, encoding, err := decodeTextBytes(raw) + if err != nil { + return localFileSnapshot{}, fmt.Errorf("decode file %q: %w", absPath, err) + } + mediaType := mime.TypeByExtension(strings.ToLower(filepath.Ext(absPath))) + if mediaType == "" { + mediaType = http.DetectContentType(raw) + } + return localFileSnapshot{ + Exists: true, + Path: absPath, + Raw: raw, + Content: content, + Mode: info.Mode(), + Timestamp: info.ModTime().UnixMilli(), + Encoding: encoding, + LineEnding: detectLineEnding(raw), + MediaType: mediaType, + OriginalSize: info.Size(), + }, nil +} + +func ensureWriteAllowed( + absPath string, + snapshot localFileSnapshot, + state *fileState, +) error { + view, ok := state.views[absPath] + if !ok || view.IsPartialView { + return fmt.Errorf("File has not been read yet. Read it first before writing to it.") + } + if snapshot.Timestamp > view.Timestamp { + isFullView := view.Offset == nil && view.Limit == nil && strings.TrimSpace(view.Pages) == "" + if !isFullView || snapshot.Content != view.Content { + return fmt.Errorf("File has been modified since read, either by the user or by a linter. Read it again before attempting to write it.") + } + } + return nil +} + +func writeLocalFile( + absPath string, + content string, + mode os.FileMode, + encoding string, + lineEnding string, +) error { + parentDir := filepath.Dir(absPath) + if err := os.MkdirAll(parentDir, 0o755); err != nil { + return fmt.Errorf("create directory for %q: %w", absPath, err) + } + fileMode := mode + if fileMode == 0 { + fileMode = 0o644 + } + encoded, err := encodeTextBytes(content, encoding, lineEnding) + if err != nil { + return err + } + if err := os.WriteFile(absPath, encoded, fileMode); err != nil { + return fmt.Errorf("write file %q: %w", absPath, err) + } + return nil +} + +func storeReadView( + state *fileState, + absPath string, + content string, + timestamp int64, + offset *int, + limit *int, + pages string, + isPartial bool, + fromRead bool, +) { + state.views[absPath] = fileView{ + Content: content, + Timestamp: timestamp, + Offset: offset, + Limit: limit, + Pages: pages, + IsPartialView: isPartial, + FromRead: fromRead, + } +} + +func matchesReadView( + view fileView, + offset *int, + limit *int, + pages string, +) bool { + if !view.FromRead { + return false + } + if !intPtrsEqual(view.Offset, offset) { + return false + } + if !intPtrsEqual(view.Limit, limit) { + return false + } + return strings.TrimSpace(view.Pages) == strings.TrimSpace(pages) +} + +func intPtrsEqual(left *int, right *int) bool { + switch { + case left == nil && right == nil: + return true + case left == nil || right == nil: + return false + default: + return *left == *right + } +} + +func normalizeQuotes(raw string) string { + replacer := strings.NewReplacer( + "‘", "'", + "’", "'", + "“", "\"", + "”", "\"", + ) + return replacer.Replace(raw) +} + +func findActualString(fileContent string, searchString string) string { + if strings.Contains(fileContent, searchString) { + return searchString + } + var builder strings.Builder + for _, r := range searchString { + switch r { + case '\'': + builder.WriteString("['‘’]") + case '"': + builder.WriteString("[\"“”]") + default: + builder.WriteString(regexp.QuoteMeta(string(r))) + } + } + re, err := regexp.Compile(builder.String()) + if err != nil { + return "" + } + return re.FindString(fileContent) +} + +func preserveQuoteStyle(oldString string, actualOldString string, newString string) string { + if oldString == actualOldString { + return newString + } + hasDoubleQuotes := strings.Contains(actualOldString, "“") || strings.Contains(actualOldString, "”") + hasSingleQuotes := strings.Contains(actualOldString, "‘") || strings.Contains(actualOldString, "’") + result := newString + if hasDoubleQuotes { + result = applyCurlyDoubleQuotes(result) + } + if hasSingleQuotes { + result = applyCurlySingleQuotes(result) + } + return result +} + +func applyCurlyDoubleQuotes(raw string) string { + chars := []rune(raw) + out := make([]rune, 0, len(chars)) + for idx, r := range chars { + if r != '"' { + out = append(out, r) + continue + } + if isOpeningQuote(chars, idx) { + out = append(out, '“') + continue + } + out = append(out, '”') + } + return string(out) +} + +func applyCurlySingleQuotes(raw string) string { + chars := []rune(raw) + out := make([]rune, 0, len(chars)) + for idx, r := range chars { + if r != '\'' { + out = append(out, r) + continue + } + prevIsLetter := idx > 0 && unicode.IsLetter(chars[idx-1]) + nextIsLetter := idx+1 < len(chars) && unicode.IsLetter(chars[idx+1]) + if prevIsLetter && nextIsLetter { + out = append(out, '’') + continue + } + if isOpeningQuote(chars, idx) { + out = append(out, '‘') + continue + } + out = append(out, '’') + } + return string(out) +} + +func isOpeningQuote(chars []rune, idx int) bool { + if idx == 0 { + return true + } + prev := chars[idx-1] + return unicode.IsSpace(prev) || strings.ContainsRune("([{", prev) +} + +func editLocalFile( + absPath string, + in editInput, + runtime *runtime, +) (editOutput, error) { + snapshot, err := readLocalFileSnapshot(absPath, maxEditableFileSize) + if err != nil { + return editOutput{}, err + } + if !snapshot.Exists { + if in.OldString != "" { + return editOutput{}, fmt.Errorf("File does not exist: %s", relativePath(runtime.currentBaseDir(), absPath)) + } + if err := writeLocalFile(absPath, in.NewString, 0, "utf8", "\n"); err != nil { + return editOutput{}, err + } + current, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return editOutput{}, err + } + storeReadView(runtime.fileState, absPath, current.Content, current.Timestamp, nil, nil, "", false, false) + return writeOutputToEditOutput(absPath, in, nil, in.NewString), nil + } + if strings.HasSuffix(strings.ToLower(absPath), ".ipynb") { + return editOutput{}, fmt.Errorf("File is a Jupyter Notebook. Use the %s tool to edit this file.", toolNotebookEdit) + } + if isProbablyBinary(snapshot.Raw) { + return editOutput{}, fmt.Errorf("This tool cannot edit binary files.") + } + if in.OldString == in.NewString { + return editOutput{}, fmt.Errorf("No changes to make: old_string and new_string are exactly the same.") + } + if err := ensureWriteAllowed(absPath, snapshot, runtime.fileState); err != nil { + return editOutput{}, err + } + if in.OldString == "" { + if strings.TrimSpace(snapshot.Content) != "" { + return editOutput{}, fmt.Errorf("Cannot create new file - file already exists.") + } + if err := writeLocalFile(absPath, in.NewString, snapshot.Mode, snapshot.Encoding, snapshot.LineEnding); err != nil { + return editOutput{}, err + } + current, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return editOutput{}, err + } + storeReadView(runtime.fileState, absPath, current.Content, current.Timestamp, nil, nil, "", false, false) + return writeOutputToEditOutput(absPath, in, &snapshot.Content, in.NewString), nil + } + actualOldString := findActualString(snapshot.Content, in.OldString) + if actualOldString == "" { + return editOutput{}, fmt.Errorf("String to replace not found in file.\nString: %s", in.OldString) + } + actualNewString := preserveQuoteStyle(in.OldString, actualOldString, in.NewString) + matchCount := strings.Count(snapshot.Content, actualOldString) + if matchCount > 1 && !in.ReplaceAll { + return editOutput{}, fmt.Errorf("Found %d matches of the string to replace, but replace_all is false. To replace all occurrences, set replace_all to true. To replace only one occurrence, please provide more context to uniquely identify the instance.\nString: %s", matchCount, in.OldString) + } + replacements := 1 + if in.ReplaceAll { + replacements = -1 + } + updated := strings.Replace(snapshot.Content, actualOldString, actualNewString, replacements) + if err := writeLocalFile(absPath, updated, snapshot.Mode, snapshot.Encoding, snapshot.LineEnding); err != nil { + return editOutput{}, err + } + current, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return editOutput{}, err + } + storeReadView(runtime.fileState, absPath, current.Content, current.Timestamp, nil, nil, "", false, false) + return editOutput{ + FilePath: absPath, + OldString: in.OldString, + NewString: in.NewString, + OriginalFile: snapshot.Content, + StructuredPatch: buildStructuredPatch(snapshot.Content, updated), + UserModified: false, + ReplaceAll: in.ReplaceAll, + }, nil +} + +func writeOutputToEditOutput(absPath string, in editInput, oldContent *string, newContent string) editOutput { + originalFile := "" + if oldContent != nil { + originalFile = *oldContent + } + return editOutput{ + FilePath: absPath, + OldString: in.OldString, + NewString: in.NewString, + OriginalFile: originalFile, + StructuredPatch: buildStructuredPatch(originalFile, newContent), + UserModified: false, + ReplaceAll: in.ReplaceAll, + } +} diff --git a/tool/claudecode/glob.go b/tool/claudecode/glob.go new file mode 100644 index 000000000..ccec0bb71 --- /dev/null +++ b/tool/claudecode/glob.go @@ -0,0 +1,88 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/bmatcuk/doublestar/v4" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newGlobTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in globInput) (globOutput, error) { + baseDir := runtime.currentBaseDir() + start := time.Now() + searchDir := baseDir + searchRel := "" + if in.Path != "" { + relPath, absPath, err := normalizePath(baseDir, in.Path) + if err != nil { + return globOutput{}, err + } + info, err := os.Stat(absPath) + if err != nil { + if os.IsNotExist(err) { + return globOutput{}, fmt.Errorf("Directory does not exist: %s", in.Path) + } + return globOutput{}, err + } + if !info.IsDir() { + return globOutput{}, fmt.Errorf("Path is not a directory: %s", in.Path) + } + searchDir = absPath + searchRel = relPath + } + matches, err := doublestar.Glob(os.DirFS(searchDir), in.Pattern, doublestar.WithCaseInsensitive()) + if err != nil { + return globOutput{}, err + } + sorted := sortedCopy(matches) + truncated := false + if len(sorted) > defaultGlobHeadLimit { + sorted = sorted[:defaultGlobHeadLimit] + truncated = true + } + filenames := make([]string, 0, len(sorted)) + for _, match := range sorted { + fullRel := match + if searchRel != "" { + fullRel = filepath.ToSlash(filepath.Join(searchRel, match)) + } + filenames = append(filenames, filepath.ToSlash(filepath.Clean(fullRel))) + } + return globOutput{ + DurationMs: max(time.Since(start).Milliseconds(), 1), + NumFiles: len(filenames), + Filenames: filenames, + Truncated: truncated, + }, nil + }, + function.WithName(toolGlob), + function.WithDescription(globDescription()), + ), nil +} + +func globDescription() string { + return fmt.Sprintf(`Fast file pattern matching for workspace paths. + +Usage: +- Use %s to find files by name or path pattern. +- pattern uses doublestar-style globs such as "*.go" or "**/*.ts". +- path optionally narrows the search to a specific directory. +- Results are sorted, and the tool returns at most %d filenames per call. +- Use %s instead when you need to search file contents rather than file names.`, toolGlob, defaultGlobHeadLimit, toolGrep) +} diff --git a/tool/claudecode/go.mod b/tool/claudecode/go.mod new file mode 100644 index 000000000..84dbdd422 --- /dev/null +++ b/tool/claudecode/go.mod @@ -0,0 +1,28 @@ +module trpc.group/trpc-go/trpc-agent-go/tool/claudecode + +go 1.24.1 + +replace trpc.group/trpc-go/trpc-agent-go => ../.. + +require ( + github.com/bmatcuk/doublestar/v4 v4.9.1 + github.com/go-pdf/fpdf v0.9.0 + github.com/google/uuid v1.6.0 + github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 + github.com/stretchr/testify v1.11.1 + golang.org/x/net v0.34.0 + golang.org/x/text v0.21.0 + trpc.group/trpc-go/trpc-agent-go v0.5.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + trpc.group/trpc-go/trpc-a2a-go v0.2.5 // indirect +) diff --git a/tool/claudecode/go.sum b/tool/claudecode/go.sum new file mode 100644 index 000000000..a57ca9e9d --- /dev/null +++ b/tool/claudecode/go.sum @@ -0,0 +1,43 @@ +github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE= +github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-pdf/fpdf v0.9.0 h1:PPvSaUuo1iMi9KkaAn90NuKi+P4gwMedWPHhj8YlJQw= +github.com/go-pdf/fpdf v0.9.0/go.mod h1:oO8N111TkmKb9D7VvWGLvLJlaZUQVPM+6V42pp3iV4Y= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+ExRDqGQltzXqN/xypdKP86niVn8= +github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +trpc.group/trpc-go/trpc-a2a-go v0.2.5 h1:X3pAlWD128LaS9TtXsUDZoJWPVuPZDkZKUecKRxmWn4= +trpc.group/trpc-go/trpc-a2a-go v0.2.5/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk= diff --git a/tool/claudecode/grep.go b/tool/claudecode/grep.go new file mode 100644 index 000000000..da252f97c --- /dev/null +++ b/tool/claudecode/grep.go @@ -0,0 +1,732 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "os" + "path/filepath" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/bmatcuk/doublestar/v4" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newGrepTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(ctx context.Context, in grepInput) (grepOutput, error) { + baseDir := runtime.currentBaseDir() + if out, handled, err := runLocalRipgrep(ctx, baseDir, in); handled { + return out, err + } + return runFallbackGrep(runtime, baseDir, in) + }, + function.WithName(toolGrep), + function.WithDescription(grepDescription()), + ), nil +} + +func grepDescription() string { + return fmt.Sprintf(`Powerful repository search built on ripgrep when available, with a workspace-safe fallback. + +Usage: +- ALWAYS use %s for repository search tasks. NEVER invoke grep or rg through %s for normal code search. +- Supports regular expressions. +- Use glob or type to narrow the file set. +- output_mode supports "content", "files_with_matches", and "count". +- Use multiline=true for patterns that must match across line boundaries. +- Use A, B, C, or context to request surrounding lines in content mode. +- Use head_limit and offset to page through large result sets.`, toolGrep, toolBash) +} + +func runFallbackGrep(runtime *runtime, baseDir string, in grepInput) (grepOutput, error) { + mode := strings.TrimSpace(in.OutputMode) + if mode == "" { + mode = "files_with_matches" + } + re, err := compileGrepPattern(in) + if err != nil { + return grepOutput{}, err + } + candidates, err := collectGrepCandidates(baseDir, in.Path, in.Glob, in.Type) + if err != nil { + return grepOutput{}, err + } + collector := newFallbackGrepCollector(mode) + for _, absPath := range candidates { + if err := collectFallbackGrepMatch(runtime, baseDir, absPath, re, in, collector); err != nil { + return grepOutput{}, err + } + } + return finalizeFallbackGrepOutput(baseDir, in, collector), nil +} + +type fallbackGrepCollector struct { + mode string + contentLines []string + countLines []string + fileMatches []string +} + +func newFallbackGrepCollector(mode string) *fallbackGrepCollector { + return &fallbackGrepCollector{ + mode: mode, + contentLines: make([]string, 0), + countLines: make([]string, 0), + fileMatches: make([]string, 0), + } +} + +func compileGrepPattern(in grepInput) (*regexp.Regexp, error) { + pattern := in.Pattern + if in.IgnoreCase != nil && *in.IgnoreCase { + pattern = "(?i)" + pattern + } + if in.Multiline { + pattern = "(?s)" + pattern + } + return regexp.Compile(pattern) +} + +func collectFallbackGrepMatch( + runtime *runtime, + baseDir string, + absPath string, + re *regexp.Regexp, + in grepInput, + collector *fallbackGrepCollector, +) error { + snapshot, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil || isProbablyBinary(snapshot.Raw) { + return nil + } + relPath := relativePath(baseDir, absPath) + if in.Multiline { + return collectFallbackMultilineMatch(snapshot.Content, relPath, re, in, collector) + } + return collectFallbackLineMatch(snapshot.Content, relPath, re, in, collector) +} + +func collectFallbackMultilineMatch( + content string, + relPath string, + re *regexp.Regexp, + in grepInput, + collector *fallbackGrepCollector, +) error { + matches := re.FindAllStringIndex(content, -1) + if len(matches) == 0 { + return nil + } + lines := splitTextLines(content) + switch collector.mode { + case "count": + collector.countLines = append(collector.countLines, fmt.Sprintf("%s:%d", relPath, len(matches))) + case "content": + matchedIndexes := multilineMatchLineIndexes(content, matches, len(lines)) + appendGrepContentLines(&collector.contentLines, relPath, lines, expandContextLines(matchedIndexes, len(lines), in), showGrepLineNumbers(in)) + default: + collector.fileMatches = append(collector.fileMatches, relPath) + } + return nil +} + +func collectFallbackLineMatch( + content string, + relPath string, + re *regexp.Regexp, + in grepInput, + collector *fallbackGrepCollector, +) error { + lines := splitTextLines(content) + matchedIndexes := make([]int, 0) + for idx, line := range lines { + if re.MatchString(line) { + matchedIndexes = append(matchedIndexes, idx) + } + } + if len(matchedIndexes) == 0 { + return nil + } + switch collector.mode { + case "count": + collector.countLines = append(collector.countLines, fmt.Sprintf("%s:%d", relPath, len(matchedIndexes))) + case "content": + appendGrepContentLines(&collector.contentLines, relPath, lines, expandContextLines(matchedIndexes, len(lines), in), showGrepLineNumbers(in)) + default: + collector.fileMatches = append(collector.fileMatches, relPath) + } + return nil +} + +func appendGrepContentLines(out *[]string, relPath string, lines []string, indexes []int, showLineNumbers bool) { + for _, idx := range indexes { + if showLineNumbers { + *out = append(*out, fmt.Sprintf("%s:%d:%s", relPath, idx+1, lines[idx])) + continue + } + *out = append(*out, fmt.Sprintf("%s:%s", relPath, lines[idx])) + } +} + +func showGrepLineNumbers(in grepInput) bool { + if in.ShowLineNum != nil { + return *in.ShowLineNum + } + return true +} + +func finalizeFallbackGrepOutput(baseDir string, in grepInput, collector *fallbackGrepCollector) grepOutput { + offset := grepOffset(in) + limit := grepLimit(in) + switch collector.mode { + case "count": + return finalizeFallbackCountOutput(offset, limit, collector.countLines) + case "content": + sliced, appliedLimit := sliceStrings(collector.contentLines, offset, limit) + return grepOutput{ + Mode: "content", + NumFiles: 0, + Filenames: []string{}, + Content: strings.Join(sliced, "\n"), + NumLines: len(sliced), + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } + default: + sortedFiles := sortGrepPathsByMtime(baseDir, collector.fileMatches) + sliced, appliedLimit := sliceStrings(sortedFiles, offset, limit) + return grepOutput{ + Mode: "files_with_matches", + NumFiles: len(sliced), + Filenames: sliced, + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } + } +} + +func finalizeFallbackCountOutput(offset int, limit int, countLines []string) grepOutput { + sortedCounts := sortedCopy(countLines) + sliced, appliedLimit := sliceStrings(sortedCounts, offset, limit) + totalMatches := 0 + for _, line := range sliced { + _, count, ok := parseGrepCountLine(line) + if ok { + totalMatches += count + } + } + return grepOutput{ + Mode: "count", + NumFiles: len(sliced), + Content: strings.Join(sliced, "\n"), + NumMatches: totalMatches, + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } +} + +func multilineMatchLineIndexes(content string, matches [][]int, totalLines int) []int { + if len(matches) == 0 || totalLines <= 0 { + return nil + } + lineStarts := make([]int, 1, totalLines) + lineStarts[0] = 0 + for idx, r := range content { + if r != '\n' || len(lineStarts) == totalLines { + continue + } + lineStarts = append(lineStarts, idx+1) + } + seen := make(map[int]struct{}, len(matches)) + indexes := make([]int, 0, len(matches)*2) + for _, match := range matches { + if len(match) != 2 { + continue + } + startLine := lineIndexForOffset(lineStarts, match[0]) + endOffset := match[1] - 1 + if endOffset < match[0] { + endOffset = match[0] + } + endLine := lineIndexForOffset(lineStarts, endOffset) + for line := startLine; line <= endLine; line++ { + if _, ok := seen[line]; ok { + continue + } + seen[line] = struct{}{} + indexes = append(indexes, line) + } + } + slices.Sort(indexes) + return indexes +} + +func lineIndexForOffset(lineStarts []int, offset int) int { + if len(lineStarts) == 0 || offset <= 0 { + return 0 + } + lineIndex := 0 + for idx := 1; idx < len(lineStarts); idx++ { + if lineStarts[idx] > offset { + break + } + lineIndex = idx + } + return lineIndex +} + +func collectGrepCandidates(baseDir string, pathValue string, globValue string, typeValue string) ([]string, error) { + searchRoot := baseDir + if strings.TrimSpace(pathValue) != "" { + _, absPath, err := normalizePath(baseDir, pathValue) + if err != nil { + return nil, err + } + searchRoot = absPath + } + info, err := os.Stat(searchRoot) + if err != nil { + return nil, err + } + if !info.IsDir() { + return []string{searchRoot}, nil + } + globPatterns := splitGlobPatterns(globValue) + if len(globPatterns) == 0 { + globPatterns = []string{"**/*"} + } + typePatterns := typePatterns(typeValue) + candidates := make([]string, 0, 64) + err = filepath.WalkDir(searchRoot, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + for _, excluded := range grepExcludedDirs { + if d.Name() == excluded { + return filepath.SkipDir + } + } + return nil + } + relToRoot, err := filepath.Rel(searchRoot, path) + if err != nil { + return err + } + relToRoot = filepath.ToSlash(relToRoot) + if !matchesAnyPattern(relToRoot, globPatterns) { + return nil + } + if len(typePatterns) > 0 && !matchesAnyPattern(relToRoot, typePatterns) { + return nil + } + candidates = append(candidates, path) + return nil + }) + if err != nil { + return nil, err + } + slices.Sort(candidates) + return candidates, nil +} + +func matchesAnyPattern(value string, patterns []string) bool { + for _, pattern := range patterns { + ok, err := doublestar.PathMatch(pattern, value) + if err == nil && ok { + return true + } + } + return false +} + +func expandContextLines(matches []int, total int, in grepInput) []int { + before := 0 + after := 0 + switch { + case in.ContextAlt != nil: + before = *in.ContextAlt + after = *in.ContextAlt + case in.Context != nil: + before = *in.Context + after = *in.Context + default: + if in.Before != nil { + before = *in.Before + } + if in.After != nil { + after = *in.After + } + } + seen := make(map[int]struct{}, len(matches)) + indexes := make([]int, 0, len(matches)) + for _, idx := range matches { + start := idx - before + if start < 0 { + start = 0 + } + end := idx + after + if end >= total { + end = total - 1 + } + for line := start; line <= end; line++ { + if _, ok := seen[line]; ok { + continue + } + seen[line] = struct{}{} + indexes = append(indexes, line) + } + } + slices.Sort(indexes) + return indexes +} + +func typePatterns(typeValue string) []string { + switch strings.TrimSpace(strings.ToLower(typeValue)) { + case "": + return nil + case "go": + return []string{"**/*.go"} + case "js": + return []string{"**/*.js", "**/*.cjs", "**/*.mjs"} + case "ts": + return []string{"**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"} + case "py": + return []string{"**/*.py"} + case "java": + return []string{"**/*.java"} + case "rust", "rs": + return []string{"**/*.rs"} + case "json": + return []string{"**/*.json"} + case "md", "markdown": + return []string{"**/*.md"} + case "yaml", "yml": + return []string{"**/*.yaml", "**/*.yml"} + case "txt": + return []string{"**/*.txt"} + default: + return []string{"**/*." + strings.TrimPrefix(strings.ToLower(typeValue), ".")} + } +} + +func sliceStrings(items []string, offset int, limit int) ([]string, *int) { + if offset < 0 { + offset = 0 + } + if offset >= len(items) { + return []string{}, nil + } + remaining := items[offset:] + if limit == 0 { + return remaining, nil + } + if limit < 0 || limit >= len(remaining) { + return remaining, nil + } + appliedLimit := limit + return remaining[:limit], &appliedLimit +} + +func runLocalRipgrep( + ctx context.Context, + baseDir string, + in grepInput, +) (grepOutput, bool, error) { + if ripgrepCommand() == "" { + return grepOutput{}, false, nil + } + baseAbs, err := filepath.Abs(baseDir) + if err != nil { + return grepOutput{}, true, err + } + targetPath := "." + if strings.TrimSpace(in.Path) != "" { + relPath, _, err := normalizePath(baseDir, in.Path) + if err != nil { + return grepOutput{}, true, err + } + if relPath != "" { + targetPath = relPath + } + } + return runRipgrepCommand(ctx, baseAbs, targetPath, in) +} + +func runRipgrepCommand( + ctx context.Context, + baseAbs string, + targetPath string, + in grepInput, +) (grepOutput, bool, error) { + mode := strings.TrimSpace(in.OutputMode) + if mode == "" { + mode = "files_with_matches" + } + lines, err := execRipgrep(ctx, baseAbs, buildRipgrepArgs(mode, targetPath, in)...) + if err != nil { + return grepOutput{}, true, err + } + return formatRipgrepOutput(baseAbs, mode, in, lines), true, nil +} + +func buildRipgrepArgs(mode string, targetPath string, in grepInput) []string { + args := []string{"--hidden", "--max-columns", "500"} + args = appendRipgrepExcludes(args) + args = appendRipgrepMode(args, mode, in) + args = appendRipgrepPattern(args, in.Pattern) + if strings.TrimSpace(in.Type) != "" { + args = append(args, "--type", strings.TrimSpace(in.Type)) + } + for _, pattern := range splitGlobPatterns(in.Glob) { + args = append(args, "--glob", pattern) + } + return append(args, targetPath) +} + +func appendRipgrepExcludes(args []string) []string { + for _, dir := range grepExcludedDirs { + args = append(args, "--glob", "!"+dir) + } + return args +} + +func appendRipgrepMode(args []string, mode string, in grepInput) []string { + if in.Multiline { + args = append(args, "-U", "--multiline-dotall") + } + if in.IgnoreCase != nil && *in.IgnoreCase { + args = append(args, "-i") + } + switch mode { + case "files_with_matches": + args = append(args, "-l") + case "count": + args = append(args, "-c") + case "content": + if showGrepLineNumbers(in) { + args = append(args, "-n") + } + args = appendRipgrepContext(args, in) + } + return args +} + +func appendRipgrepContext(args []string, in grepInput) []string { + switch { + case in.ContextAlt != nil: + return append(args, "-C", strconv.Itoa(*in.ContextAlt)) + case in.Context != nil: + return append(args, "-C", strconv.Itoa(*in.Context)) + default: + if in.Before != nil { + args = append(args, "-B", strconv.Itoa(*in.Before)) + } + if in.After != nil { + args = append(args, "-A", strconv.Itoa(*in.After)) + } + return args + } +} + +func appendRipgrepPattern(args []string, pattern string) []string { + if strings.HasPrefix(pattern, "-") { + return append(args, "-e", pattern) + } + return append(args, pattern) +} + +func formatRipgrepOutput(baseAbs string, mode string, in grepInput, lines []string) grepOutput { + offset := grepOffset(in) + limit := grepLimit(in) + switch mode { + case "content": + sliced, appliedLimit := sliceStrings(lines, offset, limit) + return grepOutput{ + Mode: "content", + Content: strings.Join(sliced, "\n"), + NumLines: len(sliced), + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } + case "count": + return formatRipgrepCountOutput(offset, limit, lines) + default: + sorted := sortGrepPathsByMtime(baseAbs, lines) + sliced, appliedLimit := sliceStrings(sorted, offset, limit) + return grepOutput{ + Mode: "files_with_matches", + NumFiles: len(sliced), + Filenames: sliced, + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } + } +} + +func formatRipgrepCountOutput(offset int, limit int, lines []string) grepOutput { + formatted := make([]string, 0, len(lines)) + for _, line := range lines { + path, count, ok := parseGrepCountLine(line) + if ok { + formatted = append(formatted, fmt.Sprintf("%s:%d", path, count)) + continue + } + formatted = append(formatted, line) + } + sliced, appliedLimit := sliceStrings(formatted, offset, limit) + totalMatches := 0 + for _, line := range sliced { + _, count, ok := parseGrepCountLine(line) + if ok { + totalMatches += count + } + } + return grepOutput{ + Mode: "count", + NumFiles: len(sliced), + Content: strings.Join(sliced, "\n"), + NumMatches: totalMatches, + AppliedLimit: appliedLimit, + AppliedOffset: offset, + } +} + +func grepOffset(in grepInput) int { + if in.Offset != nil && *in.Offset > 0 { + return *in.Offset + } + return 0 +} + +func grepLimit(in grepInput) int { + if in.HeadLimit != nil { + return *in.HeadLimit + } + return defaultGrepHeadLimit +} + +func execRipgrep( + ctx context.Context, + baseAbs string, + args ...string, +) ([]string, error) { + result, err := runCapturedProcess(ctx, baseAbs, nil, ripgrepCommand(), args...) + if err != nil || result.ExitCode != 0 { + if result.ExitCode == 1 { + return []string{}, nil + } + msg := strings.TrimSpace(string(result.Stderr)) + if msg == "" { + if err != nil { + msg = err.Error() + } else { + msg = fmt.Sprintf("ripgrep exited with code %d", result.ExitCode) + } + } + return nil, fmt.Errorf("ripgrep failed: %s", msg) + } + return splitRipgrepLines(string(result.Stdout)), nil +} + +func splitRipgrepLines(raw string) []string { + lines := strings.Split(raw, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines { + line = strings.TrimRight(line, "\r") + if strings.TrimSpace(line) == "" { + continue + } + line = strings.TrimPrefix(line, "./") + line = strings.TrimPrefix(line, ".\\") + out = append(out, filepath.ToSlash(line)) + } + return out +} + +func splitGlobPatterns(raw string) []string { + clean := strings.TrimSpace(raw) + if clean == "" { + return nil + } + rawPatterns := strings.Fields(clean) + out := make([]string, 0, len(rawPatterns)) + for _, rawPattern := range rawPatterns { + if strings.Contains(rawPattern, "{") && strings.Contains(rawPattern, "}") { + out = append(out, rawPattern) + continue + } + for _, part := range strings.Split(rawPattern, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + out = append(out, part) + } + } + return out +} + +func sortGrepPathsByMtime(baseAbs string, paths []string) []string { + type fileEntry struct { + path string + mtime int64 + } + entries := make([]fileEntry, 0, len(paths)) + for _, path := range paths { + mtime := int64(0) + if info, err := os.Stat(filepath.Join(baseAbs, filepath.FromSlash(path))); err == nil { + mtime = info.ModTime().UnixMilli() + } + entries = append(entries, fileEntry{path: filepath.ToSlash(path), mtime: mtime}) + } + slices.SortFunc(entries, func(a, b fileEntry) int { + if a.mtime == b.mtime { + return strings.Compare(a.path, b.path) + } + if a.mtime > b.mtime { + return -1 + } + return 1 + }) + out := make([]string, 0, len(entries)) + for _, entry := range entries { + out = append(out, entry.path) + } + return out +} + +func parseGrepCountLine(line string) (string, int, bool) { + idx := strings.LastIndex(line, ":") + if idx <= 0 || idx >= len(line)-1 { + return "", 0, false + } + count, err := strconv.Atoi(strings.TrimSpace(line[idx+1:])) + if err != nil { + return "", 0, false + } + return line[:idx], count, true +} + +func ripgrepCommand() string { + ripgrepOnce.Do(func() { + path, err := ripgrepLookPath("rg") + if err == nil { + ripgrepPath = path + } + }) + return ripgrepPath +} diff --git a/tool/claudecode/notebook_edit.go b/tool/claudecode/notebook_edit.go new file mode 100644 index 000000000..eee573329 --- /dev/null +++ b/tool/claudecode/notebook_edit.go @@ -0,0 +1,379 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/google/uuid" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +type notebookEditInput struct { + NotebookPath string `json:"notebook_path"` + CellID string `json:"cell_id,omitempty"` + NewSource string `json:"new_source"` + CellType string `json:"cell_type,omitempty"` + EditMode string `json:"edit_mode,omitempty"` +} + +type notebookEditOutput struct { + NewSource string `json:"new_source"` + CellID string `json:"cell_id,omitempty"` + CellType string `json:"cell_type"` + Language string `json:"language"` + EditMode string `json:"edit_mode"` + NotebookPath string `json:"notebook_path"` + OriginalFile string `json:"original_file"` + UpdatedFile string `json:"updated_file"` +} + +type notebookEditState struct { + snapshot localFileSnapshot + notebook map[string]any + cells []map[string]any + editMode string + cellType string + language string + cellIndex int +} + +func newNotebookEditTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in notebookEditInput) (notebookEditOutput, error) { + baseDir := runtime.currentBaseDir() + _, absPath, err := normalizePath(baseDir, in.NotebookPath) + if err != nil { + return notebookEditOutput{}, err + } + runtime.fileState.mu.Lock() + defer runtime.fileState.mu.Unlock() + return editNotebook(absPath, in, runtime) + }, + function.WithName(toolNotebookEdit), + function.WithDescription(notebookEditDescription()), + ), nil +} + +func editNotebook( + absPath string, + in notebookEditInput, + runtime *runtime, +) (notebookEditOutput, error) { + state, err := loadNotebookEditState(absPath, in, runtime) + if err != nil { + return notebookEditOutput{}, err + } + resultCellID, resultCellType, err := applyNotebookEdit(&state, in) + if err != nil { + return notebookEditOutput{}, err + } + state.notebook["cells"] = notebookCellsAny(state.cells) + updatedContent, err := marshalNotebook(state.notebook) + if err != nil { + return notebookEditOutput{}, err + } + if err := writeLocalFile(absPath, updatedContent, state.snapshot.Mode, state.snapshot.Encoding, state.snapshot.LineEnding); err != nil { + return notebookEditOutput{}, err + } + current, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return notebookEditOutput{}, err + } + storeReadView(runtime.fileState, absPath, current.Content, current.Timestamp, nil, nil, "", false, false) + return notebookEditOutput{ + NewSource: in.NewSource, + CellID: resultCellID, + CellType: resultCellType, + Language: state.language, + EditMode: state.editMode, + NotebookPath: absPath, + OriginalFile: state.snapshot.Content, + UpdatedFile: updatedContent, + }, nil +} + +func loadNotebookEditState(absPath string, in notebookEditInput, runtime *runtime) (notebookEditState, error) { + if !strings.HasSuffix(strings.ToLower(absPath), ".ipynb") { + return notebookEditState{}, fmt.Errorf("File must be a Jupyter notebook (.ipynb file).") + } + editMode := strings.TrimSpace(strings.ToLower(in.EditMode)) + if editMode == "" { + editMode = "replace" + } + if editMode != "replace" && editMode != "insert" && editMode != "delete" { + return notebookEditState{}, fmt.Errorf("Edit mode must be replace, insert, or delete.") + } + cellType, err := normalizeNotebookCellType(in.CellType) + if err != nil { + return notebookEditState{}, err + } + if editMode == "insert" && cellType == "" { + return notebookEditState{}, fmt.Errorf("Cell type is required when using edit_mode=insert.") + } + snapshot, err := readLocalFileSnapshot(absPath, maxEditableFileSize) + if err != nil { + return notebookEditState{}, err + } + if !snapshot.Exists { + return notebookEditState{}, fmt.Errorf("Notebook file does not exist.") + } + if err := ensureWriteAllowed(absPath, snapshot, runtime.fileState); err != nil { + return notebookEditState{}, err + } + notebook, cells, err := parseNotebook(snapshot.Raw) + if err != nil { + return notebookEditState{}, fmt.Errorf("Notebook is not valid JSON.") + } + cellIndex, err := notebookCellIndex(cells, in.CellID) + if err != nil { + return notebookEditState{}, err + } + if in.CellID == "" && editMode != "insert" { + return notebookEditState{}, fmt.Errorf("Cell ID must be specified when not inserting a new cell.") + } + if cellIndex > len(cells) { + return notebookEditState{}, fmt.Errorf("Cell with index %d does not exist in notebook.", cellIndex) + } + if editMode == "replace" && cellIndex == len(cells) { + editMode = "insert" + if cellType == "" { + cellType = "code" + } + } + return notebookEditState{ + snapshot: snapshot, + notebook: notebook, + cells: cells, + editMode: editMode, + cellType: cellType, + language: notebookLanguage(notebook), + cellIndex: cellIndex, + }, nil +} + +func applyNotebookEdit(state *notebookEditState, in notebookEditInput) (string, string, error) { + switch state.editMode { + case "delete": + return deleteNotebookCell(state, in) + case "insert": + return insertNotebookCell(state, in), state.cellType, nil + default: + return replaceNotebookCell(state, in) + } +} + +func deleteNotebookCell(state *notebookEditState, in notebookEditInput) (string, string, error) { + if state.cellIndex >= len(state.cells) { + return "", "", fmt.Errorf("Cell with ID %q not found in notebook.", strings.TrimSpace(in.CellID)) + } + resultCellID := notebookResultCellID(state.cells[state.cellIndex], state.cellIndex, in.CellID) + resultCellType := notebookCellType(state.cells[state.cellIndex], "code") + state.cells = append(state.cells[:state.cellIndex], state.cells[state.cellIndex+1:]...) + return resultCellID, resultCellType, nil +} + +func insertNotebookCell(state *notebookEditState, in notebookEditInput) string { + insertAt := 0 + if strings.TrimSpace(in.CellID) != "" { + insertAt = state.cellIndex + 1 + if insertAt > len(state.cells) { + insertAt = len(state.cells) + } + } + newCell := newNotebookCell(state.cellType, in.NewSource, notebookSupportsCellIDs(state.notebook)) + resultCellID := notebookResultCellID(newCell, insertAt, "") + state.cells = append(state.cells[:insertAt], append([]map[string]any{newCell}, state.cells[insertAt:]...)...) + return resultCellID +} + +func replaceNotebookCell(state *notebookEditState, in notebookEditInput) (string, string, error) { + if state.cellIndex >= len(state.cells) { + return "", "", fmt.Errorf("Cell with ID %q not found in notebook.", strings.TrimSpace(in.CellID)) + } + target := state.cells[state.cellIndex] + resultCellID := notebookResultCellID(target, state.cellIndex, in.CellID) + if state.cellType == "" { + state.cellType = notebookCellType(target, "code") + } + target["cell_type"] = state.cellType + target["source"] = in.NewSource + if state.cellType == "code" { + target["execution_count"] = nil + target["outputs"] = []any{} + return resultCellID, state.cellType, nil + } + delete(target, "execution_count") + delete(target, "outputs") + return resultCellID, state.cellType, nil +} + +func parseNotebook(raw []byte) (map[string]any, []map[string]any, error) { + var notebook map[string]any + if err := json.Unmarshal(raw, ¬ebook); err != nil { + return nil, nil, err + } + rawCells, ok := notebook["cells"].([]any) + if !ok { + return nil, nil, fmt.Errorf("notebook cells are invalid") + } + cells := make([]map[string]any, 0, len(rawCells)) + for _, rawCell := range rawCells { + cell, ok := rawCell.(map[string]any) + if !ok { + return nil, nil, fmt.Errorf("notebook cell is invalid") + } + cells = append(cells, cell) + } + return notebook, cells, nil +} + +func notebookCellIndex(cells []map[string]any, cellID string) (int, error) { + trimmed := strings.TrimSpace(cellID) + if trimmed == "" { + return 0, nil + } + for idx, cell := range cells { + if value, ok := cell["id"].(string); ok && value == trimmed { + return idx, nil + } + } + if parsed, ok := parseNotebookCellID(trimmed); ok { + return parsed, nil + } + return -1, fmt.Errorf("Cell with ID %q not found in notebook.", trimmed) +} + +func parseNotebookCellID(raw string) (int, bool) { + trimmed := strings.TrimSpace(raw) + if strings.HasPrefix(trimmed, "cell-") { + trimmed = strings.TrimPrefix(trimmed, "cell-") + } + value, err := strconv.Atoi(trimmed) + if err != nil || value < 0 { + return 0, false + } + return value, true +} + +func normalizeNotebookCellType(raw string) (string, error) { + trimmed := strings.TrimSpace(strings.ToLower(raw)) + if trimmed == "" { + return "", nil + } + if trimmed != "code" && trimmed != "markdown" { + return "", fmt.Errorf("Cell type must be code or markdown.") + } + return trimmed, nil +} + +func notebookLanguage(notebook map[string]any) string { + metadata, ok := notebook["metadata"].(map[string]any) + if !ok { + return "python" + } + languageInfo, ok := metadata["language_info"].(map[string]any) + if !ok { + return "python" + } + name, _ := languageInfo["name"].(string) + if strings.TrimSpace(name) == "" { + return "python" + } + return name +} + +func notebookSupportsCellIDs(notebook map[string]any) bool { + nbformat, _ := notebookInt(notebook["nbformat"]) + nbformatMinor, _ := notebookInt(notebook["nbformat_minor"]) + return nbformat > 4 || (nbformat == 4 && nbformatMinor >= 5) +} + +func notebookInt(raw any) (int, bool) { + switch value := raw.(type) { + case float64: + return int(value), true + case int: + return value, true + default: + return 0, false + } +} + +func newNotebookCell(cellType string, source string, includeID bool) map[string]any { + cell := map[string]any{ + "cell_type": cellType, + "metadata": map[string]any{}, + "source": source, + } + if includeID { + cell["id"] = uuid.NewString()[:12] + } + if cellType == "code" { + cell["execution_count"] = nil + cell["outputs"] = []any{} + } + return cell +} + +func notebookResultCellID(cell map[string]any, cellIndex int, fallback string) string { + if value, ok := cell["id"].(string); ok && strings.TrimSpace(value) != "" { + return value + } + if strings.TrimSpace(fallback) != "" { + return strings.TrimSpace(fallback) + } + return fmt.Sprintf("cell-%d", cellIndex) +} + +func notebookCellType(cell map[string]any, fallback string) string { + value, _ := cell["cell_type"].(string) + value = strings.TrimSpace(strings.ToLower(value)) + if value == "" { + return fallback + } + return value +} + +func notebookCellsAny(cells []map[string]any) []any { + out := make([]any, 0, len(cells)) + for _, cell := range cells { + out = append(out, cell) + } + return out +} + +func marshalNotebook(notebook map[string]any) (string, error) { + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + encoder.SetEscapeHTML(false) + encoder.SetIndent("", " ") + if err := encoder.Encode(notebook); err != nil { + return "", err + } + return strings.TrimSuffix(buf.String(), "\n"), nil +} + +func notebookEditDescription() string { + return fmt.Sprintf(`Edit cells inside a Jupyter notebook. + +Usage: +- Use this tool for .ipynb files instead of %s or %s when you want cell-aware edits. +- Always read the notebook with %s before editing it. +- Supports replace, insert, and delete operations on notebook cells. +- Use cell_id to target an existing cell. Insert operations can create a new cell and return the resulting cell ID. +- Notebook edits participate in the same stale-check protections as other file-writing tools.`, toolEdit, toolWrite, toolRead) +} diff --git a/tool/claudecode/options.go b/tool/claudecode/options.go new file mode 100644 index 000000000..3909d7697 --- /dev/null +++ b/tool/claudecode/options.go @@ -0,0 +1,112 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +// Package claudecode implements a Claude Code-compatible toolset. +package claudecode + +import ( + "context" + "time" +) + +// Option mutates toolset construction options. +type Option func(*toolSetOptions) + +// WebFetchOptions configures the WebFetch tool behavior. +type WebFetchOptions struct { + AllowedDomains []string + BlockedDomains []string + AllowAll bool + Timeout time.Duration + MaxContentLength int + MaxTotalContentLength int + PromptProcessor WebFetchPromptProcessor +} + +// WebSearchOptions configures the WebSearch tool behavior. +type WebSearchOptions struct { + Provider string + BaseURL string + UserAgent string + Timeout time.Duration + APIKey string + EngineID string + Size int + Offset int + Lang string +} + +// WebFetchPromptInput describes the extracted page content passed to a prompt processor. +type WebFetchPromptInput struct { + URL string + Prompt string + Content string + ContentType string +} + +// WebFetchPromptProcessor transforms fetched content into a prompt-aware result. +type WebFetchPromptProcessor func(context.Context, WebFetchPromptInput) (string, error) + +type toolSetOptions struct { + name string + baseDir string + readOnly bool + maxFileSize int64 + webFetch WebFetchOptions + webSearch *WebSearchOptions + hasMaxSize bool + hasWebFetch bool + hasWebSearch bool +} + +// WithBaseDir sets the base directory used by the toolset. +func WithBaseDir(baseDir string) Option { + return func(options *toolSetOptions) { + options.baseDir = baseDir + } +} + +// WithName overrides the toolset name. +func WithName(name string) Option { + return func(options *toolSetOptions) { + options.name = name + } +} + +// WithReadOnly disables mutating tools when set to true. +func WithReadOnly(readOnly bool) Option { + return func(options *toolSetOptions) { + options.readOnly = readOnly + } +} + +// WithMaxFileSize sets the maximum readable file size in bytes. +func WithMaxFileSize(maxFileSize int64) Option { + return func(options *toolSetOptions) { + options.maxFileSize = maxFileSize + options.hasMaxSize = true + } +} + +// WithWebFetchOptions overrides WebFetch options. +func WithWebFetchOptions(webFetch WebFetchOptions) Option { + return func(options *toolSetOptions) { + options.webFetch = webFetch + options.hasWebFetch = true + } +} + +// WithWebSearchOptions overrides WebSearch options. +func WithWebSearchOptions(webSearch WebSearchOptions) Option { + return func(options *toolSetOptions) { + webSearchCopy := webSearch + options.webSearch = &webSearchCopy + options.hasWebSearch = true + } +} diff --git a/tool/claudecode/pdf.go b/tool/claudecode/pdf.go new file mode 100644 index 000000000..603c8b00f --- /dev/null +++ b/tool/claudecode/pdf.go @@ -0,0 +1,144 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/ledongthuc/pdf" +) + +type pdfPageRange struct { + FirstPage int + LastPage int + Count int +} + +func pdfPageCount(raw []byte) (int, error) { + reader, err := pdf.NewReader(bytes.NewReader(raw), int64(len(raw))) + if err != nil { + return 0, fmt.Errorf("failed to create PDF reader: %w", err) + } + return reader.NumPage(), nil +} + +func resolvePDFPageRange(raw string, totalPages int) (pdfPageRange, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return pdfPageRange{}, fmt.Errorf("Invalid pages parameter: %q. Use formats like \"1-5\", \"3\", or \"10-20\". Pages are 1-indexed.", raw) + } + if strings.HasSuffix(trimmed, "-") { + firstPage, err := strconv.Atoi(strings.TrimSpace(strings.TrimSuffix(trimmed, "-"))) + if err != nil || firstPage < 1 { + return pdfPageRange{}, fmt.Errorf("Invalid pages parameter: %q. Use formats like \"1-5\", \"3\", or \"10-20\". Pages are 1-indexed.", raw) + } + if totalPages > 0 && firstPage > totalPages { + return pdfPageRange{}, fmt.Errorf("Page range %q is outside the PDF page count of %d.", raw, totalPages) + } + lastPage := totalPages + if lastPage < firstPage { + lastPage = firstPage + } + return validatedPDFPageRange(raw, firstPage, lastPage) + } + if dashIndex := strings.Index(trimmed, "-"); dashIndex >= 0 { + firstPage, firstErr := strconv.Atoi(strings.TrimSpace(trimmed[:dashIndex])) + lastPage, lastErr := strconv.Atoi(strings.TrimSpace(trimmed[dashIndex+1:])) + if firstErr != nil || lastErr != nil || firstPage < 1 || lastPage < firstPage { + return pdfPageRange{}, fmt.Errorf("Invalid pages parameter: %q. Use formats like \"1-5\", \"3\", or \"10-20\". Pages are 1-indexed.", raw) + } + if totalPages > 0 && lastPage > totalPages { + return pdfPageRange{}, fmt.Errorf("Page range %q exceeds the PDF page count of %d.", raw, totalPages) + } + return validatedPDFPageRange(raw, firstPage, lastPage) + } + page, err := strconv.Atoi(trimmed) + if err != nil || page < 1 { + return pdfPageRange{}, fmt.Errorf("Invalid pages parameter: %q. Use formats like \"1-5\", \"3\", or \"10-20\". Pages are 1-indexed.", raw) + } + if totalPages > 0 && page > totalPages { + return pdfPageRange{}, fmt.Errorf("Page %d exceeds the PDF page count of %d.", page, totalPages) + } + return validatedPDFPageRange(raw, page, page) +} + +func validatedPDFPageRange(raw string, firstPage int, lastPage int) (pdfPageRange, error) { + count := lastPage - firstPage + 1 + if count > pdfMaxPagesPerRead { + return pdfPageRange{}, fmt.Errorf("Page range %q exceeds maximum of %d pages per request. Please use a smaller range.", raw, pdfMaxPagesPerRead) + } + return pdfPageRange{ + FirstPage: firstPage, + LastPage: lastPage, + Count: count, + }, nil +} + +func pdftoppmBinary() (string, error) { + pdftoppmOnce.Do(func() { + path, err := pdftoppmLookPath("pdftoppm") + if err == nil { + pdftoppmPath = path + } + }) + if strings.TrimSpace(pdftoppmPath) == "" { + return "", fmt.Errorf("pdftoppm is not installed. Install poppler-utils (e.g. `brew install poppler` or `apt-get install poppler-utils`) to enable PDF page rendering.") + } + return pdftoppmPath, nil +} + +func extractPDFPages( + filePath string, + pageRange pdfPageRange, +) (string, int, error) { + pdftoppmPath, err := pdftoppmBinary() + if err != nil { + return "", 0, err + } + outputDir, err := os.MkdirTemp("", "claudecode-pdf-*") + if err != nil { + return "", 0, err + } + outputPrefix := filepath.Join(outputDir, "page") + args := []string{ + "-jpeg", + "-f", strconv.Itoa(pageRange.FirstPage), + "-l", strconv.Itoa(pageRange.LastPage), + filePath, + outputPrefix, + } + output, err := exec.Command(pdftoppmPath, args...).CombinedOutput() + if err != nil { + _ = os.RemoveAll(outputDir) + message := strings.TrimSpace(string(output)) + if message == "" { + return "", 0, fmt.Errorf("failed to extract PDF pages: %w", err) + } + return "", 0, fmt.Errorf("failed to extract PDF pages: %s", message) + } + imageFiles, err := filepath.Glob(outputPrefix + "-*.jpg") + if err != nil { + _ = os.RemoveAll(outputDir) + return "", 0, err + } + sort.Strings(imageFiles) + if len(imageFiles) == 0 { + _ = os.RemoveAll(outputDir) + return "", 0, fmt.Errorf("failed to extract PDF pages: no rendered page images were produced") + } + return outputDir, len(imageFiles), nil +} diff --git a/tool/claudecode/process.go b/tool/claudecode/process.go new file mode 100644 index 000000000..399010fbf --- /dev/null +++ b/tool/claudecode/process.go @@ -0,0 +1,193 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "sync" +) + +type capturedProcessResult struct { + Stdout []byte + Stderr []byte + ExitCode int +} + +type waitedProcessState struct { + State *os.ProcessState + Err error +} + +type processCapture struct { + stdout bytes.Buffer + stderr bytes.Buffer + wg sync.WaitGroup +} + +func runCapturedProcess( + ctx context.Context, + dir string, + env []string, + bin string, + args ...string, +) (capturedProcessResult, error) { + processPath, err := exec.LookPath(bin) + if err != nil { + return capturedProcessResult{}, err + } + stdin, stdoutReader, stdoutWriter, stderrReader, stderrWriter, closeErr, err := processPipes() + if err != nil { + return capturedProcessResult{}, err + } + defer closeErr() + proc, err := os.StartProcess( + processPath, + append([]string{processPath}, args...), + &os.ProcAttr{ + Dir: dir, + Env: processEnv(env), + Files: []*os.File{stdin, stdoutWriter, stderrWriter}, + }, + ) + if err != nil { + return capturedProcessResult{}, err + } + _ = stdoutWriter.Close() + _ = stderrWriter.Close() + capture := startProcessCapture(stdoutReader, stderrReader) + state := waitForProcess(ctx, proc) + stdout, stderr := capture.wait() + result := capturedProcessResult{ + Stdout: stdout, + Stderr: stderr, + } + if state.State != nil { + result.ExitCode = state.State.ExitCode() + } + return result, state.Err +} + +func startProcess( + dir string, + env []string, + stdoutFile *os.File, + stderrFile *os.File, + bin string, + args ...string, +) (*os.Process, error) { + processPath, err := exec.LookPath(bin) + if err != nil { + return nil, err + } + stdin, err := os.Open(os.DevNull) + if err != nil { + return nil, err + } + defer stdin.Close() + return os.StartProcess( + processPath, + append([]string{processPath}, args...), + &os.ProcAttr{ + Dir: dir, + Env: processEnv(env), + Files: []*os.File{stdin, stdoutFile, stderrFile}, + }, + ) +} + +func processPipes() ( + *os.File, + *os.File, + *os.File, + *os.File, + *os.File, + func() error, + error, +) { + stdin, err := os.Open(os.DevNull) + if err != nil { + return nil, nil, nil, nil, nil, nil, err + } + stdoutReader, stdoutWriter, err := os.Pipe() + if err != nil { + _ = stdin.Close() + return nil, nil, nil, nil, nil, nil, err + } + stderrReader, stderrWriter, err := os.Pipe() + if err != nil { + _ = stdin.Close() + _ = stdoutReader.Close() + _ = stdoutWriter.Close() + return nil, nil, nil, nil, nil, nil, err + } + closeAll := func() error { + var firstErr error + for _, closer := range []io.Closer{stdin, stdoutReader, stdoutWriter, stderrReader, stderrWriter} { + if closer == nil { + continue + } + if err := closer.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr + } + return stdin, stdoutReader, stdoutWriter, stderrReader, stderrWriter, closeAll, nil +} + +func processEnv(extra []string) []string { + env := os.Environ() + if len(extra) == 0 { + return env + } + return append(env, extra...) +} + +func startProcessCapture(stdoutReader *os.File, stderrReader *os.File) *processCapture { + capture := &processCapture{} + capture.wg.Add(2) + go func() { + defer capture.wg.Done() + _, _ = io.Copy(&capture.stdout, stdoutReader) + }() + go func() { + defer capture.wg.Done() + _, _ = io.Copy(&capture.stderr, stderrReader) + }() + return capture +} + +func (c *processCapture) wait() ([]byte, []byte) { + c.wg.Wait() + return c.stdout.Bytes(), c.stderr.Bytes() +} + +func waitForProcess(ctx context.Context, proc *os.Process) waitedProcessState { + waitCh := make(chan waitedProcessState, 1) + go func() { + state, err := proc.Wait() + waitCh <- waitedProcessState{State: state, Err: err} + }() + select { + case state := <-waitCh: + return state + case <-ctx.Done(): + _ = proc.Kill() + state := <-waitCh + if state.Err == nil { + state.Err = ctx.Err() + } + return state + } +} diff --git a/tool/claudecode/read.go b/tool/claudecode/read.go new file mode 100644 index 000000000..2171a37cb --- /dev/null +++ b/tool/claudecode/read.go @@ -0,0 +1,179 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newReadTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in readInput) (readOutput, error) { + baseDir := runtime.currentBaseDir() + _, absPath, err := normalizePath(baseDir, in.FilePath) + if err != nil { + return readOutput{}, err + } + snapshot, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return readOutput{}, err + } + if !snapshot.Exists { + return readOutput{}, fmt.Errorf("File does not exist: %s", in.FilePath) + } + runtime.fileState.mu.Lock() + existing, ok := runtime.fileState.views[absPath] + if ok && existing.Timestamp == snapshot.Timestamp && matchesReadView(existing, in.Offset, in.Limit, in.Pages) { + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "file_unchanged", + File: &readFile{ + FilePath: absPath, + }, + }, nil + } + runtime.fileState.mu.Unlock() + ext := strings.ToLower(absPath) + switch { + case strings.HasSuffix(ext, ".ipynb"): + return readNotebook(runtime, snapshot, in) + case strings.HasSuffix(ext, ".pdf"): + return readPDF(runtime, snapshot, in) + case strings.HasPrefix(snapshot.MediaType, "image/"): + return readImage(runtime, snapshot, in) + default: + if isProbablyBinary(snapshot.Raw) { + return readOutput{}, fmt.Errorf("This tool cannot read binary files.") + } + return readText(runtime, snapshot, in) + } + }, + function.WithName(toolRead), + function.WithDescription(readDescription()), + ), nil +} + +func readText(runtime *runtime, snapshot localFileSnapshot, in readInput) (readOutput, error) { + startLine := 1 + if in.Offset != nil && *in.Offset > 0 { + startLine = *in.Offset + } + content, actualStartLine, totalLines := sliceLines(snapshot.Content, startLine, in.Limit) + runtime.fileState.mu.Lock() + storeReadView(runtime.fileState, snapshot.Path, content, snapshot.Timestamp, in.Offset, in.Limit, in.Pages, in.Limit != nil || startLine > 1, true) + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "text", + File: &readFile{ + FilePath: snapshot.Path, + Content: content, + NumLines: countLines(content), + StartLine: actualStartLine, + TotalLines: totalLines, + }, + }, nil +} + +func readNotebook(runtime *runtime, snapshot localFileSnapshot, in readInput) (readOutput, error) { + var notebook struct { + Cells []map[string]any `json:"cells"` + } + if err := json.Unmarshal(snapshot.Raw, ¬ebook); err != nil { + return readOutput{}, err + } + runtime.fileState.mu.Lock() + storeReadView(runtime.fileState, snapshot.Path, snapshot.Content, snapshot.Timestamp, in.Offset, in.Limit, in.Pages, false, true) + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "notebook", + File: &readFile{ + FilePath: snapshot.Path, + Cells: notebook.Cells, + }, + }, nil +} + +func readPDF(runtime *runtime, snapshot localFileSnapshot, in readInput) (readOutput, error) { + pageCount, err := pdfPageCount(snapshot.Raw) + if err != nil { + return readOutput{}, err + } + if strings.TrimSpace(in.Pages) != "" { + pageRange, rangeErr := resolvePDFPageRange(in.Pages, pageCount) + if rangeErr != nil { + return readOutput{}, rangeErr + } + outputDir, renderedCount, extractErr := extractPDFPages(snapshot.Path, pageRange) + if extractErr != nil { + return readOutput{}, extractErr + } + runtime.fileState.mu.Lock() + storeReadView(runtime.fileState, snapshot.Path, snapshot.Content, snapshot.Timestamp, in.Offset, in.Limit, in.Pages, true, true) + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "parts", + File: &readFile{ + FilePath: snapshot.Path, + OriginalSize: snapshot.OriginalSize, + Count: renderedCount, + OutputDir: outputDir, + }, + }, nil + } + if pageCount > pdfInlineReadThreshold { + return readOutput{}, fmt.Errorf("This PDF has %d pages, which is too many to read at once. Use the pages parameter to read specific page ranges (e.g., pages: \"1-5\"). Maximum %d pages per request.", pageCount, pdfMaxPagesPerRead) + } + runtime.fileState.mu.Lock() + storeReadView(runtime.fileState, snapshot.Path, snapshot.Content, snapshot.Timestamp, in.Offset, in.Limit, in.Pages, false, true) + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "pdf", + File: &readFile{ + FilePath: snapshot.Path, + Base64: fileBase64(snapshot.Raw), + OriginalSize: snapshot.OriginalSize, + }, + }, nil +} + +func readImage(runtime *runtime, snapshot localFileSnapshot, in readInput) (readOutput, error) { + runtime.fileState.mu.Lock() + storeReadView(runtime.fileState, snapshot.Path, snapshot.Content, snapshot.Timestamp, in.Offset, in.Limit, in.Pages, false, true) + runtime.fileState.mu.Unlock() + return readOutput{ + Type: "image", + File: &readFile{ + FilePath: snapshot.Path, + Base64: fileBase64(snapshot.Raw), + Type: snapshot.MediaType, + MediaType: snapshot.MediaType, + OriginalSize: snapshot.OriginalSize, + }, + }, nil +} + +func readDescription() string { + return fmt.Sprintf(`Read one file from the workspace. + +Usage: +- Use %s for reading text files, screenshots, other images, PDF files, and Jupyter notebooks. +- file_path may be workspace-relative or absolute. +- By default the tool reads from the beginning of the file. Use offset and limit for targeted text reads when you already know the region you need. +- Re-reading the same unchanged file slice may return type=file_unchanged instead of repeating the content. +- For PDFs larger than %d pages, you MUST provide the pages parameter. A single request can read at most %d pages. +- This tool reads files only. Use %s or %s for directory exploration. +- This tool does not read arbitrary binary files. Images, PDFs, and notebooks are handled as structured formats instead.`, toolRead, pdfInlineReadThreshold, pdfMaxPagesPerRead, toolBash, toolGlob) +} diff --git a/tool/claudecode/task_output.go b/tool/claudecode/task_output.go new file mode 100644 index 000000000..068a44261 --- /dev/null +++ b/tool/claudecode/task_output.go @@ -0,0 +1,125 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "os" + "time" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newTaskOutputTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(ctx context.Context, in taskOutputInput) (taskOutputOutput, error) { + if in.TaskID == "" { + return taskOutputOutput{}, fmt.Errorf("task_id is required") + } + block := true + if in.Block != nil { + block = *in.Block + } + timeoutMs := 30_000 + if in.Timeout != nil { + timeoutMs = *in.Timeout + } + if timeoutMs < 0 { + timeoutMs = 0 + } + if !block { + task, err := readTaskSnapshot(runtime, in.TaskID) + if err != nil { + return taskOutputOutput{}, err + } + status := "success" + if task.Status == "running" { + status = "not_ready" + } + return taskOutputOutput{ + RetrievalStatus: status, + Task: task, + }, nil + } + deadline := time.Now().Add(time.Duration(timeoutMs) * time.Millisecond) + for { + task, err := readTaskSnapshot(runtime, in.TaskID) + if err != nil { + return taskOutputOutput{}, err + } + if task.Status != "running" { + return taskOutputOutput{ + RetrievalStatus: "success", + Task: task, + }, nil + } + if timeoutMs == 0 || time.Now().After(deadline) { + return taskOutputOutput{ + RetrievalStatus: "timeout", + Task: task, + }, nil + } + select { + case <-ctx.Done(): + return taskOutputOutput{}, ctx.Err() + case <-time.After(100 * time.Millisecond): + } + } + }, + function.WithName(toolTaskOutput), + function.WithDescription(taskOutputDescription()), + ), nil +} + +func readTaskSnapshot(runtime *runtime, taskID string) (*taskOutputTask, error) { + task, err := snapshotBackgroundTask(runtime, taskID) + if err != nil { + return nil, err + } + outputBytes, err := os.ReadFile(task.OutputPath) + if err != nil && !os.IsNotExist(err) { + return nil, err + } + return &taskOutputTask{ + TaskID: task.ID, + TaskType: task.Type, + Status: task.Status, + Description: task.Command, + Output: string(outputBytes), + ExitCode: task.ExitCode, + }, nil +} + +func snapshotBackgroundTask(runtime *runtime, taskID string) (*backgroundTask, error) { + runtime.taskState.mu.Lock() + defer runtime.taskState.mu.Unlock() + task := runtime.taskState.tasks[taskID] + if task == nil { + return nil, fmt.Errorf("no task found with ID: %s", taskID) + } + taskCopy := *task + if task.ExitCode != nil { + exitCode := *task.ExitCode + taskCopy.ExitCode = &exitCode + } + return &taskCopy, nil +} + +func taskOutputDescription() string { + return fmt.Sprintf(`Read output from a running or completed background task. + +Usage: +- Use this tool with a task ID returned by %s. +- By default the tool blocks until the task finishes or the timeout expires. +- Set block=false to poll without waiting. +- The response includes the task status, captured output, and exit code when available.`, toolBash) +} diff --git a/tool/claudecode/task_stop.go b/tool/claudecode/task_stop.go new file mode 100644 index 000000000..b531778af --- /dev/null +++ b/tool/claudecode/task_stop.go @@ -0,0 +1,76 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "strings" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newTaskStopTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in taskStopInput) (taskStopOutput, error) { + taskID := strings.TrimSpace(in.TaskID) + if taskID == "" { + taskID = strings.TrimSpace(in.ShellID) + } + if taskID == "" { + return taskStopOutput{}, fmt.Errorf("Missing required parameter: task_id") + } + runtime.taskState.mu.Lock() + task := runtime.taskState.tasks[taskID] + if task == nil { + runtime.taskState.mu.Unlock() + return taskStopOutput{}, fmt.Errorf("No task found with ID: %s", taskID) + } + if task.Status != "running" { + status := task.Status + runtime.taskState.mu.Unlock() + return taskStopOutput{}, fmt.Errorf("Task %s is not running (status: %s)", taskID, status) + } + process := task.Process + command := task.Command + taskType := task.Type + runtime.taskState.mu.Unlock() + if process == nil { + return taskStopOutput{}, fmt.Errorf("Task %s has no running process", taskID) + } + if err := process.Kill(); err != nil { + return taskStopOutput{}, err + } + runtime.taskState.mu.Lock() + if current := runtime.taskState.tasks[taskID]; current != nil { + current.Status = "killed" + } + runtime.taskState.mu.Unlock() + return taskStopOutput{ + Message: fmt.Sprintf("Successfully stopped task: %s (%s)", taskID, command), + TaskID: taskID, + TaskType: taskType, + Command: command, + }, nil + }, + function.WithName(toolTaskStop), + function.WithDescription(taskStopDescription()), + ), nil +} + +func taskStopDescription() string { + return fmt.Sprintf(`Stop a running background task by ID. + +Usage: +- Use this tool to terminate a task started by %s with run_in_background=true. +- Pass task_id directly. shell_id is accepted as a compatibility alias. +- The tool only succeeds for tasks that are still running.`, toolBash) +} diff --git a/tool/claudecode/toolset.go b/tool/claudecode/toolset.go new file mode 100644 index 000000000..b1c84bf59 --- /dev/null +++ b/tool/claudecode/toolset.go @@ -0,0 +1,113 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "path/filepath" + "strings" + + "trpc.group/trpc-go/trpc-agent-go/tool" +) + +// NewToolSet constructs a Claude Code-compatible toolset. +func NewToolSet(opts ...Option) (tool.ToolSet, error) { + options := toolSetOptions{ + name: defaultToolSetName, + webFetch: WebFetchOptions{ + AllowAll: true, + }, + } + for _, opt := range opts { + opt(&options) + } + baseDir := strings.TrimSpace(options.baseDir) + if baseDir == "" { + baseDir = "." + } + baseAbs, err := filepath.Abs(baseDir) + if err != nil { + return nil, err + } + runtime := newToolRuntime(baseAbs, options.maxFileSize) + cc := &compositeToolSet{ + name: options.name, + } + if strings.TrimSpace(cc.name) == "" { + cc.name = defaultToolSetName + } + if err := appendCoreTools(cc, runtime, options.readOnly); err != nil { + return nil, err + } + if err := appendWebTools(cc, options); err != nil { + return nil, err + } + return cc, nil +} + +func newToolRuntime(baseAbs string, maxFileSize int64) *runtime { + return &runtime{ + baseDir: baseAbs, + maxFileSize: maxFileSize, + fileState: &fileState{ + views: map[string]fileView{}, + }, + taskState: &taskState{ + tasks: map[string]*backgroundTask{}, + }, + } +} + +func appendCoreTools(cc *compositeToolSet, rt *runtime, readOnly bool) error { + coreTools := []func(*runtime) (tool.Tool, error){ + newBashTool, + newTaskStopTool, + newTaskOutputTool, + newReadTool, + newGlobTool, + newGrepTool, + } + for _, buildTool := range coreTools { + builtTool, err := buildTool(rt) + if err != nil { + return err + } + cc.tools = append(cc.tools, builtTool) + } + if readOnly { + return nil + } + writeTool, err := newWriteTool(rt) + if err != nil { + return err + } + editTool, err := newEditTool(rt) + if err != nil { + return err + } + notebookEditTool, err := newNotebookEditTool(rt) + if err != nil { + return err + } + cc.tools = append(cc.tools, writeTool, editTool, notebookEditTool) + return nil +} + +func appendWebTools(cc *compositeToolSet, options toolSetOptions) error { + webFetchTool, err := newWebFetchTool(options.webFetch) + if err != nil { + return err + } + webSearchTool, err := newWebSearchTool(options.webSearch) + if err != nil { + return err + } + cc.tools = append(cc.tools, webFetchTool, webSearchTool) + return nil +} diff --git a/tool/claudecode/types.go b/tool/claudecode/types.go new file mode 100644 index 000000000..d946701ca --- /dev/null +++ b/tool/claudecode/types.go @@ -0,0 +1,282 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "os" + "sync" + + "trpc.group/trpc-go/trpc-agent-go/tool" +) + +type compositeToolSet struct { + name string + tools []tool.Tool +} + +type runtime struct { + mu sync.RWMutex + baseDir string + maxFileSize int64 + fileState *fileState + taskState *taskState +} + +type fileView struct { + Content string + Timestamp int64 + Offset *int + Limit *int + Pages string + IsPartialView bool + FromRead bool +} + +type fileState struct { + mu sync.Mutex + views map[string]fileView +} + +type taskState struct { + mu sync.Mutex + tasks map[string]*backgroundTask +} + +type backgroundTask struct { + ID string + Command string + Type string + OutputPath string + Process *os.Process + Status string + ExitCode *int +} + +type localFileSnapshot struct { + Exists bool + Path string + Raw []byte + Content string + Mode os.FileMode + Timestamp int64 + Encoding string + LineEnding string + MediaType string + OriginalSize int64 +} + +type patchHunk struct { + OldStart int `json:"oldStart"` + OldLines int `json:"oldLines"` + NewStart int `json:"newStart"` + NewLines int `json:"newLines"` + Lines []string `json:"lines"` +} + +type bashInput struct { + Command string `json:"command"` + Timeout *int `json:"timeout,omitempty"` + RunInBackground bool `json:"run_in_background,omitempty"` +} + +type bashOutput struct { + Command string `json:"command"` + ExitCode int `json:"exitCode"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` + Output string `json:"output,omitempty"` + DurationMs int64 `json:"durationMs"` + TimedOut bool `json:"timedOut,omitempty"` + BackgroundTaskID string `json:"taskId,omitempty"` + OutputPath string `json:"outputPath,omitempty"` +} + +type taskStopInput struct { + TaskID string `json:"task_id,omitempty"` + ShellID string `json:"shell_id,omitempty"` +} + +type taskStopOutput struct { + Message string `json:"message"` + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Command string `json:"command,omitempty"` +} + +type taskOutputInput struct { + TaskID string `json:"task_id"` + Block *bool `json:"block,omitempty"` + Timeout *int `json:"timeout,omitempty"` +} + +type taskOutputTask struct { + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Status string `json:"status"` + Description string `json:"description"` + Output string `json:"output"` + ExitCode *int `json:"exitCode,omitempty"` + Error string `json:"error,omitempty"` +} + +type taskOutputOutput struct { + RetrievalStatus string `json:"retrieval_status"` + Task *taskOutputTask `json:"task"` +} + +type readInput struct { + FilePath string `json:"file_path"` + Offset *int `json:"offset,omitempty"` + Limit *int `json:"limit,omitempty"` + Pages string `json:"pages,omitempty"` +} + +type readFile struct { + FilePath string `json:"filePath,omitempty"` + Content string `json:"content,omitempty"` + NumLines int `json:"numLines,omitempty"` + StartLine int `json:"startLine,omitempty"` + TotalLines int `json:"totalLines,omitempty"` + Base64 string `json:"base64,omitempty"` + Type string `json:"type,omitempty"` + MediaType string `json:"mediaType,omitempty"` + OriginalSize int64 `json:"originalSize,omitempty"` + Count int `json:"count,omitempty"` + OutputDir string `json:"outputDir,omitempty"` + Cells []map[string]any `json:"cells,omitempty"` +} + +type readOutput struct { + Type string `json:"type"` + File *readFile `json:"file,omitempty"` +} + +type writeInput struct { + FilePath string `json:"file_path"` + Content string `json:"content"` +} + +type writeOutput struct { + Type string `json:"type"` + FilePath string `json:"filePath"` + Content string `json:"content"` + StructuredPatch []patchHunk `json:"structuredPatch"` + OriginalFile *string `json:"originalFile"` +} + +type editInput struct { + FilePath string `json:"file_path"` + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` +} + +type editOutput struct { + FilePath string `json:"filePath"` + OldString string `json:"oldString"` + NewString string `json:"newString"` + OriginalFile string `json:"originalFile"` + StructuredPatch []patchHunk `json:"structuredPatch"` + UserModified bool `json:"userModified"` + ReplaceAll bool `json:"replaceAll"` +} + +type globInput struct { + Pattern string `json:"pattern"` + Path string `json:"path,omitempty"` +} + +type globOutput struct { + DurationMs int64 `json:"durationMs"` + NumFiles int `json:"numFiles"` + Filenames []string `json:"filenames"` + Truncated bool `json:"truncated"` +} + +type grepInput struct { + Pattern string `json:"pattern"` + Path string `json:"path,omitempty"` + Glob string `json:"glob,omitempty"` + OutputMode string `json:"output_mode,omitempty"` + Before *int `json:"-B,omitempty"` + After *int `json:"-A,omitempty"` + Context *int `json:"-C,omitempty"` + ContextAlt *int `json:"context,omitempty"` + ShowLineNum *bool `json:"-n,omitempty"` + IgnoreCase *bool `json:"-i,omitempty"` + Type string `json:"type,omitempty"` + HeadLimit *int `json:"head_limit,omitempty"` + Offset *int `json:"offset,omitempty"` + Multiline bool `json:"multiline,omitempty"` +} + +type grepOutput struct { + Mode string `json:"mode,omitempty"` + NumFiles int `json:"numFiles"` + Filenames []string `json:"filenames"` + Content string `json:"content,omitempty"` + NumLines int `json:"numLines,omitempty"` + NumMatches int `json:"numMatches,omitempty"` + AppliedLimit *int `json:"appliedLimit,omitempty"` + AppliedOffset int `json:"appliedOffset,omitempty"` +} + +type webFetchInput struct { + URL string `json:"url"` + Prompt string `json:"prompt"` +} + +type webFetchOutput struct { + Bytes int `json:"bytes"` + Code int `json:"code"` + CodeText string `json:"codeText"` + Result string `json:"result"` + DurationMs int64 `json:"durationMs"` + URL string `json:"url"` +} + +type webSearchInput struct { + Query string `json:"query"` + AllowedDomains []string `json:"allowed_domains,omitempty"` + BlockedDomains []string `json:"blocked_domains,omitempty"` +} + +type webSearchHit struct { + Title string `json:"title"` + URL string `json:"url"` + Snippet string `json:"snippet,omitempty"` +} + +type webSearchResult struct { + ToolUseID string `json:"tool_use_id,omitempty"` + Content []webSearchHit `json:"content,omitempty"` + Text string `json:"text,omitempty"` +} + +type webSearchOutput struct { + Query string `json:"query"` + Results []webSearchResult `json:"results"` + DurationSeconds float64 `json:"durationSeconds"` +} + +func (s *compositeToolSet) Tools(ctx context.Context) []tool.Tool { + out := make([]tool.Tool, 0, len(s.tools)) + out = append(out, s.tools...) + return out +} + +func (s *compositeToolSet) Close() error { + return nil +} + +func (s *compositeToolSet) Name() string { + return s.name +} diff --git a/tool/claudecode/web_fetch.go b/tool/claudecode/web_fetch.go new file mode 100644 index 000000000..db32c9d9c --- /dev/null +++ b/tool/claudecode/web_fetch.go @@ -0,0 +1,167 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newWebFetchTool(options WebFetchOptions) (tool.Tool, error) { + toolOptions := options + client := &http.Client{ + Timeout: defaultHTTPTimeout, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + if toolOptions.Timeout > 0 { + client.Timeout = toolOptions.Timeout + } + return function.NewFunctionTool( + func(ctx context.Context, in webFetchInput) (webFetchOutput, error) { + if strings.TrimSpace(in.Prompt) == "" { + return webFetchOutput{}, fmt.Errorf("prompt is required") + } + if !matchSearchDomainFilters(in.URL, toolOptions.AllowedDomains, toolOptions.BlockedDomains) { + return webFetchOutput{}, fmt.Errorf("url is blocked by domain policy: %s", in.URL) + } + start := time.Now() + finalURL, statusCode, statusText, body, contentType, err := fetchURL(ctx, client, in.URL, toolOptions) + if err != nil { + return webFetchOutput{}, err + } + content := string(body) + if strings.Contains(strings.ToLower(contentType), "html") { + content = extractHTMLText(body) + } + result, err := processFetchedContent(ctx, toolOptions, in, content, contentType) + if err != nil { + return webFetchOutput{}, err + } + return webFetchOutput{ + Bytes: len(body), + Code: statusCode, + CodeText: statusText, + Result: result, + DurationMs: max(time.Since(start).Milliseconds(), 1), + URL: finalURL, + }, nil + }, + function.WithName(toolWebFetch), + function.WithDescription(webFetchDescription()), + ), nil +} + +func fetchURL( + ctx context.Context, + client *http.Client, + rawURL string, + options WebFetchOptions, +) (string, int, string, []byte, string, error) { + currentURL := rawURL + originalHost := searchURLHost(rawURL) + for redirects := 0; redirects < 5; redirects++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL, nil) + if err != nil { + return "", 0, "", nil, "", err + } + resp, err := client.Do(req) + if err != nil { + return "", 0, "", nil, "", err + } + if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest { + location := resp.Header.Get("Location") + _ = resp.Body.Close() + if location == "" { + return currentURL, resp.StatusCode, resp.Status, nil, "", nil + } + nextURL, err := resolveRedirectURL(currentURL, location) + if err != nil { + return "", 0, "", nil, "", err + } + if searchURLHost(nextURL) != originalHost { + message := fmt.Sprintf("REDIRECT DETECTED: The URL redirects to a different host.\n\nOriginal URL: %s\nRedirect URL: %s\nStatus: %d %s\n\nTo complete your request, fetch the redirected URL directly.", rawURL, nextURL, resp.StatusCode, resp.Status) + return rawURL, resp.StatusCode, resp.Status, []byte(message), "text/plain; charset=utf-8", nil + } + currentURL = nextURL + continue + } + body, err := readHTTPBody(resp, options.MaxContentLength, options.MaxTotalContentLength) + contentType := resp.Header.Get("Content-Type") + statusCode := resp.StatusCode + statusText := resp.Status + finalURL := resp.Request.URL.String() + _ = resp.Body.Close() + if err != nil { + return "", 0, "", nil, "", err + } + return finalURL, statusCode, statusText, body, contentType, nil + } + return "", 0, "", nil, "", fmt.Errorf("too many redirects") +} + +func resolveRedirectURL(baseURL string, location string) (string, error) { + baseParsed, err := url.Parse(baseURL) + if err != nil { + return "", err + } + locationParsed, err := url.Parse(location) + if err != nil { + return "", err + } + return baseParsed.ResolveReference(locationParsed).String(), nil +} + +func processFetchedContent( + ctx context.Context, + options WebFetchOptions, + in webFetchInput, + content string, + contentType string, +) (string, error) { + if options.PromptProcessor != nil { + return options.PromptProcessor(ctx, WebFetchPromptInput{ + URL: in.URL, + Prompt: in.Prompt, + Content: content, + ContentType: contentType, + }) + } + return trimFetchResult(content, 4000), nil +} + +func webFetchDescription() string { + return fmt.Sprintf(`Fetch one URL and process the fetched content according to a prompt. + +Usage: +- url must be a fully formed URL. +- prompt is required and should describe the information to extract or summarize. +- HTML pages are converted into extracted text before prompt processing. +- When PromptProcessor is configured, the fetched content is passed to it for prompt-aware post-processing. +- Same-host redirects are followed automatically. Cross-host redirects return a redirect notice and require a second %s call with the redirected URL. +- Prefer %s first when you need to discover relevant pages, then use %s to read a selected page in depth. +- For GitHub repository, issue, or pull request metadata, prefer %s with gh when possible.`, toolWebFetch, toolWebSearch, toolWebFetch, toolBash) +} + +func trimFetchResult(content string, limit int) string { + trimmed := strings.TrimSpace(content) + if limit <= 0 || len(trimmed) <= limit { + return trimmed + } + return strings.TrimSpace(trimmed[:limit]) + "\n\n[Content truncated due to length.]" +} diff --git a/tool/claudecode/web_search.go b/tool/claudecode/web_search.go new file mode 100644 index 000000000..1d8a2d579 --- /dev/null +++ b/tool/claudecode/web_search.go @@ -0,0 +1,371 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/net/html" + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +type codeSearchBackend interface { + search(context.Context, webSearchInput) ([]webSearchHit, error) +} + +type duckDuckGoSearchBackend struct { + client *http.Client + baseURL string + userAgent string + size int + offset int +} + +type googleSearchBackend struct { + client *http.Client + options *WebSearchOptions +} + +func newWebSearchTool(options *WebSearchOptions) (tool.Tool, error) { + target, err := newSearchBackend(options) + if err != nil { + return nil, err + } + return function.NewFunctionTool( + func(ctx context.Context, in webSearchInput) (webSearchOutput, error) { + if strings.TrimSpace(in.Query) == "" { + return webSearchOutput{}, fmt.Errorf("query is required") + } + if len(in.AllowedDomains) > 0 && len(in.BlockedDomains) > 0 { + return webSearchOutput{}, fmt.Errorf("cannot specify both allowed_domains and blocked_domains") + } + start := time.Now() + hits, err := target.search(ctx, in) + results := make([]webSearchResult, 0, 1) + if len(hits) > 0 { + results = append(results, webSearchResult{ + ToolUseID: uuid.NewString(), + Content: hits, + }) + } + return webSearchOutput{ + Query: in.Query, + Results: results, + DurationSeconds: max(time.Since(start).Seconds(), 0.001), + }, err + }, + function.WithName(toolWebSearch), + function.WithDescription(webSearchDescription()), + ), nil +} + +func newSearchBackend(options *WebSearchOptions) (codeSearchBackend, error) { + provider := "duckduckgo" + if options != nil && strings.TrimSpace(options.Provider) != "" { + provider = strings.ToLower(strings.TrimSpace(options.Provider)) + } + client := &http.Client{Timeout: defaultHTTPTimeout} + if options != nil && options.Timeout > 0 { + client.Timeout = options.Timeout + } + switch provider { + case "duckduckgo": + baseURL := "https://html.duckduckgo.com/html/" + userAgent := "" + if options != nil { + if strings.TrimSpace(options.BaseURL) != "" { + baseURL = strings.TrimSpace(options.BaseURL) + } + userAgent = strings.TrimSpace(options.UserAgent) + } + return &duckDuckGoSearchBackend{ + client: client, + baseURL: baseURL, + userAgent: userAgent, + size: max(0, webSearchSize(options)), + offset: max(0, webSearchOffset(options)), + }, nil + case "google": + return &googleSearchBackend{client: client, options: options}, nil + default: + return nil, fmt.Errorf("unsupported web search provider: %s", provider) + } +} + +func webSearchDescription() string { + return fmt.Sprintf(`Search the web for current information. + +Usage: +- Use %s for open-ended discovery, current events, recent documentation, or when you do not yet know the exact page to fetch. +- query is required. +- allowed_domains and blocked_domains may constrain the search, but you must not set both at the same time. +- Results contain titles, URLs, and snippets, grouped into Claude-style search result blocks. +- After choosing a relevant result, use %s to read the destination page in detail.`, toolWebSearch, toolWebFetch) +} + +func (b *duckDuckGoSearchBackend) search( + ctx context.Context, + in webSearchInput, +) ([]webSearchHit, error) { + u, err := url.Parse(b.baseURL) + if err != nil { + return nil, err + } + query := u.Query() + query.Set("q", in.Query) + u.RawQuery = query.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + if b.userAgent != "" { + req.Header.Set("User-Agent", b.userAgent) + } + resp, err := b.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + body, _ := readHTTPBody(resp, 16*1024, 16*1024) + return nil, fmt.Errorf("duckduckgo search request failed: status=%d body=%s", resp.StatusCode, body) + } + body, err := readHTTPBody(resp, 512*1024, 512*1024) + if err != nil { + return nil, err + } + return parseDuckDuckGoHTML(body, in, b.offset, b.size), nil +} + +func parseDuckDuckGoHTML(body []byte, in webSearchInput, offset int, limit int) []webSearchHit { + doc, err := html.Parse(bytes.NewReader(body)) + if err != nil { + return nil + } + type partialResult struct { + Title string + URL string + Snippet string + } + results := make([]partialResult, 0, 10) + var visit func(*html.Node) + visit = func(node *html.Node) { + if node.Type == html.ElementNode && node.Data == "a" && htmlHasClass(node, "result__a") { + title := strings.TrimSpace(htmlNodeText(node)) + link := "" + for _, attr := range node.Attr { + if attr.Key == "href" { + link = strings.TrimSpace(attr.Val) + break + } + } + results = append(results, partialResult{Title: title, URL: link}) + } + if node.Type == html.ElementNode && htmlHasClass(node, "result__snippet") { + if len(results) > 0 && results[len(results)-1].Snippet == "" { + results[len(results)-1].Snippet = strings.TrimSpace(htmlNodeText(node)) + } + } + for child := node.FirstChild; child != nil; child = child.NextSibling { + visit(child) + } + } + visit(doc) + hits := make([]webSearchHit, 0, len(results)) + seen := map[string]struct{}{} + for _, item := range results { + normalizedURL := normalizeDuckDuckGoResultURL(item.URL) + if normalizedURL == "" || !matchSearchDomainFilters(normalizedURL, in.AllowedDomains, in.BlockedDomains) { + continue + } + if _, ok := seen[normalizedURL]; ok { + continue + } + seen[normalizedURL] = struct{}{} + hits = append(hits, webSearchHit{ + Title: item.Title, + URL: normalizedURL, + Snippet: collapseWhitespace(item.Snippet), + }) + } + return applySearchWindow(hits, offset, limit) +} + +func normalizeDuckDuckGoResultURL(rawURL string) string { + trimmed := strings.TrimSpace(rawURL) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil { + return trimmed + } + if uddg := strings.TrimSpace(parsed.Query().Get("uddg")); uddg != "" { + return uddg + } + return trimmed +} + +func htmlHasClass(node *html.Node, className string) bool { + for _, attr := range node.Attr { + if attr.Key != "class" { + continue + } + for _, candidate := range strings.Fields(attr.Val) { + if candidate == className { + return true + } + } + } + return false +} + +func htmlNodeText(node *html.Node) string { + parts := make([]string, 0, 8) + var visit func(*html.Node) + visit = func(current *html.Node) { + if current.Type == html.TextNode { + text := strings.TrimSpace(current.Data) + if text != "" { + parts = append(parts, text) + } + } + for child := current.FirstChild; child != nil; child = child.NextSibling { + visit(child) + } + } + visit(node) + return strings.Join(parts, " ") +} + +func (b *googleSearchBackend) search( + ctx context.Context, + in webSearchInput, +) ([]webSearchHit, error) { + if b.options == nil { + return nil, fmt.Errorf("google search config is required") + } + apiKey := strings.TrimSpace(b.options.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(os.Getenv(envGoogleAPIKey)) + } + engineID := strings.TrimSpace(b.options.EngineID) + if engineID == "" { + engineID = strings.TrimSpace(os.Getenv(envGoogleEngineID)) + } + if apiKey == "" || engineID == "" { + return nil, fmt.Errorf("google search requires api_key and engine_id") + } + baseURL := strings.TrimSpace(b.options.BaseURL) + if baseURL == "" { + baseURL = "https://www.googleapis.com/customsearch/v1" + } + u, err := url.Parse(baseURL) + if err != nil { + return nil, err + } + query := u.Query() + query.Set("key", apiKey) + query.Set("cx", engineID) + query.Set("q", in.Query) + if b.options.Size > 0 { + query.Set("num", strconv.Itoa(b.options.Size)) + } + if b.options.Offset > 0 { + query.Set("start", strconv.Itoa(b.options.Offset+1)) + } + if strings.TrimSpace(b.options.Lang) != "" { + query.Set("lr", "lang_"+strings.TrimSpace(b.options.Lang)) + } + u.RawQuery = query.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + resp, err := b.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + body, _ := readHTTPBody(resp, 16*1024, 16*1024) + return nil, fmt.Errorf("google search request failed: status=%d body=%s", resp.StatusCode, body) + } + var decoded struct { + Items []struct { + Link string `json:"link"` + Title string `json:"title"` + Snippet string `json:"snippet"` + } `json:"items"` + } + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + return nil, err + } + hits := make([]webSearchHit, 0, len(decoded.Items)) + seen := map[string]struct{}{} + for _, item := range decoded.Items { + link := strings.TrimSpace(item.Link) + if link == "" || !matchSearchDomainFilters(link, in.AllowedDomains, in.BlockedDomains) { + continue + } + if _, ok := seen[link]; ok { + continue + } + seen[link] = struct{}{} + hits = append(hits, webSearchHit{ + Title: item.Title, + URL: link, + Snippet: item.Snippet, + }) + } + return hits, nil +} + +func applySearchWindow(hits []webSearchHit, offset int, limit int) []webSearchHit { + if len(hits) == 0 { + return nil + } + if offset < 0 { + offset = 0 + } + if offset >= len(hits) { + return nil + } + hits = hits[offset:] + if limit > 0 && limit < len(hits) { + return hits[:limit] + } + return hits +} + +func webSearchSize(options *WebSearchOptions) int { + if options == nil || options.Size <= 0 { + return 0 + } + return options.Size +} + +func webSearchOffset(options *WebSearchOptions) int { + if options == nil || options.Offset <= 0 { + return 0 + } + return options.Offset +} diff --git a/tool/claudecode/write.go b/tool/claudecode/write.go new file mode 100644 index 000000000..07231ac91 --- /dev/null +++ b/tool/claudecode/write.go @@ -0,0 +1,82 @@ +// +// Tencent is pleased to support the open source community by making +// trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// + +package claudecode + +import ( + "context" + "fmt" + + "trpc.group/trpc-go/trpc-agent-go/tool" + "trpc.group/trpc-go/trpc-agent-go/tool/function" +) + +func newWriteTool(runtime *runtime) (tool.Tool, error) { + return function.NewFunctionTool( + func(_ context.Context, in writeInput) (writeOutput, error) { + baseDir := runtime.currentBaseDir() + _, absPath, err := normalizePath(baseDir, in.FilePath) + if err != nil { + return writeOutput{}, err + } + runtime.fileState.mu.Lock() + defer runtime.fileState.mu.Unlock() + snapshot, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return writeOutput{}, err + } + writeType := "create" + var originalFile *string + encoding := "utf8" + lineEnding := "\n" + mode := snapshot.Mode + if snapshot.Exists { + writeType = "update" + originalFile = &snapshot.Content + encoding = snapshot.Encoding + lineEnding = snapshot.LineEnding + if err := ensureWriteAllowed(absPath, snapshot, runtime.fileState); err != nil { + return writeOutput{}, err + } + } + if err := writeLocalFile(absPath, in.Content, mode, encoding, lineEnding); err != nil { + return writeOutput{}, err + } + current, err := readLocalFileSnapshot(absPath, runtime.maxFileSize) + if err != nil { + return writeOutput{}, err + } + storeReadView(runtime.fileState, absPath, current.Content, current.Timestamp, nil, nil, "", false, false) + previous := "" + if originalFile != nil { + previous = *originalFile + } + return writeOutput{ + Type: writeType, + FilePath: absPath, + Content: in.Content, + StructuredPatch: buildStructuredPatch(previous, in.Content), + OriginalFile: originalFile, + }, nil + }, + function.WithName(toolWrite), + function.WithDescription(writeDescription()), + ), nil +} + +func writeDescription() string { + return fmt.Sprintf(`Create or overwrite a file. + +Usage: +- Use %s when you want to replace the whole file or create a brand-new file. +- When overwriting an existing file, read it with %s first. A partial read is not enough to authorize a full overwrite. +- Prefer %s when you only need a targeted replacement inside an existing text file. +- Prefer %s when editing .ipynb files instead of rewriting notebook JSON manually. +- Updates preserve the existing file mode, encoding, and line endings when possible.`, toolWrite, toolRead, toolEdit, toolNotebookEdit) +}