diff --git a/.github/workflows/build-cli.yml b/.github/workflows/build-cli.yml index f91360a21f56..b18eb6c02c2a 100644 --- a/.github/workflows/build-cli.yml +++ b/.github/workflows/build-cli.yml @@ -17,7 +17,7 @@ on: ref: type: string required: false - default: 'refs/heads/main' + default: "" name: "Reusable workflow to build CLI" diff --git a/.github/workflows/bundle-desktop-intel.yml b/.github/workflows/bundle-desktop-intel.yml index ab8208d2d812..ee69ed7eb95f 100644 --- a/.github/workflows/bundle-desktop-intel.yml +++ b/.github/workflows/bundle-desktop-intel.yml @@ -82,7 +82,8 @@ jobs: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: - ref: ${{ inputs.ref }} + # Only pass ref if it's explicitly set, otherwise let checkout action use its default behavior + ref: ${{ inputs.ref != '' && inputs.ref || '' }} fetch-depth: 0 # Update versions before build diff --git a/.github/workflows/bundle-desktop-linux.yml b/.github/workflows/bundle-desktop-linux.yml index 7bf7a45d80de..f13a15de8846 100644 --- a/.github/workflows/bundle-desktop-linux.yml +++ b/.github/workflows/bundle-desktop-linux.yml @@ -28,7 +28,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4 with: - ref: ${{ inputs.ref }} + # Only pass ref if it's explicitly set, otherwise let checkout action use its default behavior + ref: ${{ inputs.ref != '' && inputs.ref || '' }} fetch-depth: 0 # 2) Update versions before build diff --git a/.github/workflows/bundle-desktop-windows.yml b/.github/workflows/bundle-desktop-windows.yml index 29877fea1dd4..6dd21d93102a 100644 --- a/.github/workflows/bundle-desktop-windows.yml +++ b/.github/workflows/bundle-desktop-windows.yml @@ -45,7 +45,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4 with: - ref: ${{ inputs.ref }} + # Only pass ref if it's explicitly set, otherwise let checkout action use its default behavior + ref: ${{ inputs.ref != '' && inputs.ref || '' }} fetch-depth: 0 # 2) Configure AWS credentials for code signing diff --git a/.gitignore b/.gitignore index caab83d726c7..41f629f09366 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll # Hermit .hermit/ +# Claude +.claude + debug_*.txt # Docs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 15f9f1465b52..5d540b9afd9f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -67,6 +67,21 @@ you can talk to goose! You can now make changes in the code in ui/desktop to iterate on the GUI half of goose. +### Regenerating the OpenAPI schema + +The file `ui/desktop/openapi.json` is automatically generated during the build. +It is written by the `generate_schema` binary in `crates/goose-server`. +If you need to update the spec without starting the UI, run: + +``` +just generate-openapi +``` + +This command regenerates `ui/desktop/openapi.json` and then runs the UI's +`generate-api` script to rebuild the TypeScript client from that spec. + +Changes to the API should be made in the Rust source under `crates/goose-server/src/`. + ## Creating a fork To fork the repository: diff --git a/Cargo.lock b/Cargo.lock index f53e92031276..0e6c66c81644 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3434,7 +3434,7 @@ dependencies = [ [[package]] name = "goose" -version = "1.0.35" +version = "1.1.0" dependencies = [ "ahash", "anyhow", @@ -3473,6 +3473,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest 0.12.12", + "rmcp", "serde", "serde_json", "serde_urlencoded", @@ -3490,6 +3491,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "urlencoding", "utoipa", "uuid", "webbrowser 0.8.15", @@ -3499,7 +3501,7 @@ dependencies = [ [[package]] name = "goose-bench" -version = "1.0.35" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", @@ -3512,6 +3514,7 @@ dependencies = [ "once_cell", "paste", "regex", + "rmcp", "serde", "serde_json", "tokio", @@ -3523,7 +3526,7 @@ dependencies = [ [[package]] name = "goose-cli" -version = "1.0.35" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", @@ -3553,6 +3556,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest 0.12.12", + "rmcp", "rustyline", "serde", "serde_json", @@ -3568,14 +3572,13 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", - "urlencoding", "webbrowser 1.0.4", "winapi", ] [[package]] name = "goose-ffi" -version = "1.0.35" +version = "1.1.0" dependencies = [ "cbindgen", "futures", @@ -3589,7 +3592,7 @@ dependencies = [ [[package]] name = "goose-llm" -version = "1.0.35" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", @@ -3598,6 +3601,7 @@ dependencies = [ "criterion", "ctor", "dotenv", + "goose", "include_dir", "indoc 1.0.9", "lazy_static", @@ -3618,7 +3622,7 @@ dependencies = [ [[package]] name = "goose-mcp" -version = "1.0.35" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", @@ -3626,6 +3630,7 @@ dependencies = [ "chrono", "docx-rs", "etcetera", + "glob", "google-apis-common", "google-docs1", "google-drive3", @@ -3646,6 +3651,7 @@ dependencies = [ "once_cell", "regex", "reqwest 0.11.27", + "rmcp", "serde", "serde_json", "serde_with", @@ -3660,7 +3666,6 @@ dependencies = [ "tracing-subscriber", "umya-spreadsheet", "url", - "urlencoding", "utoipa", "webbrowser 0.8.15", "which 6.0.3", @@ -3669,7 +3674,7 @@ dependencies = [ [[package]] name = "goose-server" -version = "1.0.35" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", @@ -3690,6 +3695,7 @@ dependencies = [ "mcp-server", "once_cell", "reqwest 0.12.12", + "rmcp", "serde", "serde_json", "serde_yaml", @@ -5409,6 +5415,7 @@ dependencies = [ "async-trait", "base64 0.21.7", "chrono", + "rmcp", "schemars", "serde", "serde_json", @@ -5444,6 +5451,7 @@ dependencies = [ "mcp-core", "mcp-macros", "pin-project", + "rmcp", "schemars", "serde", "serde_json", @@ -7116,6 +7124,40 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37f2048a81a7ff7e8ef6bc5abced70c3d9114c8f03d85d7aaaafd9fd04f12e9e" +dependencies = [ + "base64 0.22.1", + "chrono", + "futures", + "paste", + "pin-project-lite", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72398e694b9f6dbb5de960cf158c8699e6a1854cb5bbaac7de0646b2005763c4" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.99", +] + [[package]] name = "roaring" version = "0.10.9" @@ -7399,6 +7441,7 @@ version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ + "chrono", "dyn-clone", "schemars_derive", "serde", diff --git a/Cargo.toml b/Cargo.toml index da1c0cd2b298..cad9d62a672a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ resolver = "2" [workspace.package] edition = "2021" -version = "1.0.35" +version = "1.1.0" authors = ["Block "] license = "Apache-2.0" repository = "https://github.com/block/goose" @@ -13,6 +13,9 @@ description = "An AI agent" [workspace.lints.clippy] uninlined_format_args = "allow" +[workspace.dependencies] +rmcp = { version = "0.2.1", features = ["schemars"] } + # Patch for Windows cross-compilation issue with crunchy [patch.crates-io] crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" } diff --git a/Justfile b/Justfile index 9641cf285d54..dbe99a53b350 100644 --- a/Justfile +++ b/Justfile @@ -147,6 +147,13 @@ run-ui-only: @echo "Running UI..." cd ui/desktop && npm install && npm run start-gui +debug-ui: + @echo "🚀 Starting Goose frontend in external backend mode" + cd ui/desktop && \ + export GOOSE_EXTERNAL_BACKEND=true && \ + export GOOSE_EXTERNAL_PORT=3000 && \ + npm install && \ + npm run start-gui # Run UI with alpha changes run-ui-alpha temporal="true": @@ -177,6 +184,13 @@ run-server: @echo "Running server..." cargo run -p goose-server +# Generate OpenAPI specification without starting the UI +generate-openapi: + @echo "Generating OpenAPI schema..." + cargo run -p goose-server --bin generate_schema + @echo "Generating frontend API..." + cd ui/desktop && npm run generate-api + # make GUI with latest binary lint-ui: cd ui/desktop && npm run lint:check @@ -274,12 +288,11 @@ install-deps: cd ui/desktop && npm install cd documentation && yarn -# ensure the current branch is "main" or error -ensure-main: +ensure-release-branch: #!/usr/bin/env bash branch=$(git rev-parse --abbrev-ref HEAD); \ - if [ "$branch" != "main" ]; then \ - echo "Error: You are not on the main branch (current: $branch)"; \ + if [[ ! "$branch" == release/* ]]; then \ + echo "Error: You are not on a release branch (current: $branch)"; \ exit 1; \ fi @@ -287,7 +300,7 @@ ensure-main: git fetch # @{u} refers to upstream branch of current branch if [ "$(git rev-parse HEAD)" != "$(git rev-parse @{u})" ]; then \ - echo "Error: Your branch is not up to date with the upstream main branch"; \ + echo "Error: Your branch is not up to date with the upstream branch"; \ echo " ensure your branch is up to date (git pull)"; \ exit 1; \ fi @@ -309,7 +322,7 @@ validate version: fi # set cargo and app versions, must be semver -release version: ensure-main +prepare-release version: @just validate {{ version }} || exit 1 @git switch -c "release/{{ version }}" @@ -327,8 +340,8 @@ release version: ensure-main get-tag-version: @uvx --from=toml-cli toml get --toml-path=Cargo.toml "workspace.package.version" -# create the git tag from Cargo.toml, must be on main -tag: ensure-main +# create the git tag from Cargo.toml, checking we're on a release branch +tag: ensure-release-branch git tag v$(just get-tag-version) # create tag and push to origin (use this when release branch is merged to main) @@ -337,9 +350,9 @@ tag-push: tag git push origin tag v$(just get-tag-version) # generate release notes from git commits -release-notes: +release-notes old: #!/usr/bin/env bash - git log --pretty=format:"- %s" v$(just get-tag-version)..HEAD + git log --pretty=format:"- %s" {{ old }}..v$(just get-tag-version) ### s = file seperator based on OS s := if os() == "windows" { "\\" } else { "/" } @@ -447,3 +460,4 @@ kotlin-example: -Djna.library.path=$HOME/Development/goose/target/debug \ -classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \ UsageKt + diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 000000000000..0911c672cda0 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,27 @@ +# Making a Release + +You'll generally create one of two release types: a regular feature release (minor version bump) or a bug-fixing patch release (patch version bump). Regular releases start on main, while patch releases start with an existing release tag. + +## Regular release from main + +1. Check out the main branch. +2. Pick the new version. Use a new minor version (e.g. if the current latest release is 1.2.3, use 1.3.0). Save it using `export VERSION=` +3. Run `just prepare-release $VERSION`. This will create a branch `release/`. Push this branch and open a PR into main. The diff should show version updates to Cargo.toml/package.json and their lock files. +4. Test this build. When ready to make the release, proceed to the next step. +5. Tag the release: run `just tag-push` to create the tag and push it. This will start the build process for your new release. +6. Merge the PR you created in step 2. +7. Once the release is created on [Github](https://github.com/block/goose/releases), run `just release-notes ` to generate release notes. Copy these into the release description. + +## Patch release + +Follow the above steps, but rather than starting on main, start on the release tag you're interested in patching. Increment the patch version instead of minor (e.g. 1.2.3 -> 1.2.4). Bug fixes should be merged to main and then cherry-picked onto this branch. + +1. Before proceeding, make sure any fixes you're looking to include in a patch are merged into main, if possible. +2. Check out the release you're patching using the tag (e.g `git checkout v1.3.0`). Set the version by incrementing the patch version (`export VERSION=1.3.1`). +3. Run `just prepare-release $VERSION`. +4. Cherry-pick the relevant fixes from the main branch. +5. Test this build. When ready to make the release, proceed to the next step. +6. Tag the release: run `just tag-push` to create the tag and push it. This will start the build process for your new release. +7. Once the release is created on [Github](https://github.com/block/goose/releases), run `just release-notes ` to generate release notes. Copy these into the release description. + +Note that you won't merge this branch into main. diff --git a/bindings/kotlin/example/RuntimeStats.kt b/bindings/kotlin/example/RuntimeStats.kt new file mode 100644 index 000000000000..688d382fb9c6 --- /dev/null +++ b/bindings/kotlin/example/RuntimeStats.kt @@ -0,0 +1,115 @@ +import kotlin.system.measureNanoTime +import kotlinx.coroutines.runBlocking +import uniffi.goose_llm.* + +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse + +/* ---------- Goose helpers ---------- */ + +fun buildProviderConfig(host: String, token: String): String = + """{ "host": "$host", "token": "$token" }""" + +suspend fun timeGooseCall( + modelCfg: ModelConfig, + providerName: String, + providerCfg: String +): Pair { + + val req = createCompletionRequest( + providerName, + providerCfg, + modelCfg, + systemPreamble = "You are a helpful assistant.", + messages = listOf( + Message( + Role.USER, + System.currentTimeMillis() / 1000, + listOf(MessageContent.Text(TextContent("Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."))) + ) + ), + extensions = emptyList() + ) + + lateinit var resp: CompletionResponse + val wallMs = measureNanoTime { resp = completion(req) } / 1_000_000.0 + return wallMs to resp +} + +/* ---------- OpenAI helpers ---------- */ + +fun timeOpenAiCall(client: HttpClient, apiKey: String): Double { + val body = """ + { + "model": "gpt-4.1", + "max_tokens": 500, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."} + ] + } + """.trimIndent() + + val request = HttpRequest.newBuilder() + .uri(URI.create("https://api.openai.com/v1/chat/completions")) + .header("Authorization", "Bearer $apiKey") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build() + + val wallMs = measureNanoTime { + client.send(request, HttpResponse.BodyHandlers.ofString()) + } / 1_000_000.0 + + return wallMs +} + +/* ---------- main ---------- */ + +fun main() = runBlocking { + /* Goose provider setup */ + val providerName = "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerCfg = buildProviderConfig(host, token) + + /* OpenAI setup */ + val openAiKey = System.getenv("OPENAI_API_KEY") ?: error("OPENAI_API_KEY not set") + val httpClient = HttpClient.newBuilder().build() + + val gooseModels = listOf("goose-claude-4-sonnet", "goose-gpt-4-1") + val runsPerModel = 3 + + /* --- Goose timing --- */ + for (model in gooseModels) { + val maxTokens = 500 + val cfg = ModelConfig(model, 100_000u, 0.0f, maxTokens) + var wallSum = 0.0 + var gooseSum = 0.0 + + println("=== Goose: $model ===") + repeat(runsPerModel) { run -> + val (wall, resp) = timeGooseCall(cfg, providerName, providerCfg) + val gooseMs = resp.runtimeMetrics.totalTimeSec * 1_000 + val overhead = wall - gooseMs + wallSum += wall + gooseSum += gooseMs + println("run ${run + 1}: wall = %.1f ms | goose-llm = %.1f ms | overhead = %.1f ms" + .format(wall, gooseMs, overhead)) + } + println("-- avg wall = %.1f ms | avg overhead = %.1f ms --\n" + .format(wallSum / runsPerModel, (wallSum - gooseSum) / runsPerModel)) + } + + /* --- OpenAI direct timing --- */ + var oaSum = 0.0 + println("=== OpenAI: gpt-4.1 (direct HTTPS) ===") + repeat(runsPerModel) { run -> + val wall = timeOpenAiCall(httpClient, openAiKey) + oaSum += wall + println("run ${run + 1}: wall = %.1f ms".format(wall)) + } + println("-- avg wall = %.1f ms --".format(oaSum / runsPerModel)) +} diff --git a/bindings/kotlin/example/Usage.kt b/bindings/kotlin/example/Usage.kt index cdb06c8211db..90ee002d9e99 100644 --- a/bindings/kotlin/example/Usage.kt +++ b/bindings/kotlin/example/Usage.kt @@ -1,292 +1,228 @@ +import java.io.File +import java.util.Base64 import kotlinx.coroutines.runBlocking import uniffi.goose_llm.* -fun main() = runBlocking { - val now = System.currentTimeMillis() / 1000 - val msgs = listOf( - // 1) User sends a plain-text prompt - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What is 7 x 6?") - ) - ) - ), - - // 2) Assistant makes a tool request (ToolReq) to calculate 7×6 - Message( - role = Role.ASSISTANT, - created = now + 2, - content = listOf( - MessageContent.ToolReq( - ToolRequest( - id = "calc1", - toolCall = """ - { - "status": "success", - "value": { - "name": "calculator_extension__toolname", - "arguments": { - "operation": "doesnotexist", - "numbers": [7, 6] - }, - "needsApproval": false - } - } - """.trimIndent() - ) - ) - ) - ), - - // 3) User (on behalf of the tool) responds with the tool result (ToolResp) - Message( - role = Role.USER, - created = now + 3, - content = listOf( - MessageContent.ToolResp( - ToolResponse( - id = "calc1", - toolResult = """ - { - "status": "error", - "error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add', 'subtract', 'multiply', 'divide']" - } - """.trimIndent() - ) - ) - ) - ), - - // 4) Assistant makes a tool request (ToolReq) to calculate 7×6 - Message( - role = Role.ASSISTANT, - created = now + 4, - content = listOf( - MessageContent.ToolReq( - ToolRequest( - id = "calc1", - toolCall = """ - { - "status": "success", - "value": { - "name": "calculator_extension__toolname", - "arguments": { - "operation": "multiply", - "numbers": [7, 6] - }, - "needsApproval": false - } - } - """.trimIndent() - ) - ) - ) - ), - - // 5) User (on behalf of the tool) responds with the tool result (ToolResp) - Message( - role = Role.USER, - created = now + 5, - content = listOf( - MessageContent.ToolResp( - ToolResponse( - id = "calc1", - toolResult = """ - { - "status": "success", - "value": [ - {"type": "text", "text": "42"} - ] - } - """.trimIndent() - ) - ) - ) - ), - ) - - printMessages(msgs) - println("---\n") - - // Setup provider - val providerName = "databricks" - val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") - val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") - val providerConfig = """{"host": "$host", "token": "$token"}""" - - println("Provider Name: $providerName") - println("Provider Config: $providerConfig") +/* ---------- shared helpers ---------- */ +fun buildProviderConfig(host: String, token: String, imageFormat: String = "OpenAi"): String = """ +{ + "host": "$host", + "token": "$token", + "image_format": "$imageFormat" +} +""".trimIndent() - val sessionName = generateSessionName(providerName, providerConfig, msgs) - println("\nSession Name: $sessionName") - - val tooltip = generateTooltip(providerName, providerConfig, msgs) - println("\nTooltip: $tooltip") - - // Completion - val modelName = "goose-gpt-4-1" - val modelConfig = ModelConfig( - modelName, - 100000u, // UInt - 0.1f, // Float - 200 // Int - ) - +fun calculatorExtension(): ExtensionConfig { val calculatorTool = createToolConfig( - name = "calculator", + name = "calculator", description = "Perform basic arithmetic operations", inputSchema = """ { - "type": "object", - "required": ["operation", "numbers"], - "properties": { - "operation": { - "type": "string", - "enum": ["add", "subtract", "multiply", "divide"], - "description": "The arithmetic operation to perform" - }, - "numbers": { - "type": "array", - "items": { "type": "number" }, - "description": "List of numbers to operate on in order" - } + "type": "object", + "required": ["operation", "numbers"], + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform" + }, + "numbers": { + "type": "array", + "items": { "type": "number" }, + "description": "List of numbers to operate on in order" } + } } """.trimIndent(), approvalMode = ToolApprovalMode.AUTO ) - - val calculator_extension = ExtensionConfig( - name = "calculator_extension", + return ExtensionConfig( + name = "calculator_extension", instructions = "This extension provides a calculator tool.", - tools = listOf(calculatorTool) + tools = listOf(calculatorTool) ) +} - val extensions = listOf(calculator_extension) - val systemPreamble = "You are a helpful assistant." +/* ---------- demos ---------- */ - // Testing with tool calls with an error in tool name - val reqToolErr = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble, - messages = listOf( - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What is 7 x 6?") - ) - ) - )), - extensions = extensions +suspend fun runCalculatorDemo( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val now = System.currentTimeMillis() / 1000 + val msgs = listOf( + // same conversation you already had + Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What is 7 x 6?")))), + Message(Role.ASSISTANT, now + 2, listOf(MessageContent.ToolReq( + ToolRequest( + id = "calc1", + toolCall = """ + { + "status": "success", + "value": { + "name": "calculator_extension__toolname", + "arguments": { "operation": "doesnotexist", "numbers": [7,6] }, + "needsApproval": false + } + } + """.trimIndent() + )))), + Message(Role.USER, now + 3, listOf(MessageContent.ToolResp( + ToolResponse( + id = "calc1", + toolResult = """ + { + "status": "error", + "error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add','subtract','multiply','divide']" + } + """.trimIndent() + )))), + Message(Role.ASSISTANT, now + 4, listOf(MessageContent.ToolReq( + ToolRequest( + id = "calc1", + toolCall = """ + { + "status": "success", + "value": { + "name": "calculator_extension__toolname", + "arguments": { "operation": "multiply", "numbers": [7,6] }, + "needsApproval": false + } + } + """.trimIndent() + )))), + Message(Role.USER, now + 5, listOf(MessageContent.ToolResp( + ToolResponse( + id = "calc1", + toolResult = """ + { + "status": "success", + "value": [ { "type": "text", "text": "42" } ] + } + """.trimIndent() + )))) ) - val respToolErr = completion(reqToolErr) - println("\nCompletion Response (one msg):\n${respToolErr.message}") - println() + /* one-shot prompt with error */ + val reqErr = createCompletionRequest( + providerName, providerConfig, modelConfig, + "You are a helpful assistant.", + messages = listOf(msgs.first()), + extensions = listOf(calculatorExtension()) + ) + println("\n[${modelConfig.modelName}] Calculator (single-msg) → ${completion(reqErr).message}") + /* full conversation */ val reqAll = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble, - messages = msgs, - extensions = extensions + providerName, providerConfig, modelConfig, + "You are a helpful assistant.", + messages = msgs, + extensions = listOf(calculatorExtension()) ) + println("[${modelConfig.modelName}] Calculator (full chat) → ${completion(reqAll).message}") +} - val respAll = completion(reqAll) - println("\nCompletion Response (all msgs):\n${respAll.message}") - println() - - // ---- UI Extraction (custom schema) ---- - runUiExtraction(providerName, providerConfig) +suspend fun runImageExample( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val imagePath = "../../crates/goose/examples/test_assets/test_image.png" + val base64Image = Base64.getEncoder().encodeToString(File(imagePath).readBytes()) + val now = System.currentTimeMillis() / 1000 - // --- Prompt Override --- - val prompt_req = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble = null, - systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.", - messages=listOf( - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What's your name?") - ) - ) - ) - ), - extensions=emptyList() + val msgs = listOf( + Message(Role.USER, now, listOf( + MessageContent.Text(TextContent("What is in this image?")), + MessageContent.Image(ImageContent(base64Image, "image/png")) + )), ) - val prompt_resp = completion(prompt_req) + val req = createCompletionRequest( + providerName, providerConfig, modelConfig, + "You are a helpful assistant. Please describe any text you see in the image.", + messages = msgs, + extensions = emptyList() + ) - println("\nPrompt Override Response:\n${prompt_resp.message}") + println("\n[${modelConfig.modelName}] Image example → ${completion(req).message}") } +suspend fun runPromptOverride( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val now = System.currentTimeMillis() / 1000 + val req = createCompletionRequest( + providerName, providerConfig, modelConfig, + systemPreamble = null, + systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.", + messages = listOf( + Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What's your name?")))) + ), + extensions = emptyList() + ) + println("\n[${modelConfig.modelName}] Prompt override → ${completion(req).message}") +} suspend fun runUiExtraction(providerName: String, providerConfig: String) { - val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI." + val schema = /* same JSON schema as before */ """ + { + "type":"object", + "properties":{ + "type":{"type":"string","enum":["div","button","header","section","field","form"]}, + "label":{"type":"string"}, + "children":{"type":"array","items":{"${'$'}ref":"#"}}, + "attributes":{"type":"array","items":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"string"}},"required":["name","value"],"additionalProperties":false}} + }, + "required":["type","label","children","attributes"], + "additionalProperties":false + } + """.trimIndent() + val messages = listOf( - Message( - role = Role.USER, - created = System.currentTimeMillis() / 1000, - content = listOf( - MessageContent.Text( - TextContent("Make a User Profile Form") - ) - ) - ) + Message(Role.USER, System.currentTimeMillis()/1000, + listOf(MessageContent.Text(TextContent("Make a User Profile Form")))) ) - val schema = """{ - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["div","button","header","section","field","form"] - }, - "label": { "type": "string" }, - "children": { - "type": "array", - "items": { "${'$'}ref": "#" } - }, - "attributes": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" }, - "value": { "type": "string" } - }, - "required": ["name","value"], - "additionalProperties": false - } - } - }, - "required": ["type","label","children","attributes"], - "additionalProperties": false - }""".trimIndent(); - try { - val response = generateStructuredOutputs( - providerName = providerName, - providerConfig = providerConfig, - systemPrompt = systemPrompt, - messages = messages, - schema = schema - ) - println("\nUI Extraction Output:\n${response}") - } catch (e: ProviderException) { - println("\nUI Extraction failed:\n${e.message}") + val res = generateStructuredOutputs( + providerName, providerConfig, + systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI.", + messages = messages, + schema = schema + ) + println("\n[UI-Extraction] → $res") +} + +/* ---------- entry-point ---------- */ + +fun main() = runBlocking { + /* --- provider setup --- */ + val providerName = "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerConfig = buildProviderConfig(host, token) + + println("Provider: $providerName") + println("Config : $providerConfig\n") + + /* --- run demos for each model --- */ + // NOTE: `claude-3-5-haiku` does NOT support images + val modelNames = listOf("kgoose-gpt-4o", "goose-claude-4-sonnet") + + for (name in modelNames) { + val modelConfig = ModelConfig(name, 100000u, 0.1f, 200) + println("\n===== Running demos for model: $name =====") + + runCalculatorDemo(modelConfig, providerName, providerConfig) + runImageExample(modelConfig, providerName, providerConfig) + runPromptOverride(modelConfig, providerName, providerConfig) + println("===== End demos for $name =====\n") } + + /* UI extraction is model-agnostic, so run it once */ + runUiExtraction(providerName, providerConfig) } diff --git a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt index 76e60aaf7441..f01956947261 100644 --- a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt +++ b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt @@ -833,6 +833,7 @@ internal interface UniffiLib : Library { `systemPromptOverride`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `extensions`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, uniffi_out_err: UniffiRustCallStatus, ): RustBuffer.ByValue @@ -848,6 +849,7 @@ internal interface UniffiLib : Library { `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -856,12 +858,14 @@ internal interface UniffiLib : Library { `systemPrompt`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `schema`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_tooltip( `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_print_messages( @@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) { if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) { @@ -2960,6 +2964,7 @@ fun `createCompletionRequest`( `systemPromptOverride`: kotlin.String? = null, `messages`: List, `extensions`: List, + `requestId`: kotlin.String? = null, ): CompletionRequest = FfiConverterTypeCompletionRequest.lift( uniffiRustCall { _status -> @@ -2971,6 +2976,7 @@ fun `createCompletionRequest`( FfiConverterOptionalString.lower(`systemPromptOverride`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterSequenceTypeExtensionConfig.lower(`extensions`), + FfiConverterOptionalString.lower(`requestId`), _status, ) }, @@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`( `systemPrompt`: kotlin.String, `messages`: List, `schema`: Value, + `requestId`: kotlin.String? = null, ): ProviderExtractResponse = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`( FfiConverterString.lower(`systemPrompt`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterTypeValue.lower(`schema`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, diff --git a/crates/goose-bench/Cargo.toml b/crates/goose-bench/Cargo.toml index ba436de56e60..05bd50fb12e6 100644 --- a/crates/goose-bench/Cargo.toml +++ b/crates/goose-bench/Cargo.toml @@ -10,13 +10,13 @@ description.workspace = true [lints] workspace = true - [dependencies] anyhow = "1.0" paste = "1.0" ctor = "0.2.7" goose = { path = "../goose" } mcp-core = { path = "../mcp-core" } +rmcp = { workspace = true } async-trait = "0.1.86" chrono = { version = "0.4", features = ["serde"] } serde_json = "1.0" diff --git a/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs b/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs index 4b92805fbde4..b42243dfd2d7 100644 --- a/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs +++ b/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs @@ -9,7 +9,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs b/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs index 60b3c205200e..04fba60f7586 100644 --- a/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs +++ b/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs @@ -9,7 +9,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/core/developer/create_file.rs b/crates/goose-bench/src/eval_suites/core/developer/create_file.rs index 5cfe895d0d66..154319c38316 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/create_file.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/create_file.rs @@ -9,7 +9,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/core/developer/list_files.rs b/crates/goose-bench/src/eval_suites/core/developer/list_files.rs index 8881b690d4b3..8aea32cc5b14 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/list_files.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/list_files.rs @@ -7,7 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs b/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs index dd280fbf380d..ffa8541e6191 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs @@ -7,7 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/core/developer_image/image.rs b/crates/goose-bench/src/eval_suites/core/developer_image/image.rs index 34e43496cd51..2ac8a8ce88f8 100644 --- a/crates/goose-bench/src/eval_suites/core/developer_image/image.rs +++ b/crates/goose-bench/src/eval_suites/core/developer_image/image.rs @@ -7,8 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::content::Content; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] @@ -68,7 +67,7 @@ impl Evaluation for DeveloperImage { if let Ok(result) = &tool_resp.tool_result { // Check each item in the result list for item in result { - if let Content::Image(image) = item { + if let Some(image) = item.as_image() { // Image content already contains mime_type and data if image.mime_type.starts_with("image/") && !image.data.is_empty() diff --git a/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs b/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs index 91bebeea228c..4e3184e42af6 100644 --- a/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs +++ b/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs @@ -9,7 +9,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; #[derive(Debug)] diff --git a/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs b/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs index 9cbca88a022f..edd2f4a52424 100644 --- a/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs +++ b/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs @@ -7,7 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; use std::fs; diff --git a/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs b/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs index 480989e9d0ba..2609584cf890 100644 --- a/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs +++ b/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs @@ -7,7 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; use std::fs; diff --git a/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs b/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs index 84062b721dce..fd628a6544bb 100644 --- a/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs +++ b/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs @@ -7,7 +7,7 @@ use crate::eval_suites::{ use crate::register_evaluation; use async_trait::async_trait; use goose::message::MessageContent; -use mcp_core::role::Role; +use rmcp::model::Role; use serde_json::{self, Value}; pub struct SquirrelCensus {} diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 0bcfc5506b0c..88aa02668582 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -21,6 +21,7 @@ goose-mcp = { path = "../goose-mcp" } mcp-client = { path = "../mcp-client" } mcp-server = { path = "../mcp-server" } mcp-core = { path = "../mcp-core" } +rmcp = { workspace = true } clap = { version = "4.4", features = ["derive"] } cliclack = "0.3.5" console = "0.15.8" @@ -70,7 +71,6 @@ http = "1.0" webbrowser = "1.0" indicatif = "0.17.11" -urlencoding = "2" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 44b18eae2c43..2b16521d8b3c 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -8,7 +8,7 @@ use crate::commands::configure::handle_configure; use crate::commands::info::handle_info; use crate::commands::mcp::run_server; use crate::commands::project::{handle_project_default, handle_projects_interactive}; -use crate::commands::recipe::{handle_deeplink, handle_validate}; +use crate::commands::recipe::{handle_deeplink, handle_list, handle_validate}; // Import the new handlers from commands::schedule use crate::commands::schedule::{ handle_schedule_add, handle_schedule_cron_help, handle_schedule_list, handle_schedule_remove, @@ -18,7 +18,7 @@ use crate::commands::schedule::{ use crate::commands::session::{handle_session_list, handle_session_remove}; use crate::logging::setup_logging; use crate::recipes::extract_from_cli::extract_recipe_info_from_cli; -use crate::recipes::recipe::{explain_recipe_with_parameters, load_recipe_content_as_template}; +use crate::recipes::recipe::{explain_recipe, render_recipe_as_yaml}; use crate::session; use crate::session::{build_session, SessionBuilderConfig, SessionSettings}; use goose_bench::bench_config::BenchRunConfig; @@ -246,6 +246,27 @@ enum RecipeCommand { )] recipe_name: String, }, + + /// List available recipes + #[command(about = "List available recipes")] + List { + /// Output format (text, json) + #[arg( + long = "format", + value_name = "FORMAT", + help = "Output format (text, json)", + default_value = "text" + )] + format: String, + + /// Show verbose information including recipe descriptions + #[arg( + short, + long, + help = "Show verbose information including recipe descriptions" + )] + verbose: bool, + }, } #[derive(Subcommand)] @@ -441,7 +462,7 @@ enum Command { long = "no-session", help = "Run without storing a session file", long_help = "Execute commands without creating or using a session file. Useful for automated runs.", - conflicts_with_all = ["resume", "name", "path"] + conflicts_with_all = ["resume", "name", "path"] )] no_session: bool, @@ -860,16 +881,14 @@ pub async fn cli() -> Result<()> { ), (_, _, Some(recipe_name)) => { if explain { - explain_recipe_with_parameters(&recipe_name, params)?; + explain_recipe(&recipe_name, params)?; return Ok(()); } if render_recipe { - let recipe = load_recipe_content_as_template(&recipe_name, params) - .unwrap_or_else(|err| { - eprintln!("{}: {}", console::style("Error").red().bold(), err); - std::process::exit(1); - }); - println!("{}", recipe); + if let Err(err) = render_recipe_as_yaml(&recipe_name, params) { + eprintln!("{}: {}", console::style("Error").red().bold(), err); + std::process::exit(1); + } return Ok(()); } extract_recipe_info_from_cli(recipe_name, params, additional_sub_recipes)? @@ -995,6 +1014,9 @@ pub async fn cli() -> Result<()> { RecipeCommand::Deeplink { recipe_name } => { handle_deeplink(&recipe_name)?; } + RecipeCommand::List { format, verbose } => { + handle_list(&format, verbose)?; + } } return Ok(()); } diff --git a/crates/goose-cli/src/commands/mcp.rs b/crates/goose-cli/src/commands/mcp.rs index 8fbd399f06d9..f70bd8c6f2ee 100644 --- a/crates/goose-cli/src/commands/mcp.rs +++ b/crates/goose-cli/src/commands/mcp.rs @@ -1,7 +1,6 @@ use anyhow::Result; use goose_mcp::{ - ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, JetBrainsRouter, MemoryRouter, - TutorialRouter, + ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, MemoryRouter, TutorialRouter, }; use mcp_server::router::RouterService; use mcp_server::{BoundedService, ByteTransport, Server}; @@ -26,7 +25,6 @@ pub async fn run_server(name: &str) -> Result<()> { let router: Option> = match name { "developer" => Some(Box::new(RouterService(DeveloperRouter::new()))), "computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))), - "jetbrains" => Some(Box::new(RouterService(JetBrainsRouter::new()))), "google_drive" | "googledrive" => { let router = GoogleDriveRouter::new().await; Some(Box::new(RouterService(router))) diff --git a/crates/goose-cli/src/commands/project.rs b/crates/goose-cli/src/commands/project.rs index 17e63e412baf..e049e66e323e 100644 --- a/crates/goose-cli/src/commands/project.rs +++ b/crates/goose-cli/src/commands/project.rs @@ -4,7 +4,7 @@ use cliclack::{self, intro, outro}; use std::path::Path; use crate::project_tracker::ProjectTracker; -use crate::utils::safe_truncate; +use goose::utils::safe_truncate; /// Format a DateTime for display fn format_date(date: DateTime) -> String { diff --git a/crates/goose-cli/src/commands/recipe.rs b/crates/goose-cli/src/commands/recipe.rs index 332e693e4c9a..3f1db9e2d982 100644 --- a/crates/goose-cli/src/commands/recipe.rs +++ b/crates/goose-cli/src/commands/recipe.rs @@ -1,8 +1,10 @@ use anyhow::Result; -use base64::Engine; use console::style; -use crate::recipes::recipe::load_recipe; +use crate::recipes::github_recipe::RecipeSource; +use crate::recipes::recipe::load_recipe_for_validation; +use crate::recipes::search_recipe::list_available_recipes; +use goose::recipe_deeplink; /// Validates a recipe file /// @@ -15,7 +17,7 @@ use crate::recipes::recipe::load_recipe; /// Result indicating success or failure pub fn handle_validate(recipe_name: &str) -> Result<()> { // Load and validate the recipe file - match load_recipe(recipe_name) { + match load_recipe_for_validation(recipe_name) { Ok(_) => { println!("{} recipe file is valid", style("✓").green().bold()); Ok(()) @@ -38,22 +40,27 @@ pub fn handle_validate(recipe_name: &str) -> Result<()> { /// Result indicating success or failure pub fn handle_deeplink(recipe_name: &str) -> Result { // Load the recipe file first to validate it - match load_recipe(recipe_name) { - Ok(recipe) => { - let mut full_url = String::new(); - if let Ok(recipe_json) = serde_json::to_string(&recipe) { - let deeplink = base64::engine::general_purpose::STANDARD.encode(recipe_json); + match load_recipe_for_validation(recipe_name) { + Ok(recipe) => match recipe_deeplink::encode(&recipe) { + Ok(encoded) => { println!( "{} Generated deeplink for: {}", style("✓").green().bold(), recipe.title ); - let url_safe = urlencoding::encode(&deeplink); - full_url = format!("goose://recipe?config={}", url_safe); + let full_url = format!("goose://recipe?config={}", encoded); println!("{}", full_url); + Ok(full_url) } - Ok(full_url) - } + Err(err) => { + println!( + "{} Failed to encode recipe: {}", + style("✗").red().bold(), + err + ); + Err(anyhow::anyhow!("Failed to encode recipe: {}", err)) + } + }, Err(err) => { println!("{} {}", style("✗").red().bold(), err); Err(err) @@ -61,6 +68,67 @@ pub fn handle_deeplink(recipe_name: &str) -> Result { } } +/// Lists all available recipes from local paths and GitHub repositories +/// +/// # Arguments +/// +/// * `format` - Output format ("text" or "json") +/// * `verbose` - Whether to show detailed information +/// +/// # Returns +/// +/// Result indicating success or failure +pub fn handle_list(format: &str, verbose: bool) -> Result<()> { + let recipes = match list_available_recipes() { + Ok(recipes) => recipes, + Err(e) => { + return Err(anyhow::anyhow!("Failed to list recipes: {}", e)); + } + }; + + match format { + "json" => { + println!("{}", serde_json::to_string(&recipes)?); + } + _ => { + if recipes.is_empty() { + println!("No recipes found"); + return Ok(()); + } else { + println!("Available recipes:"); + for recipe in recipes { + let source_info = match recipe.source { + RecipeSource::Local => format!("local: {}", recipe.path), + RecipeSource::GitHub => format!("github: {}", recipe.path), + }; + + let description = if let Some(desc) = &recipe.description { + if desc.is_empty() { + "(none)" + } else { + desc + } + } else { + "(none)" + }; + + let output = format!("{} - {} - {}", recipe.name, description, source_info); + if verbose { + println!(" {}", output); + if let Some(title) = &recipe.title { + println!(" Title: {}", title); + } + println!(" Path: {}", recipe.path); + } else { + println!("{}", output); + } + } + } + } + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -121,7 +189,10 @@ response: let result = handle_deeplink(&recipe_path); assert!(result.is_ok()); - assert!(result.unwrap().contains("goose://recipe?config=eyJ2ZXJzaW9uIjoiMS4wLjAiLCJ0aXRsZSI6IlRlc3QgUmVjaXBlIHdpdGggVmFsaWQgSlNPTiBTY2hlbWEiLCJkZXNjcmlwdGlvbiI6IkEgdGVzdCByZWNpcGUgd2l0aCB2YWxpZCBKU09OIHNjaGVtYSIsImluc3RydWN0aW9ucyI6IlRlc3QgaW5zdHJ1Y3Rpb25zIiwicHJvbXB0IjoiVGVzdCBwcm9tcHQgY29udGVudCIsInJlc3BvbnNlIjp7Impzb25fc2NoZW1hIjp7InByb3BlcnRpZXMiOnsiY291bnQiOnsiZGVzY3JpcHRpb24iOiJBIGNvdW50IHZhbHVlIiwidHlwZSI6Im51bWJlciJ9LCJyZXN1bHQiOnsiZGVzY3JpcHRpb24iOiJUaGUgcmVzdWx0IiwidHlwZSI6InN0cmluZyJ9fSwicmVxdWlyZWQiOlsicmVzdWx0Il0sInR5cGUiOiJvYmplY3QifX19")); + let url = result.unwrap(); + assert!(url.starts_with("goose://recipe?config=")); + let encoded_part = url.strip_prefix("goose://recipe?config=").unwrap(); + assert!(encoded_part.len() > 0); } #[test] diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index fbb862482ad0..20a0f3d2f378 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,9 +1,9 @@ use crate::session::message_to_markdown; -use crate::utils::safe_truncate; use anyhow::{Context, Result}; use cliclack::{confirm, multiselect, select}; use goose::session::info::{get_valid_sorted_sessions, SessionInfo, SortOrder}; use goose::session::{self, Identifier}; +use goose::utils::safe_truncate; use regex::Regex; use std::fs; use std::path::{Path, PathBuf}; @@ -241,7 +241,7 @@ fn export_session_to_markdown( for message in &messages { // Check if this is a User message containing only ToolResponses - let is_only_tool_response = message.role == mcp_core::role::Role::User + let is_only_tool_response = message.role == rmcp::model::Role::User && message .content .iter() @@ -263,8 +263,8 @@ fn export_session_to_markdown( // Output the role prefix except for tool response-only messages if !is_only_tool_response { let role_prefix = match message.role { - mcp_core::role::Role::User => "### User:\n", - mcp_core::role::Role::Assistant => "### Assistant:\n", + rmcp::model::Role::User => "### User:\n", + rmcp::model::Role::Assistant => "### Assistant:\n", }; markdown_output.push_str(role_prefix); } diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 72b5d32bda9d..e61975ad1ada 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -475,7 +475,14 @@ async fn process_message_streaming( } let provider = provider.unwrap(); - session::persist_messages(&session_file, &messages, Some(provider.clone())).await?; + let working_dir = Some(std::env::current_dir()?); + session::persist_messages( + &session_file, + &messages, + Some(provider.clone()), + working_dir.clone(), + ) + .await?; // Create a session config let session_config = SessionConfig { @@ -503,7 +510,13 @@ async fn process_message_streaming( let session_msgs = session_messages.lock().await; session_msgs.clone() }; - session::persist_messages(&session_file, ¤t_messages, None).await?; + session::persist_messages( + &session_file, + ¤t_messages, + None, + working_dir.clone(), + ) + .await?; // Handle different message content types for content in &message.content { match content { diff --git a/crates/goose-cli/src/lib.rs b/crates/goose-cli/src/lib.rs index 055f38b9033a..68f2357f5ee0 100644 --- a/crates/goose-cli/src/lib.rs +++ b/crates/goose-cli/src/lib.rs @@ -7,7 +7,6 @@ pub mod project_tracker; pub mod recipes; pub mod session; pub mod signal; -pub mod utils; // Re-export commonly used types pub use session::Session; diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 199c7091e3a6..dff70fcadfb7 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -3,8 +3,10 @@ use std::path::PathBuf; use anyhow::{anyhow, Result}; use goose::recipe::{Response, SubRecipe}; +use crate::recipes::print_recipe::print_recipe_info; +use crate::recipes::recipe::load_recipe; use crate::recipes::search_recipe::retrieve_recipe_file; -use crate::{cli::InputConfig, recipes::recipe::load_recipe_as_template, session::SessionSettings}; +use crate::{cli::InputConfig, session::SessionSettings}; #[allow(clippy::type_complexity)] pub fn extract_recipe_info_from_cli( @@ -17,10 +19,11 @@ pub fn extract_recipe_info_from_cli( Option>, Option, )> { - let recipe = load_recipe_as_template(&recipe_name, params).unwrap_or_else(|err| { + let recipe = load_recipe(&recipe_name, params.clone()).unwrap_or_else(|err| { eprintln!("{}: {}", console::style("Error").red().bold(), err); std::process::exit(1); }); + print_recipe_info(&recipe, params); let mut all_sub_recipes = recipe.sub_recipes.clone().unwrap_or_default(); if !additional_sub_recipes.is_empty() { for sub_recipe_name in additional_sub_recipes { @@ -32,6 +35,7 @@ pub fn extract_recipe_info_from_cli( path: recipe_file_path.to_string_lossy().to_string(), name, values: None, + sequential_when_repeated: true, }; all_sub_recipes.push(additional_sub_recipe); } diff --git a/crates/goose-cli/src/recipes/github_recipe.rs b/crates/goose-cli/src/recipes/github_recipe.rs index 2763d24fedf7..e7f92855845e 100644 --- a/crates/goose-cli/src/recipes/github_recipe.rs +++ b/crates/goose-cli/src/recipes/github_recipe.rs @@ -1,5 +1,10 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use console::style; +use goose::recipe::template_recipe::parse_recipe_content; +use serde::{Deserialize, Serialize}; + +use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS; +use goose::recipe::read_recipe_file_content::RecipeFile; use std::env; use std::fs; use std::path::Path; @@ -8,8 +13,20 @@ use std::process::Command; use std::process::Stdio; use tar::Archive; -use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS; -use crate::recipes::search_recipe::RecipeFile; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RecipeInfo { + pub name: String, + pub source: RecipeSource, + pub path: String, + pub title: Option, + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RecipeSource { + Local, + GitHub, +} pub const GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY: &str = "GOOSE_RECIPE_GITHUB_REPO"; pub fn retrieve_recipe_from_github( @@ -83,7 +100,7 @@ fn clone_and_download_recipe(recipe_name: &str, recipe_repo_full_name: &str) -> get_folder_from_github(&local_repo_path, recipe_name) } -fn ensure_gh_authenticated() -> Result<()> { +pub fn ensure_gh_authenticated() -> Result<()> { // Check authentication status let status = Command::new("gh") .args(["auth", "status"]) @@ -197,3 +214,133 @@ fn list_files(dir: &Path) -> Result<()> { } Ok(()) } + +/// Lists all available recipes from a GitHub repository +pub fn list_github_recipes(repo: &str) -> Result> { + discover_github_recipes(repo) +} + +fn discover_github_recipes(repo: &str) -> Result> { + use serde_json::Value; + use std::process::Command; + + // Ensure GitHub CLI is authenticated + ensure_gh_authenticated()?; + + // Get repository contents using GitHub CLI + let output = Command::new("gh") + .args(["api", &format!("repos/{}/contents", repo)]) + .output() + .map_err(|e| anyhow!("Failed to fetch repository contents using 'gh api' command (executed when GOOSE_RECIPE_GITHUB_REPO is configured). This requires GitHub CLI (gh) to be installed and authenticated. Error: {}", e))?; + + if !output.status.success() { + let error_msg = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("GitHub API request failed: {}", error_msg)); + } + + let contents: Value = serde_json::from_slice(&output.stdout) + .map_err(|e| anyhow!("Failed to parse GitHub API response: {}", e))?; + + let mut recipes = Vec::new(); + + if let Some(items) = contents.as_array() { + for item in items { + if let (Some(name), Some(item_type)) = ( + item.get("name").and_then(|n| n.as_str()), + item.get("type").and_then(|t| t.as_str()), + ) { + if item_type == "dir" { + // Check if this directory contains a recipe file + if let Ok(recipe_info) = check_github_directory_for_recipe(repo, name) { + recipes.push(recipe_info); + } + } + } + } + } + + Ok(recipes) +} + +fn check_github_directory_for_recipe(repo: &str, dir_name: &str) -> Result { + use serde_json::Value; + use std::process::Command; + + // Check directory contents for recipe files + let output = Command::new("gh") + .args(["api", &format!("repos/{}/contents/{}", repo, dir_name)]) + .output() + .map_err(|e| anyhow!("Failed to check directory contents: {}", e))?; + + if !output.status.success() { + return Err(anyhow!("Failed to access directory: {}", dir_name)); + } + + let contents: Value = serde_json::from_slice(&output.stdout) + .map_err(|e| anyhow!("Failed to parse directory contents: {}", e))?; + + if let Some(items) = contents.as_array() { + for item in items { + if let Some(name) = item.get("name").and_then(|n| n.as_str()) { + if RECIPE_FILE_EXTENSIONS + .iter() + .any(|ext| name == format!("recipe.{}", ext)) + { + // Found a recipe file, get its content + return get_github_recipe_info(repo, dir_name, name); + } + } + } + } + + Err(anyhow!("No recipe file found in directory: {}", dir_name)) +} + +fn get_github_recipe_info(repo: &str, dir_name: &str, recipe_filename: &str) -> Result { + use serde_json::Value; + use std::process::Command; + + // Get the recipe file content + let output = Command::new("gh") + .args([ + "api", + &format!("repos/{}/contents/{}/{}", repo, dir_name, recipe_filename), + ]) + .output() + .map_err(|e| anyhow!("Failed to get recipe file content: {}", e))?; + + if !output.status.success() { + return Err(anyhow!( + "Failed to access recipe file: {}/{}", + dir_name, + recipe_filename + )); + } + + let file_info: Value = serde_json::from_slice(&output.stdout) + .map_err(|e| anyhow!("Failed to parse file info: {}", e))?; + + if let Some(content_b64) = file_info.get("content").and_then(|c| c.as_str()) { + // Decode base64 content + use base64::{engine::general_purpose, Engine as _}; + let content_bytes = general_purpose::STANDARD + .decode(content_b64.replace('\n', "")) + .map_err(|e| anyhow!("Failed to decode base64 content: {}", e))?; + + let content = String::from_utf8(content_bytes) + .map_err(|e| anyhow!("Failed to convert content to string: {}", e))?; + + // Parse the recipe content + let (recipe, _) = parse_recipe_content(&content, format!("{}/{}", repo, dir_name))?; + + return Ok(RecipeInfo { + name: dir_name.to_string(), + source: RecipeSource::GitHub, + path: format!("{}/{}", repo, dir_name), + title: Some(recipe.title), + description: Some(recipe.description), + }); + } + + Err(anyhow!("Failed to get recipe content from GitHub")) +} diff --git a/crates/goose-cli/src/recipes/mod.rs b/crates/goose-cli/src/recipes/mod.rs index a1ceef614213..8dd270e20f54 100644 --- a/crates/goose-cli/src/recipes/mod.rs +++ b/crates/goose-cli/src/recipes/mod.rs @@ -3,4 +3,3 @@ pub mod github_recipe; pub mod print_recipe; pub mod recipe; pub mod search_recipe; -pub mod template_recipe; diff --git a/crates/goose-cli/src/recipes/print_recipe.rs b/crates/goose-cli/src/recipes/print_recipe.rs index bc43842f8db0..be4a2339875f 100644 --- a/crates/goose-cli/src/recipes/print_recipe.rs +++ b/crates/goose-cli/src/recipes/print_recipe.rs @@ -1,9 +1,7 @@ use std::collections::HashMap; use console::style; -use goose::recipe::Recipe; - -use crate::recipes::recipe::BUILT_IN_RECIPE_DIR_PARAM; +use goose::recipe::{Recipe, BUILT_IN_RECIPE_DIR_PARAM}; pub fn print_recipe_explanation(recipe: &Recipe) { println!( @@ -81,3 +79,18 @@ pub fn missing_parameters_command_line(missing_params: Vec) -> String { .collect::>() .join(" ") } + +pub fn print_recipe_info(recipe: &Recipe, params: Vec<(String, String)>) { + println!( + "{} {}", + style("Loading recipe:").green().bold(), + style(&recipe.title).green() + ); + println!("{} {}", style("Description:").bold(), &recipe.description); + + if !params.is_empty() { + println!("{}", style("Parameters used to load this recipe:").bold()); + print_parameters_with_values(params.into_iter().collect()); + } + println!(); +} diff --git a/crates/goose-cli/src/recipes/recipe.rs b/crates/goose-cli/src/recipes/recipe.rs index 49b0e6bf7210..0596f72a63b6 100644 --- a/crates/goose-cli/src/recipes/recipe.rs +++ b/crates/goose-cli/src/recipes/recipe.rs @@ -1,89 +1,69 @@ use crate::recipes::print_recipe::{ - missing_parameters_command_line, print_parameters_with_values, print_recipe_explanation, + missing_parameters_command_line, print_recipe_explanation, print_required_parameters_for_template, }; -use crate::recipes::search_recipe::{retrieve_recipe_file, RecipeFile}; -use crate::recipes::template_recipe::{ - parse_recipe_content, render_recipe_content_with_params, render_recipe_for_preview, -}; +use crate::recipes::search_recipe::retrieve_recipe_file; use anyhow::Result; -use console::style; -use goose::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement}; -use std::collections::{HashMap, HashSet}; +use goose::recipe::build_recipe::{ + apply_values_to_parameters, build_recipe_from_template, validate_recipe_parameters, RecipeError, +}; +use goose::recipe::read_recipe_file_content::RecipeFile; +use goose::recipe::template_recipe::render_recipe_for_preview; +use goose::recipe::Recipe; +use std::collections::HashMap; -pub const BUILT_IN_RECIPE_DIR_PARAM: &str = "recipe_dir"; pub const RECIPE_FILE_EXTENSIONS: &[&str] = &["yaml", "json"]; -pub fn load_recipe_content_as_template( - recipe_name: &str, - params: Vec<(String, String)>, -) -> Result { - let RecipeFile { - content: recipe_file_content, - parent_dir: recipe_parent_dir, - .. - } = retrieve_recipe_file(recipe_name)?; - let recipe_dir_str = recipe_parent_dir - .to_str() - .ok_or_else(|| anyhow::anyhow!("Error getting recipe directory"))?; - let recipe_parameters = validate_recipe_parameters(&recipe_file_content, recipe_dir_str)?; +fn create_user_prompt_callback() -> impl Fn(&str, &str) -> Result { + |key: &str, description: &str| -> Result { + let input_value = + cliclack::input(format!("Please enter {} ({})", key, description)).interact()?; + Ok(input_value) + } +} - let (params_for_template, missing_params) = - apply_values_to_parameters(¶ms, recipe_parameters, recipe_dir_str, true)?; +fn load_recipe_file_with_dir(recipe_name: &str) -> Result<(RecipeFile, String)> { + let recipe_file = retrieve_recipe_file(recipe_name)?; + let recipe_dir_str = recipe_file + .parent_dir + .to_str() + .ok_or_else(|| anyhow::anyhow!("Error getting recipe directory"))? + .to_string(); + Ok((recipe_file, recipe_dir_str)) +} - if !missing_params.is_empty() { - return Err(anyhow::anyhow!( +pub fn load_recipe(recipe_name: &str, params: Vec<(String, String)>) -> Result { + let recipe_file = retrieve_recipe_file(recipe_name)?; + match build_recipe_from_template(recipe_file, params, Some(create_user_prompt_callback())) { + Ok(recipe) => Ok(recipe), + Err(RecipeError::MissingParams { parameters }) => Err(anyhow::anyhow!( "Please provide the following parameters in the command line: {}", - missing_parameters_command_line(missing_params) - )); + missing_parameters_command_line(parameters) + )), + Err(e) => Err(anyhow::anyhow!(e.to_string())), } - render_recipe_content_with_params(&recipe_file_content, ¶ms_for_template) -} - -fn validate_recipe_parameters( - recipe_file_content: &str, - recipe_dir_str: &str, -) -> Result>> { - let (raw_recipe, template_variables) = - parse_recipe_content(recipe_file_content, recipe_dir_str.to_string())?; - let recipe_parameters = raw_recipe.parameters; - validate_optional_parameters(&recipe_parameters)?; - validate_parameters_in_template(&recipe_parameters, &template_variables)?; - Ok(recipe_parameters) } -pub fn load_recipe_as_template(recipe_name: &str, params: Vec<(String, String)>) -> Result { - let rendered_content = load_recipe_content_as_template(recipe_name, params.clone())?; - let recipe = Recipe::from_content(&rendered_content)?; - - // Display information about the loaded recipe - println!( - "{} {}", - style("Loading recipe:").green().bold(), - style(&recipe.title).green() - ); - println!("{} {}", style("Description:").bold(), &recipe.description); - - if !params.is_empty() { - println!("{}", style("Parameters used to load this recipe:").bold()); - print_parameters_with_values(params.into_iter().collect()); +pub fn render_recipe_as_yaml(recipe_name: &str, params: Vec<(String, String)>) -> Result<()> { + let recipe = load_recipe(recipe_name, params)?; + match serde_yaml::to_string(&recipe) { + Ok(yaml_content) => { + println!("{}", yaml_content); + Ok(()) + } + Err(_) => { + eprintln!("Failed to serialize recipe to YAML"); + std::process::exit(1); + } } - println!(); - Ok(recipe) } -pub fn load_recipe(recipe_name: &str) -> Result { - let RecipeFile { - content: recipe_file_content, - parent_dir: recipe_parent_dir, - .. - } = retrieve_recipe_file(recipe_name)?; - let recipe_dir_str = recipe_parent_dir - .to_str() - .ok_or_else(|| anyhow::anyhow!("Error getting recipe directory"))?; - validate_recipe_parameters(&recipe_file_content, recipe_dir_str)?; +pub fn load_recipe_for_validation(recipe_name: &str) -> Result { + let (recipe_file, recipe_dir_str) = load_recipe_file_with_dir(recipe_name)?; + let recipe_file_content = &recipe_file.content; + validate_recipe_parameters(recipe_file_content, &recipe_dir_str)?; let recipe = render_recipe_for_preview( - &recipe_file_content, + recipe_file_content, recipe_dir_str.to_string(), &HashMap::new(), )?; @@ -97,24 +77,19 @@ pub fn load_recipe(recipe_name: &str) -> Result { Ok(recipe) } -pub fn explain_recipe_with_parameters( - recipe_name: &str, - params: Vec<(String, String)>, -) -> Result<()> { - let RecipeFile { - content: recipe_file_content, - parent_dir: recipe_parent_dir, - .. - } = retrieve_recipe_file(recipe_name)?; - let recipe_dir_str = recipe_parent_dir - .to_str() - .ok_or_else(|| anyhow::anyhow!("Error getting recipe directory"))?; - let recipe_parameters = validate_recipe_parameters(&recipe_file_content, recipe_dir_str)?; +pub fn explain_recipe(recipe_name: &str, params: Vec<(String, String)>) -> Result<()> { + let (recipe_file, recipe_dir_str) = load_recipe_file_with_dir(recipe_name)?; + let recipe_file_content = &recipe_file.content; + let recipe_parameters = validate_recipe_parameters(recipe_file_content, &recipe_dir_str)?; - let (params_for_template, missing_params) = - apply_values_to_parameters(¶ms, recipe_parameters, recipe_dir_str, false)?; + let (params_for_template, missing_params) = apply_values_to_parameters( + ¶ms, + recipe_parameters, + &recipe_dir_str, + None:: Result>, + )?; let recipe = render_recipe_for_preview( - &recipe_file_content, + recipe_file_content, recipe_dir_str.to_string(), ¶ms_for_template, )?; @@ -124,110 +99,6 @@ pub fn explain_recipe_with_parameters( Ok(()) } -fn validate_parameters_in_template( - recipe_parameters: &Option>, - template_variables: &HashSet, -) -> Result<()> { - let mut template_variables = template_variables.clone(); - template_variables.remove(BUILT_IN_RECIPE_DIR_PARAM); - - let param_keys: HashSet = recipe_parameters - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|p| p.key.clone()) - .collect(); - - let missing_keys = template_variables - .difference(¶m_keys) - .collect::>(); - - let extra_keys = param_keys - .difference(&template_variables) - .collect::>(); - - if missing_keys.is_empty() && extra_keys.is_empty() { - return Ok(()); - } - - let mut message = String::new(); - - if !missing_keys.is_empty() { - message.push_str(&format!( - "Missing definitions for parameters in the recipe file: {}.", - missing_keys - .iter() - .map(|s| s.to_string()) - .collect::>() - .join(", ") - )); - } - - if !extra_keys.is_empty() { - message.push_str(&format!( - "\nUnnecessary parameter definitions: {}.", - extra_keys - .iter() - .map(|s| s.to_string()) - .collect::>() - .join(", ") - )); - } - Err(anyhow::anyhow!("{}", message.trim_end())) -} - -fn validate_optional_parameters(parameters: &Option>) -> Result<()> { - let optional_params_without_default_values: Vec = parameters - .as_ref() - .unwrap_or(&vec![]) - .iter() - .filter(|p| { - matches!(p.requirement, RecipeParameterRequirement::Optional) && p.default.is_none() - }) - .map(|p| p.key.clone()) - .collect(); - - if optional_params_without_default_values.is_empty() { - Ok(()) - } else { - Err(anyhow::anyhow!("Optional parameters missing default values in the recipe: {}. Please provide defaults.", optional_params_without_default_values.join(", "))) - } -} - -fn apply_values_to_parameters( - user_params: &[(String, String)], - recipe_parameters: Option>, - recipe_parent_dir: &str, - enable_user_prompt: bool, -) -> Result<(HashMap, Vec)> { - let mut param_map: HashMap = user_params.iter().cloned().collect(); - param_map.insert( - BUILT_IN_RECIPE_DIR_PARAM.to_string(), - recipe_parent_dir.to_string(), - ); - let mut missing_params: Vec = Vec::new(); - for param in recipe_parameters.unwrap_or_default() { - if !param_map.contains_key(¶m.key) { - match (¶m.default, ¶m.requirement) { - (Some(default), _) => param_map.insert(param.key.clone(), default.clone()), - (None, RecipeParameterRequirement::UserPrompt) if enable_user_prompt => { - let input_value = cliclack::input(format!( - "Please enter {} ({})", - param.key, param.description - )) - .interact()?; - param_map.insert(param.key.clone(), input_value) - } - _ => { - missing_params.push(param.key.clone()); - None - } - }; - } - } - Ok((param_map, missing_params)) -} - fn validate_json_schema(schema: &serde_json::Value) -> Result<()> { match jsonschema::validator_for(schema) { Ok(_) => Ok(()), @@ -236,4 +107,49 @@ fn validate_json_schema(schema: &serde_json::Value) -> Result<()> { } #[cfg(test)] -mod tests; +mod tests { + use goose::recipe::{RecipeParameterInputType, RecipeParameterRequirement}; + + use crate::recipes::recipe::load_recipe; + + mod load_recipe { + use super::*; + #[test] + fn test_load_recipe_success() { + let recipe_content = r#"{ + "version": "1.0.0", + "title": "Test Recipe", + "description": "A test recipe", + "instructions": "Test instructions with {{ my_name }}", + "parameters": [ + { + "key": "my_name", + "input_type": "string", + "requirement": "required", + "description": "A test parameter" + } + ] + }"#; + let temp_dir = tempfile::tempdir().unwrap(); + let recipe_path = temp_dir.path().join("test_recipe.json"); + std::fs::write(&recipe_path, recipe_content).unwrap(); + + let params = vec![("my_name".to_string(), "value".to_string())]; + let recipe = load_recipe(recipe_path.to_str().unwrap(), params).unwrap(); + + assert_eq!(recipe.title, "Test Recipe"); + assert_eq!(recipe.description, "A test recipe"); + assert_eq!(recipe.instructions.unwrap(), "Test instructions with value"); + // Verify parameters match recipe definition + assert_eq!(recipe.parameters.as_ref().unwrap().len(), 1); + let param = &recipe.parameters.as_ref().unwrap()[0]; + assert_eq!(param.key, "my_name"); + assert!(matches!(param.input_type, RecipeParameterInputType::String)); + assert!(matches!( + param.requirement, + RecipeParameterRequirement::Required + )); + assert_eq!(param.description, "A test parameter"); + } + } +} diff --git a/crates/goose-cli/src/recipes/recipe/tests.rs b/crates/goose-cli/src/recipes/recipe/tests.rs deleted file mode 100644 index 4b27601ffdd7..000000000000 --- a/crates/goose-cli/src/recipes/recipe/tests.rs +++ /dev/null @@ -1,315 +0,0 @@ -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use goose::recipe::{RecipeParameterInputType, RecipeParameterRequirement}; - use tempfile::TempDir; - - use crate::recipes::recipe::load_recipe_as_template; - - fn setup_recipe_file(instructions_and_parameters: &str) -> (TempDir, PathBuf) { - let recipe_content = format!( - r#"{{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - {} - }}"#, - instructions_and_parameters - ); - let temp_dir = tempfile::tempdir().unwrap(); - let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.json"); - - std::fs::write(&recipe_path, recipe_content).unwrap(); - (temp_dir, recipe_path) - } - - mod load_recipe_as_template_tests { - use super::*; - #[test] - fn test_load_recipe_as_template_success() { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ my_name }}", - "parameters": [ - { - "key": "my_name", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ]"#; - - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let params = vec![("my_name".to_string(), "value".to_string())]; - let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap(); - - assert_eq!(recipe.title, "Test Recipe"); - assert_eq!(recipe.description, "A test recipe"); - assert_eq!(recipe.instructions.unwrap(), "Test instructions with value"); - // Verify parameters match recipe definition - assert_eq!(recipe.parameters.as_ref().unwrap().len(), 1); - let param = &recipe.parameters.as_ref().unwrap()[0]; - assert_eq!(param.key, "my_name"); - assert!(matches!(param.input_type, RecipeParameterInputType::String)); - assert!(matches!( - param.requirement, - RecipeParameterRequirement::Required - )); - assert_eq!(param.description, "A test parameter"); - } - - #[test] - fn test_load_recipe_as_template_success_variable_in_prompt() { - let instructions_and_parameters = r#" - "instructions": "Test instructions", - "prompt": "My prompt {{ my_name }}", - "parameters": [ - { - "key": "my_name", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ]"#; - - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let params = vec![("my_name".to_string(), "value".to_string())]; - let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap(); - - assert_eq!(recipe.title, "Test Recipe"); - assert_eq!(recipe.description, "A test recipe"); - assert_eq!(recipe.instructions.unwrap(), "Test instructions"); - assert_eq!(recipe.prompt.unwrap(), "My prompt value"); - let param = &recipe.parameters.as_ref().unwrap()[0]; - assert_eq!(param.key, "my_name"); - assert!(matches!(param.input_type, RecipeParameterInputType::String)); - assert!(matches!( - param.requirement, - RecipeParameterRequirement::Required - )); - assert_eq!(param.description, "A test parameter"); - } - - #[test] - fn test_load_recipe_as_template_wrong_parameters_in_recipe_file() { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ expected_param1 }} {{ expected_param2 }}", - "parameters": [ - { - "key": "wrong_param_key", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ]"#; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let load_recipe_result = - load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()); - assert!(load_recipe_result.is_err()); - let err = load_recipe_result.unwrap_err(); - println!("{}", err.to_string()); - assert!(err - .to_string() - .contains("Unnecessary parameter definitions: wrong_param_key.")); - assert!(err - .to_string() - .contains("Missing definitions for parameters in the recipe file:")); - assert!(err.to_string().contains("expected_param1")); - assert!(err.to_string().contains("expected_param2")); - } - - #[test] - fn test_load_recipe_as_template_with_default_values_in_recipe_file() { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ param_with_default }} {{ param_without_default }}", - "parameters": [ - { - "key": "param_with_default", - "input_type": "string", - "requirement": "optional", - "default": "my_default_value", - "description": "A test parameter" - }, - { - "key": "param_without_default", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ]"#; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - let params = vec![("param_without_default".to_string(), "value1".to_string())]; - - let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap(); - - assert_eq!(recipe.title, "Test Recipe"); - assert_eq!(recipe.description, "A test recipe"); - assert_eq!( - recipe.instructions.unwrap(), - "Test instructions with my_default_value value1" - ); - } - - #[test] - fn test_load_recipe_as_template_optional_parameters_with_empty_default_values_in_recipe_file( - ) { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ optional_param }}", - "parameters": [ - { - "key": "optional_param", - "input_type": "string", - "requirement": "optional", - "description": "A test parameter", - "default": "", - } - ]"#; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let recipe = - load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap(); - assert_eq!(recipe.title, "Test Recipe"); - assert_eq!(recipe.description, "A test recipe"); - assert_eq!(recipe.instructions.unwrap(), "Test instructions with "); - } - - #[test] - fn test_load_recipe_as_template_optional_parameters_without_default_values_in_recipe_file() - { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ optional_param }}", - "parameters": [ - { - "key": "optional_param", - "input_type": "string", - "requirement": "optional", - "description": "A test parameter" - } - ]"#; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let load_recipe_result = - load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()); - assert!(load_recipe_result.is_err()); - let err = load_recipe_result.unwrap_err(); - println!("{}", err.to_string()); - assert!(err.to_string().to_lowercase().contains("missing")); - } - - #[test] - fn test_load_recipe_as_template_wrong_input_type_in_recipe_file() { - let instructions_and_parameters = r#" - "instructions": "Test instructions with {{ param }}", - "parameters": [ - { - "key": "param", - "input_type": "some_invalid_type", - "requirement": "required", - "description": "A test parameter" - } - ]"#; - let params = vec![("param".to_string(), "value".to_string())]; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), params); - assert!(load_recipe_result.is_err()); - let err = load_recipe_result.unwrap_err(); - let err_msg = err.to_string(); - eprint!("Error: {}", err_msg); - assert!(err_msg.contains("unknown variant `some_invalid_type`")); - } - - #[test] - fn test_load_recipe_as_template_success_without_parameters() { - let instructions_and_parameters = r#" - "instructions": "Test instructions" - "#; - let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters); - - let recipe = - load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap(); - assert_eq!(recipe.instructions.unwrap(), "Test instructions"); - assert!(recipe.parameters.is_none()); - } - - #[test] - fn test_template_inheritance() { - let temp_dir = tempfile::tempdir().unwrap(); - let temp_path = temp_dir.path(); - let parent_content = r#" - version: 1.0.0 - title: Parent - description: Parent recipe - prompt: | - show me the news for day: {{ date }} - {% block prompt -%} - What is the capital of France? - {%- endblock %} - {% if is_enabled %} - Feature is enabled. - {% else %} - Feature is disabled. - {% endif %} - parameters: - - key: date - input_type: string - requirement: required - description: date specified by the user - - key: is_enabled - input_type: boolean - requirement: required - description: whether the feature is enabled - "#; - - let parent_path = temp_path.join("parent.yaml"); - std::fs::write(&parent_path, parent_content).unwrap(); - let child_content = r#" - {% extends "parent.yaml" -%} - {% block prompt -%} - What is the capital of Germany? - {%- endblock %} - "#; - let child_path = temp_path.join("child.yaml"); - std::fs::write(&child_path, child_content).unwrap(); - - let params = vec![ - ("date".to_string(), "today".to_string()), - ("is_enabled".to_string(), "true".to_string()), - ]; - let parent_result = - load_recipe_as_template(parent_path.to_str().unwrap(), params.clone()); - assert!(parent_result.is_ok()); - let parent_recipe = parent_result.unwrap(); - assert_eq!(parent_recipe.description, "Parent recipe"); - assert_eq!( - parent_recipe.prompt.unwrap(), - "show me the news for day: today\nWhat is the capital of France?\n\n Feature is enabled.\n" - ); - assert_eq!(parent_recipe.parameters.as_ref().unwrap().len(), 2); - assert_eq!(parent_recipe.parameters.as_ref().unwrap()[0].key, "date"); - assert_eq!( - parent_recipe.parameters.as_ref().unwrap()[1].key, - "is_enabled" - ); - - let child_result = load_recipe_as_template(child_path.to_str().unwrap(), params); - assert!(child_result.is_ok()); - let child_recipe = child_result.unwrap(); - assert_eq!(child_recipe.title, "Parent"); - assert_eq!(child_recipe.description, "Parent recipe"); - assert_eq!( - child_recipe.prompt.unwrap().trim(), - "show me the news for day: today\nWhat is the capital of Germany?\n\n Feature is enabled." - ); - assert_eq!(child_recipe.parameters.as_ref().unwrap().len(), 2); - assert_eq!(child_recipe.parameters.as_ref().unwrap()[0].key, "date"); - assert_eq!( - child_recipe.parameters.as_ref().unwrap()[1].key, - "is_enabled" - ); - } - } -} diff --git a/crates/goose-cli/src/recipes/search_recipe.rs b/crates/goose-cli/src/recipes/search_recipe.rs index e78dfe223b36..8854e0dc4824 100644 --- a/crates/goose-cli/src/recipes/search_recipe.rs +++ b/crates/goose-cli/src/recipes/search_recipe.rs @@ -1,20 +1,20 @@ use anyhow::{anyhow, Result}; use goose::config::Config; +use goose::recipe::read_recipe_file_content::{read_recipe_file, RecipeFile}; +use goose::recipe::template_recipe::parse_recipe_content; +use std::env; +use std::fs; use std::path::{Path, PathBuf}; -use std::{env, fs}; use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS; -use super::github_recipe::{retrieve_recipe_from_github, GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY}; +use super::github_recipe::{ + list_github_recipes, retrieve_recipe_from_github, RecipeInfo, RecipeSource, + GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY, +}; const GOOSE_RECIPE_PATH_ENV_VAR: &str = "GOOSE_RECIPE_PATH"; -pub struct RecipeFile { - pub content: String, - pub parent_dir: PathBuf, - pub file_path: PathBuf, -} - pub fn retrieve_recipe_file(recipe_name: &str) -> Result { if RECIPE_FILE_EXTENSIONS .iter() @@ -100,39 +100,95 @@ fn configured_github_recipe_repo() -> Option { } } -fn convert_path_with_tilde_expansion(path: &Path) -> PathBuf { - if let Some(path_str) = path.to_str() { - if let Some(stripped) = path_str.strip_prefix("~/") { - if let Some(home_dir) = dirs::home_dir() { - return home_dir.join(stripped); +/// Lists all available recipes from local paths and GitHub repositories +pub fn list_available_recipes() -> Result> { + let mut recipes = Vec::new(); + + // Search local recipes + if let Ok(local_recipes) = discover_local_recipes() { + recipes.extend(local_recipes); + } + + // Search GitHub recipes if configured + if let Some(repo) = configured_github_recipe_repo() { + if let Ok(github_recipes) = list_github_recipes(&repo) { + recipes.extend(github_recipes); + } + } + + Ok(recipes) +} + +fn discover_local_recipes() -> Result> { + let mut recipes = Vec::new(); + let mut search_dirs = vec![PathBuf::from(".")]; + + // Add GOOSE_RECIPE_PATH directories + if let Ok(recipe_path_env) = env::var(GOOSE_RECIPE_PATH_ENV_VAR) { + let path_separator = if cfg!(windows) { ';' } else { ':' }; + let recipe_path_env_dirs: Vec = recipe_path_env + .split(path_separator) + .map(PathBuf::from) + .collect(); + search_dirs.extend(recipe_path_env_dirs); + } + + for dir in search_dirs { + if let Ok(dir_recipes) = scan_directory_for_recipes(&dir) { + recipes.extend(dir_recipes); + } + } + + Ok(recipes) +} + +fn scan_directory_for_recipes(dir: &Path) -> Result> { + let mut recipes = Vec::new(); + + if !dir.exists() || !dir.is_dir() { + return Ok(recipes); + } + + for entry in fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_file() { + if let Some(extension) = path.extension() { + if RECIPE_FILE_EXTENSIONS.contains(&extension.to_string_lossy().as_ref()) { + if let Ok(recipe_info) = create_local_recipe_info(&path) { + recipes.push(recipe_info); + } + } } } } - PathBuf::from(path) + + Ok(recipes) } -fn read_recipe_file>(recipe_path: P) -> Result { - let raw_path = recipe_path.as_ref(); - let path = convert_path_with_tilde_expansion(raw_path); - - let content = fs::read_to_string(&path) - .map_err(|e| anyhow!("Failed to read recipe file {}: {}", path.display(), e))?; - let canonical = path.canonicalize().map_err(|e| { - anyhow!( - "Failed to resolve absolute path for {}: {}", - path.display(), - e - ) - })?; - - let parent_dir = canonical +fn create_local_recipe_info(path: &Path) -> Result { + let content = fs::read_to_string(path)?; + let recipe_dir = path .parent() - .ok_or_else(|| anyhow!("Resolved path has no parent: {}", canonical.display()))? - .to_path_buf(); + .unwrap_or_else(|| Path::new(".")) + .to_string_lossy() + .to_string(); + let (recipe, _) = parse_recipe_content(&content, recipe_dir)?; + + let name = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("unknown") + .to_string(); + + let path_str = path.to_string_lossy().to_string(); - Ok(RecipeFile { - content, - parent_dir, - file_path: canonical, + Ok(RecipeInfo { + name, + source: RecipeSource::Local, + path: path_str, + title: Some(recipe.title), + description: Some(recipe.description), }) } diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 9782dea333bc..588325c37368 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -203,6 +203,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); + if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose-cli/src/session/export.rs b/crates/goose-cli/src/session/export.rs index 57b83b1efa22..5c971402512f 100644 --- a/crates/goose-cli/src/session/export.rs +++ b/crates/goose-cli/src/session/export.rs @@ -1,7 +1,6 @@ use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; -use mcp_core::content::Content as McpContent; -use mcp_core::resource::ResourceContents; -use mcp_core::role::Role; +use goose::utils::safe_truncate; +use rmcp::model::{RawContent, ResourceContents, Role}; use serde_json::Value; const MAX_STRING_LENGTH_MD_EXPORT: usize = 4096; // Generous limit for export @@ -10,9 +9,9 @@ const REDACTED_PREFIX_LENGTH: usize = 100; // Show first 100 chars before trimmi fn value_to_simple_markdown_string(value: &Value, export_full_strings: bool) -> String { match value { Value::String(s) => { - if !export_full_strings && s.len() > MAX_STRING_LENGTH_MD_EXPORT { - let prefix = &s[..REDACTED_PREFIX_LENGTH.min(s.len())]; - let trimmed_chars = s.len() - prefix.len(); + if !export_full_strings && s.chars().count() > MAX_STRING_LENGTH_MD_EXPORT { + let prefix = safe_truncate(s, REDACTED_PREFIX_LENGTH); + let trimmed_chars = s.chars().count() - prefix.chars().count(); format!("`{}[ ... trimmed : {} chars ... ]`", prefix, trimmed_chars) } else { // Escape backticks and newlines for inline code. @@ -40,7 +39,7 @@ fn value_to_markdown(value: &Value, depth: usize, export_full_strings: bool) -> md_string.push_str(&format!("{}* **{}**: ", base_indent_str, key)); match val { Value::String(s) => { - if s.contains('\n') || s.len() > 80 { + if s.contains('\n') || s.chars().count() > 80 { // Heuristic for block md_string.push_str(&format!( "\n{} ```\n{}{}\n{} ```\n", @@ -74,7 +73,7 @@ fn value_to_markdown(value: &Value, depth: usize, export_full_strings: bool) -> md_string.push_str(&format!("{}* - ", base_indent_str)); match item { Value::String(s) => { - if s.contains('\n') || s.len() > 80 { + if s.contains('\n') || s.chars().count() > 80 { // Heuristic for block md_string.push_str(&format!( "\n{} ```\n{}{}\n{} ```\n", @@ -218,8 +217,8 @@ pub fn tool_response_to_markdown(resp: &ToolResponse, export_all_content: bool) } } - match content { - McpContent::Text(text_content) => { + match &content.raw { + RawContent::Text(text_content) => { let trimmed_text = text_content.text.trim(); if (trimmed_text.starts_with('{') && trimmed_text.ends_with('}')) || (trimmed_text.starts_with('[') && trimmed_text.ends_with(']')) @@ -235,7 +234,7 @@ pub fn tool_response_to_markdown(resp: &ToolResponse, export_all_content: bool) md.push_str("\n\n"); } } - McpContent::Image(image_content) => { + RawContent::Image(image_content) => { if image_content.mime_type.starts_with("image/") { // For actual images, provide a placeholder that indicates it's an image md.push_str(&format!( @@ -251,7 +250,7 @@ pub fn tool_response_to_markdown(resp: &ToolResponse, export_all_content: bool) )); } } - McpContent::Resource(resource) => { + RawContent::Resource(resource) => { match &resource.resource { ResourceContents::TextResourceContents { uri, @@ -298,6 +297,9 @@ pub fn tool_response_to_markdown(resp: &ToolResponse, export_all_content: bool) } } } + RawContent::Audio(_) => { + md.push_str("[audio content not displayed in Markdown export]\n\n") + } } } } @@ -359,8 +361,8 @@ pub fn message_to_markdown(message: &Message, export_all_content: bool) -> Strin mod tests { use super::*; use goose::message::{Message, ToolRequest, ToolResponse}; - use mcp_core::content::{Content as McpContent, TextContent}; use mcp_core::tool::ToolCall; + use rmcp::model::{Content, RawTextContent, TextContent}; use serde_json::json; #[test] @@ -397,7 +399,7 @@ mod tests { assert!(result.starts_with("`")); assert!(result.contains("[ ... trimmed : ")); assert!(result.contains("4900 chars ... ]`")); - assert!(result.contains(&"a".repeat(100))); // Should contain the prefix + assert!(result.contains(&"a".repeat(97))); // Should contain the prefix (100 - 3 for "...") } #[test] @@ -520,12 +522,14 @@ mod tests { #[test] fn test_tool_response_to_markdown_text() { let text_content = TextContent { - text: "Command executed successfully".to_string(), + raw: RawTextContent { + text: "Command executed successfully".to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "test-id".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let result = tool_response_to_markdown(&tool_response, true); @@ -537,12 +541,14 @@ mod tests { fn test_tool_response_to_markdown_json() { let json_text = r#"{"status": "success", "data": "test"}"#; let text_content = TextContent { - text: json_text.to_string(), + raw: RawTextContent { + text: json_text.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "test-id".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let result = tool_response_to_markdown(&tool_response, true); @@ -639,12 +645,14 @@ if __name__ == "__main__": hello_world()"#; let text_content = TextContent { - text: python_code.to_string(), + raw: RawTextContent { + text: python_code.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "shell-cat".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); @@ -676,12 +684,14 @@ if __name__ == "__main__": let git_output = " M src/main.rs\n?? temp.txt\n A new_feature.rs"; let text_content = TextContent { - text: git_output.to_string(), + raw: RawTextContent { + text: git_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "git-status".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); @@ -721,12 +731,14 @@ warning: unused variable `x` Finished dev [unoptimized + debuginfo] target(s) in 2.45s"#; let text_content = TextContent { - text: build_output.to_string(), + raw: RawTextContent { + text: build_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "cargo-build".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let response_result = tool_response_to_markdown(&tool_response, true); @@ -764,12 +776,14 @@ warning: unused variable `x` }"#; let text_content = TextContent { - text: api_response.to_string(), + raw: RawTextContent { + text: api_response.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "curl-api".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let response_result = tool_response_to_markdown(&tool_response, true); @@ -796,12 +810,14 @@ warning: unused variable `x` }; let text_content = TextContent { - text: "File created successfully".to_string(), + raw: RawTextContent { + text: "File created successfully".to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "editor-write".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); @@ -849,12 +865,14 @@ def process_data(data: List[Dict]) -> List[Dict]: return [item for item in data if item.get('active', False)]"#; let text_content = TextContent { - text: python_code.to_string(), + raw: RawTextContent { + text: python_code.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "editor-view".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let response_result = tool_response_to_markdown(&tool_response, true); @@ -882,12 +900,14 @@ def process_data(data: List[Dict]) -> List[Dict]: Command failed with exit code 2"#; let text_content = TextContent { - text: error_output.to_string(), + raw: RawTextContent { + text: error_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "shell-error".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let response_result = tool_response_to_markdown(&tool_response, true); @@ -918,12 +938,14 @@ Command failed with exit code 2"#; 5^2 = 25"#; let text_content = TextContent { - text: script_output.to_string(), + raw: RawTextContent { + text: script_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "script-exec".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); @@ -961,12 +983,14 @@ drwx------ 3 user staff 96 Dec 6 16:20 com.apple.launchd.abc /tmp"#; let text_content = TextContent { - text: multi_output.to_string(), + raw: RawTextContent { + text: multi_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "multi-cmd".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&_tool_request, true); @@ -1000,12 +1024,14 @@ src/database.rs:23:async fn query_users(pool: &Pool) -> Result> { src/middleware.rs:12:async fn auth_middleware(req: Request, next: Next) -> Result {"#; let text_content = TextContent { - text: grep_output.to_string(), + raw: RawTextContent { + text: grep_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "grep-search".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); @@ -1036,12 +1062,14 @@ src/middleware.rs:12:async fn auth_middleware(req: Request, next: Next) -> Resul let json_output = r#"{"status": "success", "data": {"count": 42}}"#; let text_content = TextContent { - text: json_output.to_string(), + raw: RawTextContent { + text: json_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "json-test".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let response_result = tool_response_to_markdown(&tool_response, true); @@ -1073,12 +1101,14 @@ src/middleware.rs:12:async fn auth_middleware(req: Request, next: Next) -> Resul found 0 vulnerabilities"#; let text_content = TextContent { - text: npm_output.to_string(), + raw: RawTextContent { + text: npm_output.to_string(), + }, annotations: None, }; let tool_response = ToolResponse { id: "npm-install".to_string(), - tool_result: Ok(vec![McpContent::Text(text_content)]), + tool_result: Ok(vec![Content::text(text_content.raw.text)]), }; let request_result = tool_request_to_markdown(&tool_request, true); diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 231a6cfa8143..a260e8f16774 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -4,8 +4,14 @@ mod export; mod input; mod output; mod prompt; +mod task_execution_display; mod thinking; +use crate::session::task_execution_display::{ + format_task_execution_notification, TASK_EXECUTION_NOTIFICATION_TYPE, +}; +use std::io::Write; + pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; @@ -16,6 +22,7 @@ use goose::permission::Permission; use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; +use goose::utils::safe_truncate; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -28,9 +35,9 @@ use goose::providers::pricing::initialize_pricing_cache; use goose::session; use input::InputResult; use mcp_core::handler::ToolError; -use mcp_core::prompt::PromptMessage; use mcp_core::protocol::JsonRpcMessage; use mcp_core::protocol::JsonRpcNotification; +use rmcp::model::PromptMessage; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; @@ -352,7 +359,33 @@ impl Session { pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result> { let result = self.agent.get_prompt(name, arguments).await?; - Ok(result.messages) + // Convert mcp_core::prompt::PromptMessage to rmcp::model::PromptMessage + let converted_messages = result + .messages + .into_iter() + .map(|msg| rmcp::model::PromptMessage { + role: match msg.role { + mcp_core::prompt::PromptMessageRole::User => { + rmcp::model::PromptMessageRole::User + } + mcp_core::prompt::PromptMessageRole::Assistant => { + rmcp::model::PromptMessageRole::Assistant + } + }, + content: match msg.content { + mcp_core::prompt::PromptMessageContent::Text { text } => { + rmcp::model::PromptMessageContent::Text { text } + } + mcp_core::prompt::PromptMessageContent::Image { image } => { + rmcp::model::PromptMessageContent::Image { image } + } + mcp_core::prompt::PromptMessageContent::Resource { resource } => { + rmcp::model::PromptMessageContent::Resource { resource } + } + }, + }) + .collect(); + Ok(converted_messages) } /// Process a single message and get the response @@ -363,11 +396,16 @@ impl Session { // Persist messages with provider for automatic description generation if let Some(session_file) = &self.session_file { + let working_dir = Some( + std::env::current_dir().expect("failed to get current session working directory"), + ); + session::persist_messages_with_schedule_id( session_file, &self.messages, Some(provider), self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -485,11 +523,17 @@ impl Session { // Persist messages with provider for automatic description generation if let Some(session_file) = &self.session_file { + let working_dir = Some( + std::env::current_dir() + .expect("failed to get current session working directory"), + ); + session::persist_messages_with_schedule_id( session_file, &self.messages, Some(provider), self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -626,6 +670,10 @@ impl Session { &Message::assistant().with_text("Chat context cleared."), self.debug, ); + if let Some(file) = self.session_file.as_ref().filter(|f| f.exists()) { + std::fs::remove_file(file)?; + std::fs::File::create(file)?; + } continue; } input::InputResult::PromptCommand(opts) => { @@ -697,11 +745,13 @@ impl Session { // Persist the summarized messages if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, Some(provider), self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -881,11 +931,13 @@ impl Session { )); push_message(&mut self.messages, response_message); if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, None, self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -980,11 +1032,13 @@ impl Session { // No need to update description on assistant messages if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, None, self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -1003,7 +1057,7 @@ impl Session { match method.as_str() { "notifications/message" => { let data = o.get("data").unwrap_or(&Value::Null); - let (formatted_message, subagent_id, _notification_type) = match data { + let (formatted_message, subagent_id, message_notification_type) = match data { Value::String(s) => (s.clone(), None, None), Value::Object(o) => { // Check for subagent notification structure first @@ -1037,11 +1091,7 @@ impl Session { if min_priority > 0.1 && !self.debug { // High/Medium verbosity: show truncated response if let Some(response_content) = msg.strip_prefix("Responded: ") { - if response_content.len() > 100 { - format!("🤖 Responded: {}...", &response_content[..100]) - } else { - format!("🤖 {}", msg) - } + format!("🤖 Responded: {}", safe_truncate(response_content, 100)) } else { format!("🤖 {}", msg) } @@ -1058,6 +1108,8 @@ impl Session { } else if let Some(Value::String(output)) = o.get("output") { // Fallback for other MCP notification types (output.to_owned(), None, None) + } else if let Some(result) = format_task_execution_notification(data) { + result } else { (data.to_string(), None, None) } @@ -1069,14 +1121,26 @@ impl Session { // Handle subagent notifications - show immediately if let Some(_id) = subagent_id { - // Show subagent notifications immediately (no buffering) with compact spacing + // TODO: proper display for subagent notifications if interactive { let _ = progress_bars.hide(); println!("{}", console::style(&formatted_message).green().dim()); } else { progress_bars.log(&formatted_message); } - } else { + } else if let Some(ref notification_type) = message_notification_type { + if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { + if interactive { + let _ = progress_bars.hide(); + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } else { + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } + } + } + else { // Non-subagent notification, display immediately with compact spacing if interactive { let _ = progress_bars.hide(); @@ -1149,7 +1213,7 @@ impl Session { let tool_requests = self .messages .last() - .filter(|msg| msg.role == mcp_core::role::Role::Assistant) + .filter(|msg| msg.role == rmcp::model::Role::Assistant) .map_or(Vec::new(), |msg| { msg.content .iter() @@ -1187,11 +1251,13 @@ impl Session { // No need for description update here if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, None, self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -1204,11 +1270,13 @@ impl Session { // No need for description update here if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, None, self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -1217,7 +1285,7 @@ impl Session { } else { // An interruption occurred outside of a tool request-response. if let Some(last_msg) = self.messages.last() { - if last_msg.role == mcp_core::role::Role::User { + if last_msg.role == rmcp::model::Role::User { match last_msg.content.first() { Some(MessageContent::ToolResponse(_)) => { // Interruption occurred after a tool had completed but not assistant reply @@ -1226,11 +1294,13 @@ impl Session { // No need for description update here if let Some(session_file) = &self.session_file { + let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( session_file, &self.messages, None, self.scheduled_job_id.clone(), + working_dir, ) .await?; } @@ -1420,9 +1490,9 @@ impl Session { let msg = Message::from(prompt_message); // ensure we get a User - Assistant - User type pattern let expected_role = if i % 2 == 0 { - mcp_core::Role::User + rmcp::model::Role::User } else { - mcp_core::Role::Assistant + rmcp::model::Role::Assistant }; if msg.role != expected_role { @@ -1436,7 +1506,7 @@ impl Session { break; } - if msg.role == mcp_core::Role::User { + if msg.role == rmcp::model::Role::User { output::render_message(&msg, self.debug); } self.push_message(msg); diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 06435bd32be3..673cb8c500cc 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -232,7 +232,7 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme, debug: bool) { Ok(contents) => { for content in contents { if let Some(audience) = content.audience() { - if !audience.contains(&mcp_core::role::Role::User) { + if !audience.contains(&rmcp::model::Role::User) { continue; } } @@ -252,7 +252,7 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme, debug: bool) { if debug { println!("{:#?}", content); - } else if let mcp_core::content::Content::Text(text) = content { + } else if let Some(text) = content.as_text() { print_markdown(&text.text, theme); } } @@ -463,8 +463,26 @@ fn print_params(value: &Value, depth: usize, debug: bool) { } } Value::String(s) => { - if !debug && s.len() > get_tool_params_max_length() { - println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); + // Special handling for text_instruction to show more content + let max_length = if key == "text_instruction" { + 200 // Allow longer display for text instructions + } else { + get_tool_params_max_length() + }; + + if !debug && s.len() > max_length { + // For text instructions, show a preview instead of just "..." + if key == "text_instruction" { + let preview = &s[..max_length.saturating_sub(3)]; + println!( + "{}{}: {}", + indent, + style(key).dim(), + style(format!("{}...", preview)).green() + ); + } else { + println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); + } } else { println!("{}{}: {}", indent, style(key).dim(), style(s).green()); } diff --git a/crates/goose-cli/src/session/task_execution_display/mod.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs new file mode 100644 index 000000000000..b0b208ed546e --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -0,0 +1,247 @@ +use goose::agents::subagent_execution_tool::lib::TaskStatus; +use goose::agents::subagent_execution_tool::notification_events::{ + TaskExecutionNotificationEvent, TaskInfo, +}; +use serde_json::Value; +use std::sync::atomic::{AtomicBool, Ordering}; + +#[cfg(test)] +mod tests; + +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; +pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; + +static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); + +fn format_result_data_for_display(result_data: &Value) -> String { + match result_data { + Value::String(s) => strip_ansi_codes(s), + Value::Object(obj) => { + if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { + format!("Partial output: {}", partial_output) + } else { + serde_json::to_string_pretty(obj).unwrap_or_default() + } + } + Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(), + Value::Bool(b) => b.to_string(), + Value::Number(n) => n.to_string(), + Value::Null => "null".to_string(), + } +} + +fn process_output_for_display(output: &str) -> String { + const MAX_OUTPUT_LINES: usize = 2; + const OUTPUT_PREVIEW_LENGTH: usize = 100; + + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join(" ... "); + let stripped = strip_ansi_codes(&clean_output); + truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) +} + +fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + let mut end = max_len.saturating_sub(3); + while end > 0 && !text.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &text[..end]) + } else { + text.to_string() + } +} + +fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if let Some(next_ch) = chars.next() { + if next_ch == '[' { + // This is an ANSI escape sequence, consume until alphabetic character + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } + } + } else { + // Not an ANSI sequence, keep both characters + result.push(ch); + result.push(next_ch); + } + } else { + // End of string after \x1b + result.push(ch); + } + } else { + result.push(ch); + } + } + + result +} + +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Ok(event) = serde_json::from_value::(data.clone()) { + return Some(match event { + TaskExecutionNotificationEvent::LineOutput { output, .. } => ( + format!("{}\n", output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ), + TaskExecutionNotificationEvent::TasksUpdate { .. } => { + let formatted_display = format_tasks_update_from_event(&event); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + TaskExecutionNotificationEvent::TasksComplete { .. } => { + let formatted_summary = format_tasks_complete_from_event(&event); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + }); + } + None +} + +fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event { + let mut display = String::new(); + + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + stats.total, stats.pending, stats.running, stats.completed, stats.failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + + let mut sorted_tasks = tasks.clone(); + sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); + + for task in sorted_tasks { + display.push_str(&format_task_display(&task)); + } + + display.push_str(CLEAR_BELOW); + display + } else { + String::new() + } +} + +fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + } = event + { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + summary.push_str(&format!("Total Tasks: {}\n", stats.total)); + summary.push_str(&format!("✅ Completed: {}\n", stats.completed)); + summary.push_str(&format!("❌ Failed: {}\n", stats.failed)); + summary.push_str(&format!("📈 Success Rate: {:.1}%\n", stats.success_rate)); + + if !failed_tasks.is_empty() { + summary.push_str("\n❌ Failed Tasks:\n"); + for task in failed_tasks { + summary.push_str(&format!(" • {}\n", task.name)); + if let Some(error) = &task.error { + summary.push_str(&format!(" Error: {}\n", error)); + } + } + } + + summary.push_str("\n📝 Generating summary...\n"); + summary + } else { + String::new() + } +} + +fn format_task_display(task: &TaskInfo) -> String { + let mut task_display = String::new(); + + let status_icon = match task.status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task.task_name, task.task_type, CLEAR_TO_EOL + )); + + if !task.task_metadata.is_empty() { + task_display.push_str(&format!( + " 📋 Parameters: {}{}\n", + task.task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.duration_secs { + task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() { + let processed_output = process_output_for_display(&task.current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); + } + } + + if matches!(task.status, TaskStatus::Completed) { + if let Some(result_data) = &task.result_data { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + + if matches!(task.status, TaskStatus::Failed) { + if let Some(error) = &task.error { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " ⚠️ {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display +} diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs new file mode 100644 index 000000000000..725d161dff5b --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -0,0 +1,337 @@ +use super::*; +use goose::agents::subagent_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, +}; +use serde_json::json; + +#[test] +fn test_strip_ansi_codes() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"), + "normalyellownormal" + ); + assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello"); + assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b"); + assert_eq!(strip_ansi_codes(""), ""); +} + +#[test] +fn test_truncate_with_ellipsis() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + assert_eq!(truncate_with_ellipsis("hello world", 8), "hello..."); + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hello", 2), "..."); + assert_eq!(truncate_with_ellipsis("hello", 1), "..."); + assert_eq!(truncate_with_ellipsis("", 5), ""); +} + +#[test] +fn test_process_output_for_display() { + assert_eq!(process_output_for_display("hello world"), "hello world"); + assert_eq!( + process_output_for_display("line1\nline2"), + "line1 ... line2" + ); + + let input = "line1\nline2\nline3\nline4"; + let result = process_output_for_display(input); + assert_eq!(result, "line3 ... line4"); + + let long_line = "a".repeat(150); + let result = process_output_for_display(&long_line); + assert!(result.len() <= 100); + assert!(result.ends_with("...")); + + let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_for_display(ansi_output); + assert_eq!(result, "red line 1 ... green line 2"); + + assert_eq!(process_output_for_display(""), ""); +} + +#[test] +fn test_format_result_data_for_display() { + let string_val = json!("hello world"); + assert_eq!(format_result_data_for_display(&string_val), "hello world"); + + let ansi_string = json!("\x1b[31mred text\x1b[0m"); + assert_eq!(format_result_data_for_display(&ansi_string), "red text"); + + assert_eq!(format_result_data_for_display(&json!(true)), "true"); + assert_eq!(format_result_data_for_display(&json!(false)), "false"); + assert_eq!(format_result_data_for_display(&json!(42)), "42"); + assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14"); + assert_eq!(format_result_data_for_display(&json!(null)), "null"); + + let partial_obj = json!({ + "partial_output": "some output", + "other_field": "ignored" + }); + assert_eq!( + format_result_data_for_display(&partial_obj), + "Partial output: some output" + ); + + let obj = json!({"key": "value", "num": 42}); + let result = format_result_data_for_display(&obj); + assert!(result.contains("key")); + assert!(result.contains("value")); + + let arr = json!([1, 2, 3]); + let result = format_result_data_for_display(&arr); + assert!(result.contains("1")); + assert!(result.contains("2")); + assert!(result.contains("3")); +} + +#[test] +fn test_format_task_execution_notification_line_output() { + let _event = TaskExecutionNotificationEvent::LineOutput { + task_id: "task-1".to_string(), + output: "Hello World".to_string(), + }; + + let data = json!({ + "subtype": "line_output", + "task_id": "task-1", + "output": "Hello World" + }); + + let result = format_task_execution_notification(&data); + assert!(result.is_some()); + + let (formatted, second, third) = result.unwrap(); + assert_eq!(formatted, "Hello World\n"); + assert_eq!(second, None); + assert_eq!(third, Some("task_execution".to_string())); +} + +#[test] +fn test_format_task_execution_notification_invalid_data() { + let invalid_data = json!({ + "invalid": "structure" + }); + + let result = format_task_execution_notification(&invalid_data); + assert_eq!(result, None); + + let incomplete_data = json!({ + "subtype": "line_output" + }); + + let result = format_task_execution_notification(&incomplete_data); + assert_eq!(result, None); +} + +#[test] +fn test_format_tasks_update_from_event() { + INITIAL_SHOWN.store(false, Ordering::SeqCst); + + let stats = TaskExecutionStats::new(3, 1, 1, 1, 0); + let tasks = vec![ + TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }, + TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(2.3), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "another-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"result": "success"})), + }, + ]; + + let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks }; + let result = format_tasks_update_from_event(&event); + + assert!(result.contains("🎯 Task Execution Dashboard")); + assert!(result.contains("═══════════════════════════")); + assert!(result.contains("📊 Progress: 3 total")); + assert!(result.contains("⏳ 1 pending")); + assert!(result.contains("🏃 1 running")); + assert!(result.contains("✅ 1 completed")); + assert!(result.contains("❌ 0 failed")); + assert!(result.contains("🏃 test-task")); + assert!(result.contains("✅ another-task")); + assert!(result.contains("📋 Parameters: param=value")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing...")); + + let result2 = format_tasks_update_from_event(&event); + assert!(!result2.contains("🎯 Task Execution Dashboard")); + assert!(result2.contains(MOVE_TO_PROGRESS_LINE)); +} + +#[test] +fn test_format_tasks_complete_from_event() { + let stats = TaskCompletionStats::new(5, 4, 1); + let failed_tasks = vec![FailedTaskInfo { + id: "task-3".to_string(), + name: "failed-task".to_string(), + error: Some("Connection timeout".to_string()), + }]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(result.contains("Execution Complete!")); + assert!(result.contains("═══════════════════════")); + assert!(result.contains("Total Tasks: 5")); + assert!(result.contains("✅ Completed: 4")); + assert!(result.contains("❌ Failed: 1")); + assert!(result.contains("📈 Success Rate: 80.0%")); + assert!(result.contains("❌ Failed Tasks:")); + assert!(result.contains("• failed-task")); + assert!(result.contains("Error: Connection timeout")); + assert!(result.contains("📝 Generating summary...")); +} + +#[test] +fn test_format_tasks_complete_from_event_no_failures() { + let stats = TaskCompletionStats::new(3, 3, 0); + let failed_tasks = vec![]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(!result.contains("❌ Failed Tasks:")); + assert!(result.contains("📈 Success Rate: 100.0%")); + assert!(result.contains("❌ Failed: 0")); +} + +#[test] +fn test_format_task_display_running() { + let task = TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing data...\nAlmost done...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "data-processor".to_string(), + task_metadata: "input=file.txt,output=result.json".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("🏃 data-processor (sub_recipe)")); + assert!(result.contains("📋 Parameters: input=file.txt,output=result.json")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing data... ... Almost done...")); +} + +#[test] +fn test_format_task_display_completed() { + let task = TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(3.2), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "analyzer".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"status": "success", "count": 42})), + }; + + let result = format_task_display(&task); + + assert!(result.contains("✅ analyzer (text_instruction)")); + assert!(result.contains("⏱️ 3.2s")); + assert!(!result.contains("📋 Parameters")); + assert!(result.contains("📄")); +} + +#[test] +fn test_format_task_display_failed() { + let task = TaskInfo { + id: "task-3".to_string(), + status: TaskStatus::Failed, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "failing-task".to_string(), + task_metadata: "".to_string(), + error: Some( + "Network connection failed after multiple retries. The server is unreachable." + .to_string(), + ), + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("❌ failing-task (sub_recipe)")); + assert!(!result.contains("⏱️")); + assert!(result.contains("⚠️")); + assert!(result.contains("Network connection failed after multiple retries")); +} + +#[test] +fn test_format_task_display_pending() { + let task = TaskInfo { + id: "task-4".to_string(), + status: TaskStatus::Pending, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "waiting-task".to_string(), + task_metadata: "priority=high".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("⏳ waiting-task (sub_recipe)")); + assert!(result.contains("📋 Parameters: priority=high")); + assert!(!result.contains("⏱️")); + assert!(!result.contains("💬")); + assert!(!result.contains("📄")); + assert!(!result.contains("⚠️")); +} + +#[test] +fn test_format_task_display_empty_current_output() { + let task = TaskInfo { + id: "task-5".to_string(), + status: TaskStatus::Running, + duration_secs: Some(0.5), + current_output: " \n\t \n ".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "quiet-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(!result.contains("💬")); +} diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml index 17723e31aac4..9f3dd9ede88b 100644 --- a/crates/goose-llm/Cargo.toml +++ b/crates/goose-llm/Cargo.toml @@ -15,6 +15,7 @@ crate-type = ["lib", "cdylib"] name = "goose_llm" [dependencies] +goose = { path = "../goose" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" @@ -63,6 +64,10 @@ path = "uniffi-bindgen.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "image" +path = "examples/image.rs" + [[example]] name = "prompt_override" path = "examples/prompt_override.rs" diff --git a/crates/goose-llm/examples/image.rs b/crates/goose-llm/examples/image.rs new file mode 100644 index 000000000000..7c607713e9cf --- /dev/null +++ b/crates/goose-llm/examples/image.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use goose_llm::{ + completion, + message::MessageContent, + types::completion::{CompletionRequest, CompletionResponse}, + Message, ModelConfig, +}; +use serde_json::json; +use std::{fs, vec}; + +#[tokio::main] +async fn main() -> Result<()> { + let provider = "databricks"; + let provider_config = json!({ + "host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"), + "token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"), + }); + let model_name = "goose-claude-4-sonnet"; // "gpt-4o"; + let model_config = ModelConfig::new(model_name.to_string()); + + let system_preamble = "You are a helpful assistant."; + + // Read and encode test image + let image_data = fs::read("examples/test_assets/test_image.png")?; + let base64_image = BASE64.encode(image_data); + + let user_msg = Message::user() + .with_text("What do you see in this image?") + .with_content(MessageContent::image(base64_image, "image/png")); + + let messages = vec![user_msg]; + + let completion_response: CompletionResponse = completion( + CompletionRequest::new( + provider.to_string(), + provider_config.clone(), + model_config.clone(), + Some(system_preamble.to_string()), + None, + messages, + vec![], + ) + .with_request_id("test-image-1".to_string()), + ) + .await?; + + // Print the response + println!("\nCompletion Response:"); + println!("{}", serde_json::to_string_pretty(&completion_response)?); + + Ok(()) +} diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs index e7d36a7870ed..efab4b0abc57 100644 --- a/crates/goose-llm/examples/simple.rs +++ b/crates/goose-llm/examples/simple.rs @@ -116,7 +116,7 @@ async fn main() -> Result<()> { println!("\nCompletion Response:"); println!("{}", serde_json::to_string_pretty(&completion_response)?); - let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?; + let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?; println!("\nTooltip: {}", tooltip); } diff --git a/crates/goose-llm/examples/test_assets/test_image.png b/crates/goose-llm/examples/test_assets/test_image.png new file mode 100644 index 000000000000..f72b65986d19 Binary files /dev/null and b/crates/goose-llm/examples/test_assets/test_image.png differ diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index d39b1b8db830..13f09810b8c4 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result String { } /// Generates a short (≤4 words) session name -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_session_name( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Collect up to the first 3 user messages (truncated to 300 chars each) let context: Vec = messages @@ -60,11 +62,7 @@ pub async fn generate_session_name( .take(3) .map(|m| { let text = m.content.concat_text_str(); - if text.len() > 300 { - text.chars().take(300).collect() - } else { - text - } + safe_truncate(&text, 300) }) .collect(); @@ -93,6 +91,7 @@ pub async fn generate_session_name( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/extractors/tooltip.rs b/crates/goose-llm/src/extractors/tooltip.rs index 37d83ffe59e1..48336a546ea6 100644 --- a/crates/goose-llm/src/extractors/tooltip.rs +++ b/crates/goose-llm/src/extractors/tooltip.rs @@ -52,11 +52,12 @@ fn build_system_prompt() -> String { /// Generates a tooltip summarizing the last two messages in the session, /// including any tool calls or results. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_tooltip( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Need at least two messages to generate a tooltip if messages.len() < 2 { @@ -148,6 +149,7 @@ pub async fn generate_tooltip( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/providers/base.rs b/crates/goose-llm/src/providers/base.rs index dcfecbd1e7f3..92a3948df28f 100644 --- a/crates/goose-llm/src/providers/base.rs +++ b/crates/goose-llm/src/providers/base.rs @@ -69,6 +69,7 @@ pub trait Provider: Send + Sync { /// * `system` - The system prompt that guides the model's behavior /// * `messages` - The conversation history as a sequence of messages /// * `tools` - Optional list of tools the model can use + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A tuple containing the model's response message and provider usage statistics @@ -81,6 +82,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result; /// Structured extraction: always JSON‐Schema @@ -90,6 +92,7 @@ pub trait Provider: Send + Sync { /// * `messages` – conversation history /// * `schema` – a JSON‐Schema for the expected output. /// Will set strict=true for OpenAI & Databricks. + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`. @@ -102,6 +105,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], schema: &serde_json::Value, + request_id: Option<&str>, ) -> Result; } diff --git a/crates/goose-llm/src/providers/databricks.rs b/crates/goose-llm/src/providers/databricks.rs index 3dd31493c1cd..0bfe2ffef67b 100644 --- a/crates/goose-llm/src/providers/databricks.rs +++ b/crates/goose-llm/src/providers/databricks.rs @@ -210,6 +210,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result { let mut payload = create_request( &self.model, @@ -224,6 +225,17 @@ impl Provider for DatabricksProvider { .expect("payload should have model key") .remove("model"); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + let response = self.post(payload.clone()).await?; // Parse response @@ -247,6 +259,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], schema: &Value, + request_id: Option<&str>, ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; @@ -267,6 +280,17 @@ impl Provider for DatabricksProvider { }), ); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + // 3. Call OpenAI let response = self.post(payload.clone()).await?; diff --git a/crates/goose-llm/src/providers/formats/databricks.rs b/crates/goose-llm/src/providers/formats/databricks.rs index d69c31bb9919..37343f2ebe09 100644 --- a/crates/goose-llm/src/providers/formats/databricks.rs +++ b/crates/goose-llm/src/providers/formats/databricks.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -34,30 +31,17 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - has_multiple_content = true; - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - content_array.push(convert_image(&image, image_format)); - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } + content_array.push(json!({ + "type": "text", + "text": text.text + })); } } + MessageContent::Image(image) => { + // Handle direct image content + let converted_image = convert_image(image, image_format); + content_array.push(converted_image); + } MessageContent::Thinking(content) => { has_multiple_content = true; content_array.push(json!({ @@ -166,15 +150,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - content_array.push(json!({ - "type": "image_url", - "image_url": { - "url": convert_image(image, image_format) - } - })); - } } } @@ -791,40 +766,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/formats/openai.rs b/crates/goose-llm/src/providers/formats/openai.rs index afc48745cb59..a2eb43b414eb 100644 --- a/crates/goose-llm/src/providers/formats/openai.rs +++ b/crates/goose-llm/src/providers/formats/openai.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -31,23 +28,13 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - converted["content"] = json!([ - {"type": "text", "text": text.text}, - convert_image(&image, image_format) - ]); - } else { - // If image loading fails, just use the text - converted["content"] = json!(text.text); - } - } else { - converted["content"] = json!(text.text); - } + converted["content"] = json!(text.text); } } + MessageContent::Image(image) => { + // Handle direct image content + converted["content"] = json!([convert_image(image, image_format)]); + } MessageContent::Thinking(_) => { // Thinking blocks are not directly used in OpenAI format continue; @@ -134,10 +121,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - converted["content"] = json!([convert_image(image, image_format)]); - } } } @@ -664,40 +647,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/openai.rs b/crates/goose-llm/src/providers/openai.rs index bc0dc0884823..82d736f366cf 100644 --- a/crates/goose-llm/src/providers/openai.rs +++ b/crates/goose-llm/src/providers/openai.rs @@ -149,6 +149,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], tools: &[Tool], + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; @@ -175,6 +176,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], schema: &Value, + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; diff --git a/crates/goose-llm/src/providers/utils.rs b/crates/goose-llm/src/providers/utils.rs index 1a3945dcb15c..b6c00e7bf237 100644 --- a/crates/goose-llm/src/providers/utils.rs +++ b/crates/goose-llm/src/providers/utils.rs @@ -181,30 +181,6 @@ fn is_image_file(path: &Path) -> bool { false } -/// Detect if a string contains a path to an image file -pub fn detect_image_path(text: &str) -> Option<&str> { - // Basic image file extension check - let extensions = [".png", ".jpg", ".jpeg"]; - - // Find any word that ends with an image extension - for word in text.split_whitespace() { - if extensions - .iter() - .any(|ext| word.to_lowercase().ends_with(ext)) - { - let path = Path::new(word); - // Check if it's an absolute path and file exists - if path.is_absolute() && path.is_file() { - // Verify it's actually an image file - if is_image_file(path) { - return Some(word); - } - } - } - } - None -} - /// Convert a local image file to base64 encoded ImageContent pub fn load_image_file(path: &str) -> Result { let path = Path::new(path); @@ -267,81 +243,6 @@ pub fn emit_debug_trace( mod tests { use super::*; - #[test] - fn test_detect_image_path() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - - // Test with valid PNG file using absolute path - let text = format!("Here is an image {}", png_path_str); - assert_eq!(detect_image_path(&text), Some(png_path_str)); - - // Test with non-image file that has .png extension - let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); - assert_eq!(detect_image_path(&text), None); - - // Test with non-existent file - let text = "Here is a fake.png that doesn't exist"; - assert_eq!(detect_image_path(text), None); - - // Test with non-image file - let text = "Here is a file.txt"; - assert_eq!(detect_image_path(text), None); - - // Test with relative path (should not match) - let text = "Here is a relative/path/image.png"; - assert_eq!(detect_image_path(text), None); - } - - #[test] - fn test_load_image_file() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - let fake_png_path_str = fake_png_path.to_str().unwrap(); - - // Test loading valid PNG file - let result = load_image_file(png_path_str); - assert!(result.is_ok()); - let image = result.unwrap(); - assert_eq!(image.mime_type, "image/png"); - - // Test loading fake PNG file - let result = load_image_file(fake_png_path_str); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("not a valid image")); - - // Test non-existent file - let result = load_image_file("nonexistent.png"); - assert!(result.is_err()); - } - #[test] fn test_sanitize_function_name() { assert_eq!(sanitize_function_name("hello-world"), "hello-world"); diff --git a/crates/goose-llm/src/structured_outputs.rs b/crates/goose-llm/src/structured_outputs.rs index 8f478d8aa184..b6690b641e74 100644 --- a/crates/goose-llm/src/structured_outputs.rs +++ b/crates/goose-llm/src/structured_outputs.rs @@ -6,13 +6,14 @@ use crate::{ /// Generates a structured output based on the provided schema, /// system prompt and user messages. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_structured_outputs( provider_name: &str, provider_config: JsonValueFfi, system_prompt: &str, messages: &[Message], schema: JsonValueFfi, + request_id: Option, ) -> Result { // Use OpenAI models specifically for this task let model_name = if provider_name == "databricks" { @@ -23,7 +24,9 @@ pub async fn generate_structured_outputs( let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0)); let provider = create(provider_name, provider_config, model_cfg)?; - let resp = provider.extract(system_prompt, messages, &schema).await?; + let resp = provider + .extract(system_prompt, messages, &schema, request_id.as_deref()) + .await?; Ok(resp) } diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index 21e0bcd9ddd3..ce54f6075ed2 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -20,6 +20,7 @@ pub struct CompletionRequest { pub system_prompt_override: Option, pub messages: Vec, pub extensions: Vec, + pub request_id: Option, } impl CompletionRequest { @@ -40,10 +41,17 @@ impl CompletionRequest { system_preamble, messages, extensions, + request_id: None, } } + + pub fn with_request_id(mut self, request_id: String) -> Self { + self.request_id = Some(request_id); + self + } } +#[allow(clippy::too_many_arguments)] #[uniffi::export(default(system_preamble = None, system_prompt_override = None))] pub fn create_completion_request( provider_name: &str, @@ -53,8 +61,9 @@ pub fn create_completion_request( system_prompt_override: Option, messages: Vec, extensions: Vec, + request_id: Option, ) -> CompletionRequest { - CompletionRequest::new( + let mut request = CompletionRequest::new( provider_name.to_string(), provider_config, model_config, @@ -62,7 +71,13 @@ pub fn create_completion_request( system_prompt_override, messages, extensions, - ) + ); + + if let Some(req_id) = request_id { + request = request.with_request_id(req_id); + } + + request } uniffi::custom_type!(CompletionRequest, String, { diff --git a/crates/goose-llm/tests/extract_session_name.rs b/crates/goose-llm/tests/extract_session_name.rs index 5326fdbe780d..58d0a6b4921e 100644 --- a/crates/goose-llm/tests/extract_session_name.rs +++ b/crates/goose-llm/tests/extract_session_name.rs @@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result Result {}", provider_type, resp.data); diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index b6b77d0182ac..de8030fd4951 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -13,13 +13,13 @@ workspace = true [dependencies] mcp-core = { path = "../mcp-core" } mcp-server = { path = "../mcp-server" } +rmcp = { workspace = true } anyhow = "1.0.94" tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" url = "2.5" -urlencoding = "2.1.3" base64 = "0.21" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } @@ -62,6 +62,7 @@ utoipa = { version = "4.1", optional = true } hyper = "1" serde_with = "3" which = "6.0" +glob = "0.3" [dev-dependencies] diff --git a/crates/goose-mcp/src/computercontroller/docx_tool.rs b/crates/goose-mcp/src/computercontroller/docx_tool.rs index 5e349d1e24e0..e88564e465b8 100644 --- a/crates/goose-mcp/src/computercontroller/docx_tool.rs +++ b/crates/goose-mcp/src/computercontroller/docx_tool.rs @@ -1,6 +1,7 @@ use docx_rs::*; use image::{self, ImageFormat}; -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; +use rmcp::model::Content; use std::{fs, io::Cursor}; #[derive(Debug)] @@ -568,9 +569,9 @@ mod tests { let content = result.unwrap(); assert!(!content.is_empty(), "Extracted text should not be empty"); let text = content[0].as_text().unwrap(); - println!("Extracted text:\n{}", text); + println!("Extracted text:\n{}", text.text); assert!( - !text.trim().is_empty(), + !text.text.trim().is_empty(), "Extracted text should not be empty" ); } @@ -609,11 +610,11 @@ mod tests { let content = result.unwrap(); let text = content[0].as_text().unwrap(); assert!( - text.contains("Test Heading"), + text.text.contains("Test Heading"), "Should contain written content" ); assert!( - text.contains("test paragraph"), + text.text.contains("test paragraph"), "Should contain written content" ); @@ -700,15 +701,15 @@ mod tests { let content = result.unwrap(); let text = content[0].as_text().unwrap(); assert!( - text.contains("New content here"), + text.text.contains("New content here"), "Should contain new content" ); assert!( - text.contains("Keep this text"), + text.text.contains("Keep this text"), "Should keep unmodified content" ); assert!( - !text.contains("This should be replaced"), + !text.text.contains("This should be replaced"), "Should not contain replaced text" ); @@ -846,22 +847,25 @@ mod tests { // Check for initial content assert!( - text.contains("Initial content"), + text.text.contains("Initial content"), "Should contain initial content" ); assert!( - text.contains("first paragraph"), + text.text.contains("first paragraph"), "Should contain first paragraph" ); assert!( - text.contains("should stay in the document"), + text.text.contains("should stay in the document"), "Should preserve existing content" ); // Check for new content - assert!(text.contains("New content"), "Should contain new content"); assert!( - text.contains("additional paragraph"), + text.text.contains("New content"), + "Should contain new content" + ); + assert!( + text.text.contains("additional paragraph"), "Should contain appended paragraph" ); diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index bdc9651eac90..ac0be9ac7548 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -17,10 +17,10 @@ use mcp_core::{ protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, tool::{Tool, ToolAnnotations}, - Content, }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; +use rmcp::model::Content; mod docx_tool; mod pdf_tool; diff --git a/crates/goose-mcp/src/computercontroller/pdf_tool.rs b/crates/goose-mcp/src/computercontroller/pdf_tool.rs index f25dde64e3d0..424d1f9e190f 100644 --- a/crates/goose-mcp/src/computercontroller/pdf_tool.rs +++ b/crates/goose-mcp/src/computercontroller/pdf_tool.rs @@ -1,5 +1,6 @@ use lopdf::{content::Content as PdfContent, Document, Object}; -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; +use rmcp::model::Content; use std::{fs, path::Path}; pub async fn pdf_tool( @@ -341,10 +342,10 @@ mod tests { let content = result.unwrap(); assert!(!content.is_empty(), "Extracted text should not be empty"); let text = content[0].as_text().unwrap(); - println!("Extracted text:\n{}", text); - assert!(text.contains("Page 1"), "Should contain page marker"); + println!("Extracted text:\n{}", text.text); + assert!(text.text.contains("Page 1"), "Should contain page marker"); assert!( - text.contains("This is a test PDF"), + text.text.contains("This is a test PDF"), "Should contain expected test content" ); } @@ -373,18 +374,19 @@ mod tests { "Image extraction result should not be empty" ); let text = content[0].as_text().unwrap(); - println!("Extracted content: {}", text); + println!("Extracted content: {}", text.text); // Should either find images or explicitly state none were found assert!( - text.contains("Saved image to:") || text.contains("No images found"), + text.text.contains("Saved image to:") || text.text.contains("No images found"), "Should either save images or report none found" ); // If we found images, verify they exist - if text.contains("Saved image to:") { + if text.text.contains("Saved image to:") { // Extract the file path from the output let file_path = text + .text .lines() .find(|line| line.contains("Saved image to:")) .and_then(|line| line.split(": ").nth(1)) diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 882dd1d81e23..eeb7006bffe3 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -27,7 +27,6 @@ use mcp_core::{ protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities}, resource::Resource, tool::Tool, - Content, }; use mcp_core::{ prompt::{Prompt, PromptArgument, PromptTemplate}, @@ -35,8 +34,9 @@ use mcp_core::{ }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; +use rmcp::model::Content; -use mcp_core::role::Role; +use rmcp::model::Role; use self::editor_models::{create_editor_model, EditorModel}; use self::shell::{expand_path, get_shell_config, is_absolute_path, normalize_line_endings}; @@ -158,10 +158,9 @@ impl DeveloperRouter { sourcing files do not persist between tool calls. So you may need to repeat them each time by stringing together commands, e.g. `cd example && ls` or `source env/bin/activate && pip install numpy` - **Important**: Use ripgrep - `rg` - when you need to locate a file or a code reference, other solutions - may show ignored or hidden files. For example *do not* use `find` or `ls -r` - - List files by name: `rg --files | rg ` - - List files that contain a regex: `rg '' -l` + - Restrictions: Avoid find, grep, cat, head, tail, ls - use dedicated tools instead (Grep, Glob, Read, LS) + - Multiple commands: Use ; or && to chain commands, avoid newlines + - Pathnames: Use absolute paths and avoid cd unless explicitly requested "#}, }; @@ -178,6 +177,88 @@ impl DeveloperRouter { None, ); + let glob_tool = Tool::new( + "glob".to_string(), + indoc! {r#" + Search for files using glob patterns. + + This tool provides fast file pattern matching using glob syntax. + Returns matching file paths sorted by modification time. + Examples: + - `*.rs` - Find all Rust files in current directory + - `src/**/*.py` - Find all Python files recursively in src directory + - `**/test*.js` - Find all JavaScript test files recursively + + **Important**: Use this tool instead of shell commands like `find` or `ls -r` for file searching, + as it properly handles ignored files and is more efficient. This tool respects .gooseignore patterns. + + Use this tool when you need to locate files by name patterns rather than content. + "#}.to_string(), + json!({ + "type": "object", + "required": ["pattern"], + "properties": { + "pattern": {"type": "string", "description": "The glob pattern to search for"}, + "path": {"type": "string", "description": "The directory to search in (defaults to current directory)"} + } + }), + Some(ToolAnnotations { + title: Some("Search files by pattern".to_string()), + read_only_hint: true, + destructive_hint: false, + idempotent_hint: true, + open_world_hint: false, + }), + ); + + let grep_tool = Tool::new( + "grep".to_string(), + indoc! {r#" + Execute file content search commands using ripgrep, grep, or find. + + Use this tool to run search commands that look for content within files. The tool + executes your command directly and filters results to respect .gooseignore patterns. + + **Recommended tools and usage:** + + **ripgrep (rg)** - Fast, recommended for most searches: + - List files containing pattern: `rg -l "pattern"` + - Case-insensitive search: `rg -i "pattern"` + - Search specific file types: `rg "pattern" --glob "*.js"` + - Show matches with context: `rg "pattern" -C 3` + - List files by name: `rg --files | rg ` + - List files that contain a regex: `rg '' -l` + - Sort by modification time: `rg -l "pattern" --sort modified` + + **grep** - Traditional Unix tool: + - Recursive search: `grep -r "pattern" .` + - List files only: `grep -rl "pattern" .` + - Include specific files: `grep -r "pattern" --include="*.py"` + + **find + grep** - When you need complex file filtering: + - `find . -name "*.py" -exec grep -l "pattern" {} \;` + - `find . -type f -newer file.txt -exec grep "pattern" {} \;` + + **Important**: Use this tool instead of the shell tool for search commands, as it + properly filters results to respect ignored files. + "#} + .to_string(), + json!({ + "type": "object", + "required": ["command"], + "properties": { + "command": {"type": "string", "description": "The search command to execute (rg, grep, find, etc.)"} + } + }), + Some(ToolAnnotations { + title: Some("Search file contents".to_string()), + read_only_hint: true, + destructive_hint: false, + idempotent_hint: true, + open_world_hint: false, + }), + ); + // Create text editor tool with different descriptions based on editor API configuration let (text_editor_desc, str_replace_command) = if let Some(ref editor) = editor_model { ( @@ -483,6 +564,8 @@ impl DeveloperRouter { Self { tools: vec![ bash_tool, + glob_tool, + grep_tool, text_editor_tool, list_windows_tool, screen_capture_tool, @@ -672,6 +755,69 @@ impl DeveloperRouter { ]) } + async fn glob(&self, params: Value) -> Result, ToolError> { + let pattern = + params + .get("pattern") + .and_then(|v| v.as_str()) + .ok_or(ToolError::InvalidParameters( + "The pattern string is required".to_string(), + ))?; + + let search_path = params.get("path").and_then(|v| v.as_str()).unwrap_or("."); + + let full_pattern = if search_path == "." { + pattern.to_string() + } else { + format!("{}/{}", search_path.trim_end_matches('/'), pattern) + }; + + let glob_result = glob::glob(&full_pattern) + .map_err(|e| ToolError::InvalidParameters(format!("Invalid glob pattern: {}", e)))?; + + let mut file_paths_with_metadata = Vec::new(); + + for entry in glob_result { + match entry { + Ok(path) => { + // Check if the path should be ignored + if !self.is_ignored(&path) { + // Get file metadata for sorting by modification time + if let Ok(metadata) = std::fs::metadata(&path) { + if metadata.is_file() { + let modified = metadata + .modified() + .unwrap_or(std::time::SystemTime::UNIX_EPOCH); + file_paths_with_metadata.push((path, modified)); + } + } + } + } + Err(e) => { + tracing::warn!("Error reading glob entry: {}", e); + } + } + } + + // Sort by modification time (newest first) + file_paths_with_metadata.sort_by(|a, b| b.1.cmp(&a.1)); + + // Extract just the file paths + let file_paths: Vec = file_paths_with_metadata + .into_iter() + .map(|(path, _)| path.to_string_lossy().to_string()) + .collect(); + + let result = file_paths.join("\n"); + + Ok(vec![ + Content::text(result.clone()).with_audience(vec![Role::Assistant]), + Content::text(result) + .with_audience(vec![Role::User]) + .with_priority(0.0), + ]) + } + async fn text_editor(&self, params: Value) -> Result, ToolError> { let command = params .get("command") @@ -1438,6 +1584,8 @@ impl Router for DeveloperRouter { Box::pin(async move { match tool_name.as_str() { "shell" => this.bash(arguments, notifier).await, + "glob" => this.glob(arguments).await, + "grep" => this.bash(arguments, notifier).await, "text_editor" => this.text_editor(arguments).await, "list_windows" => this.list_windows(arguments).await, "screen_capture" => this.screen_capture(arguments).await, @@ -1738,7 +1886,7 @@ mod tests { .unwrap() .as_text() .unwrap(); - assert!(text.contains("Hello, world!")); + assert!(text.text.contains("Hello, world!")); temp_dir.close().unwrap(); } @@ -1792,7 +1940,9 @@ mod tests { .as_text() .unwrap(); - assert!(text.contains("has been edited, and the section now reads")); + assert!(text + .text + .contains("has been edited, and the section now reads")); // View the file to verify the change let view_result = router @@ -1820,9 +1970,9 @@ mod tests { // Check that the file has been modified and contains some form of "Rust" // The Editor API might transform the content differently than simple string replacement assert!( - text.contains("Rust") || text.contains("Hello, Rust!"), + text.text.contains("Rust") || text.text.contains("Hello, Rust!"), "Expected content to contain 'Rust', but got: {}", - text + text.text ); temp_dir.close().unwrap(); @@ -1881,7 +2031,7 @@ mod tests { .unwrap(); let text = undo_result.first().unwrap().as_text().unwrap(); - assert!(text.contains("Undid the last edit")); + assert!(text.text.contains("Undid the last edit")); // View the file to verify the undo let view_result = router @@ -1905,7 +2055,7 @@ mod tests { .unwrap() .as_text() .unwrap(); - assert!(text.contains("First line")); + assert!(text.text.contains("First line")); temp_dir.close().unwrap(); } @@ -2359,14 +2509,14 @@ mod tests { .unwrap(); // Should contain lines 3-6 with line numbers - assert!(text.contains("3: Line 3")); - assert!(text.contains("4: Line 4")); - assert!(text.contains("5: Line 5")); - assert!(text.contains("6: Line 6")); - assert!(text.contains("(lines 3-6)")); + assert!(text.text.contains("3: Line 3")); + assert!(text.text.contains("4: Line 4")); + assert!(text.text.contains("5: Line 5")); + assert!(text.text.contains("6: Line 6")); + assert!(text.text.contains("(lines 3-6)")); // Should not contain other lines - assert!(!text.contains("1: Line 1")); - assert!(!text.contains("7: Line 7")); + assert!(!text.text.contains("1: Line 1")); + assert!(!text.text.contains("7: Line 7")); temp_dir.close().unwrap(); } @@ -2421,13 +2571,13 @@ mod tests { .unwrap(); // Should contain lines 3 to end - assert!(text.contains("3: Line 3")); - assert!(text.contains("4: Line 4")); - assert!(text.contains("5: Line 5")); - assert!(text.contains("(lines 3-end)")); + assert!(text.text.contains("3: Line 3")); + assert!(text.text.contains("4: Line 4")); + assert!(text.text.contains("5: Line 5")); + assert!(text.text.contains("(lines 3-end)")); // Should not contain earlier lines - assert!(!text.contains("1: Line 1")); - assert!(!text.contains("2: Line 2")); + assert!(!text.text.contains("1: Line 1")); + assert!(!text.text.contains("2: Line 2")); temp_dir.close().unwrap(); } @@ -2547,7 +2697,7 @@ mod tests { .as_text() .unwrap(); - assert!(text.contains("Text has been inserted at line 1")); + assert!(text.text.contains("Text has been inserted at line 1")); // Verify the file content let view_result = router @@ -2572,10 +2722,10 @@ mod tests { .as_text() .unwrap(); - assert!(view_text.contains("1: Line 1")); - assert!(view_text.contains("2: Line 2")); - assert!(view_text.contains("3: Line 3")); - assert!(view_text.contains("4: Line 4")); + assert!(view_text.text.contains("1: Line 1")); + assert!(view_text.text.contains("2: Line 2")); + assert!(view_text.text.contains("3: Line 3")); + assert!(view_text.text.contains("4: Line 4")); temp_dir.close().unwrap(); } @@ -2630,7 +2780,7 @@ mod tests { .as_text() .unwrap(); - assert!(text.contains("Text has been inserted at line 3")); + assert!(text.text.contains("Text has been inserted at line 3")); // Verify the file content let view_result = router @@ -2655,11 +2805,11 @@ mod tests { .as_text() .unwrap(); - assert!(view_text.contains("1: Line 1")); - assert!(view_text.contains("2: Line 2")); - assert!(view_text.contains("3: Line 3")); - assert!(view_text.contains("4: Line 4")); - assert!(view_text.contains("5: Line 5")); + assert!(view_text.text.contains("1: Line 1")); + assert!(view_text.text.contains("2: Line 2")); + assert!(view_text.text.contains("3: Line 3")); + assert!(view_text.text.contains("4: Line 4")); + assert!(view_text.text.contains("5: Line 5")); temp_dir.close().unwrap(); } @@ -2714,7 +2864,7 @@ mod tests { .as_text() .unwrap(); - assert!(text.contains("Text has been inserted at line 4")); + assert!(text.text.contains("Text has been inserted at line 4")); // Verify the file content let view_result = router @@ -2739,10 +2889,10 @@ mod tests { .as_text() .unwrap(); - assert!(view_text.contains("1: Line 1")); - assert!(view_text.contains("2: Line 2")); - assert!(view_text.contains("3: Line 3")); - assert!(view_text.contains("4: Line 4")); + assert!(view_text.text.contains("1: Line 1")); + assert!(view_text.text.contains("2: Line 2")); + assert!(view_text.text.contains("3: Line 3")); + assert!(view_text.text.contains("4: Line 4")); temp_dir.close().unwrap(); } @@ -2911,7 +3061,7 @@ mod tests { .unwrap(); let text = undo_result.first().unwrap().as_text().unwrap(); - assert!(text.contains("Undid the last edit")); + assert!(text.text.contains("Undid the last edit")); // Verify the file is back to original content let view_result = router @@ -2936,9 +3086,9 @@ mod tests { .as_text() .unwrap(); - assert!(view_text.contains("1: Line 1")); - assert!(view_text.contains("2: Line 2")); - assert!(!view_text.contains("Inserted Line")); + assert!(view_text.text.contains("1: Line 1")); + assert!(view_text.text.contains("2: Line 2")); + assert!(!view_text.text.contains("Inserted Line")); temp_dir.close().unwrap(); } diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 1f1aeae70ca2..5e85ecc1f501 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -11,13 +11,13 @@ use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::ToolAnnotations; use oauth_pkce::PkceOAuth2Client; use regex::Regex; +use rmcp::model::Content; use serde_json::{json, Value}; use std::io::Cursor; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; use storage::CredentialsManager; use tokio::sync::mpsc; -use mcp_core::content::Content; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, @@ -1845,7 +1845,12 @@ impl GoogleDriveRouter { .map(|contents| { contents .into_iter() - .map(|content| content.as_text().unwrap_or_default().to_string()) + .map(|content| { + content + .as_text() + .map(|text| text.text.clone()) + .unwrap_or_default() + }) .collect::>() .join("\n") }) diff --git a/crates/goose-mcp/src/jetbrains/mod.rs b/crates/goose-mcp/src/jetbrains/mod.rs deleted file mode 100644 index c015b9deae74..000000000000 --- a/crates/goose-mcp/src/jetbrains/mod.rs +++ /dev/null @@ -1,236 +0,0 @@ -mod proxy; - -use anyhow::Result; -use mcp_core::{ - content::Content, - handler::{PromptError, ResourceError, ToolError}, - prompt::Prompt, - protocol::{JsonRpcMessage, ServerCapabilities}, - resource::Resource, - role::Role, - tool::Tool, -}; -use mcp_server::router::CapabilitiesBuilder; -use mcp_server::Router; -use serde_json::Value; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::{sleep, Duration}; -use tracing::error; - -use self::proxy::JetBrainsProxy; - -pub struct JetBrainsRouter { - tools: Arc>>, - proxy: Arc, - instructions: String, -} - -impl Default for JetBrainsRouter { - fn default() -> Self { - Self::new() - } -} - -impl JetBrainsRouter { - pub fn new() -> Self { - let tools = Arc::new(Mutex::new(Vec::new())); - let proxy = Arc::new(JetBrainsProxy::new()); - let instructions = "JetBrains IDE integration".to_string(); - - // Initialize the proxy - let proxy_clone = Arc::clone(&proxy); - tokio::spawn(async move { - if let Err(e) = proxy_clone.start().await { - error!("Failed to start JetBrains proxy: {}", e); - } - }); - - // Start the background task to update tools - let tools_clone = Arc::clone(&tools); - let proxy_clone = Arc::clone(&proxy); - tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(5)); - loop { - interval.tick().await; - match proxy_clone.list_tools().await { - Ok(new_tools) => { - let mut tools = tools_clone.lock().await; - *tools = new_tools; - } - Err(e) => { - error!("Failed to update tools: {}", e); - } - } - } - }); - - Self { - tools, - proxy, - instructions, - } - } - - async fn call_proxy_tool( - &self, - tool_name: String, - arguments: Value, - ) -> Result, ToolError> { - let result = self - .proxy - .call_tool(&tool_name, arguments) - .await - .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - - // Create a success message for the assistant - let mut contents = vec![ - Content::text(format!("Tool {} executed successfully", tool_name)) - .with_audience(vec![Role::Assistant]), - ]; - - // Add the tool's result contents - contents.extend(result.content); - - Ok(contents) - } - - async fn ensure_tools(&self) -> Result<(), ToolError> { - let mut retry_count = 0; - let max_retries = 50; // 5 second total wait time - let retry_delay = Duration::from_millis(100); - - while retry_count < max_retries { - let tools = self.tools.lock().await; - if !tools.is_empty() { - return Ok(()); - } - drop(tools); // Release the lock before sleeping - - sleep(retry_delay).await; - retry_count += 1; - } - - Err(ToolError::ExecutionError("Failed to get tools list from IDE. Make sure the IDE is running and the plugin is installed.".to_string())) - } -} - -impl Router for JetBrainsRouter { - fn name(&self) -> String { - "jetbrains".to_string() - } - - fn instructions(&self) -> String { - self.instructions.clone() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new().with_tools(true).build() - } - - fn list_tools(&self) -> Vec { - // Use block_in_place to avoid blocking the runtime - tokio::task::block_in_place(|| { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - rt.block_on(async { - let tools = self.tools.lock().await; - if tools.is_empty() { - drop(tools); - if let Err(e) = self.ensure_tools().await { - error!("Failed to ensure tools: {}", e); - vec![] - } else { - self.tools.lock().await.clone() - } - } else { - tools.clone() - } - }) - }) - } - - fn call_tool( - &self, - tool_name: &str, - arguments: Value, - _notifier: mpsc::Sender, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - Box::pin(async move { - this.ensure_tools().await?; - this.call_proxy_tool(tool_name, arguments).await - }) - } - - fn list_resources(&self) -> Vec { - vec![] - } - - fn read_resource( - &self, - _uri: &str, - ) -> Pin> + Send + 'static>> { - Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) }) - } - - fn list_prompts(&self) -> Vec { - vec![] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))) - }) - } -} - -impl Clone for JetBrainsRouter { - fn clone(&self) -> Self { - Self { - tools: Arc::clone(&self.tools), - proxy: Arc::clone(&self.proxy), - instructions: self.instructions.clone(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::OnceCell; - - static JETBRAINS_ROUTER: OnceCell = OnceCell::const_new(); - - async fn get_router() -> &'static JetBrainsRouter { - JETBRAINS_ROUTER - .get_or_init(|| async { JetBrainsRouter::new() }) - .await - } - - #[tokio::test] - async fn test_router_creation() { - let router = get_router().await; - assert_eq!(router.name(), "jetbrains"); - assert!(!router.instructions().is_empty()); - } - - #[tokio::test] - async fn test_capabilities() { - let router = get_router().await; - let capabilities = router.capabilities(); - assert!(capabilities.tools.is_some()); - } -} diff --git a/crates/goose-mcp/src/jetbrains/proxy.rs b/crates/goose-mcp/src/jetbrains/proxy.rs deleted file mode 100644 index 382f271497f1..000000000000 --- a/crates/goose-mcp/src/jetbrains/proxy.rs +++ /dev/null @@ -1,342 +0,0 @@ -use anyhow::{anyhow, Result}; -use mcp_core::{Content, Tool}; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::env; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::RwLock; -use tracing::{debug, error, info}; - -const PORT_RANGE_START: u16 = 63342; -const PORT_RANGE_END: u16 = 63352; -const ENDPOINT_CHECK_INTERVAL: Duration = Duration::from_secs(10); - -#[derive(Debug, Serialize, Deserialize)] -struct IDEResponseOk { - status: String, - error: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct IDEResponseErr { - status: Option, - error: String, -} - -#[derive(Debug, Serialize)] -pub struct CallToolResult { - pub content: Vec, - pub is_error: bool, -} - -#[derive(Debug)] -pub struct JetBrainsProxy { - cached_endpoint: Arc>>, - previous_response: Arc>>, - client: Client, -} - -impl JetBrainsProxy { - pub fn new() -> Self { - Self { - cached_endpoint: Arc::new(RwLock::new(None)), - previous_response: Arc::new(RwLock::new(None)), - client: Client::new(), - } - } - - async fn test_list_tools(&self, endpoint: &str) -> Result { - debug!("Sending test request to {}/mcp/list_tools", endpoint); - - let response = match self - .client - .get(format!("{}/mcp/list_tools", endpoint)) - .send() - .await - { - Ok(resp) => { - debug!("Got response with status: {}", resp.status()); - resp - } - Err(e) => { - debug!("Error testing endpoint {}: {}", endpoint, e); - return Ok(false); - } - }; - - if !response.status().is_success() { - debug!("Test request failed with status {}", response.status()); - return Ok(false); - } - - let current_response = response.text().await?; - debug!("Received response: {}", current_response); - - // Try to parse as JSON array to validate format - if serde_json::from_str::>(¤t_response).is_err() { - debug!("Response is not a valid JSON array of tools"); - return Ok(false); - } - - let mut prev_response = self.previous_response.write().await; - if let Some(prev) = prev_response.as_ref() { - if prev != ¤t_response { - debug!("Response changed since last check"); - self.send_tools_changed().await; - } - } - *prev_response = Some(current_response); - - Ok(true) - } - - async fn find_working_ide_endpoint(&self) -> Result { - debug!("Attempting to find working IDE endpoint..."); - - // Check IDE_PORT environment variable first - if let Ok(port) = env::var("IDE_PORT") { - debug!("Found IDE_PORT environment variable: {}", port); - let test_endpoint = format!("http://127.0.0.1:{}/api", port); - if self.test_list_tools(&test_endpoint).await? { - debug!("IDE_PORT {} is working", port); - return Ok(test_endpoint); - } - debug!("IDE_PORT {} is not responding correctly", port); - return Err(anyhow!( - "Specified IDE_PORT={} is not responding correctly", - port - )); - } - - debug!( - "No IDE_PORT environment variable, scanning port range {}-{}", - PORT_RANGE_START, PORT_RANGE_END - ); - - // Scan port range - for port in PORT_RANGE_START..=PORT_RANGE_END { - let candidate_endpoint = format!("http://127.0.0.1:{}/api", port); - debug!("Testing port {}...", port); - - if self.test_list_tools(&candidate_endpoint).await? { - debug!("Found working IDE endpoint at {}", candidate_endpoint); - return Ok(candidate_endpoint); - } - } - - debug!("No working IDE endpoint found in port range"); - Err(anyhow!( - "No working IDE endpoint found in range {}-{}", - PORT_RANGE_START, - PORT_RANGE_END - )) - } - - async fn update_ide_endpoint(&self) { - debug!("Updating IDE endpoint..."); - match self.find_working_ide_endpoint().await { - Ok(endpoint) => { - let mut cached = self.cached_endpoint.write().await; - *cached = Some(endpoint.clone()); - debug!("Updated cached endpoint to: {}", endpoint); - } - Err(e) => { - debug!("Failed to update IDE endpoint: {}", e); - error!("Failed to update IDE endpoint: {}", e); - } - } - } - - pub async fn list_tools(&self) -> Result> { - debug!("Listing tools..."); - let endpoint = { - let cached = self.cached_endpoint.read().await; - match cached.as_ref() { - Some(ep) => { - debug!("Using cached endpoint: {}", ep); - ep.clone() - } - None => { - debug!("No cached endpoint available"); - return Ok(vec![]); - } - } - }; - - debug!("Sending list_tools request to {}/mcp/list_tools", endpoint); - let response = match self - .client - .get(format!("{}/mcp/list_tools", endpoint)) - .send() - .await - { - Ok(resp) => { - debug!("Got response with status: {}", resp.status()); - resp - } - Err(e) => { - debug!("Failed to send request: {}", e); - return Err(anyhow!("Failed to send request: {}", e)); - } - }; - - if !response.status().is_success() { - debug!("Request failed with status: {}", response.status()); - return Err(anyhow!( - "Failed to fetch tools with status {}", - response.status() - )); - } - - let response_text = response.text().await?; - debug!("Got response text: {}", response_text); - - let tools_response: Value = serde_json::from_str(&response_text).map_err(|e| { - debug!("Failed to parse response as JSON: {}", e); - anyhow!("Failed to parse response as JSON: {}", e) - })?; - - debug!("Parsed JSON response: {:?}", tools_response); - - let tools: Vec = tools_response - .as_array() - .ok_or_else(|| { - debug!("Response is not a JSON array"); - anyhow!("Invalid tools response format: not an array") - })? - .iter() - .filter_map(|t| { - if let (Some(name), Some(description)) = - (t["name"].as_str(), t["description"].as_str()) - { - // Get just the first sentence of the description - let first_sentence = description - .split('.') - .next() - .unwrap_or(description) - .trim() - .to_string() - + "."; - - // Handle input_schema as either a string or an object - let input_schema = match &t["inputSchema"] { - Value::String(s) => Value::String(s.clone()), - Value::Object(o) => Value::Object(o.clone()), - _ => { - debug!( - "Invalid inputSchema format for tool {}: {:?}", - name, t["inputSchema"] - ); - return None; - } - }; - - Some(Tool { - name: name.to_string(), - description: first_sentence, - input_schema, - annotations: None, - }) - } else { - debug!("Skipping invalid tool entry: {:?}", t); - None - } - }) - .collect(); - - debug!("Collected {} tools", tools.len()); - Ok(tools) - } - - pub async fn call_tool(&self, name: &str, args: Value) -> Result { - let endpoint = self - .cached_endpoint - .read() - .await - .clone() - .ok_or_else(|| anyhow!("No working IDE endpoint available"))?; - - debug!( - "ENDPOINT: {} | Tool name: {} | args: {}", - endpoint, name, args - ); - - let response = self - .client - .post(format!("{}/mcp/{}", endpoint, name)) - .json(&args) - .send() - .await?; - - if !response.status().is_success() { - debug!("Response failed with status: {}", response.status()); - return Err(anyhow!("Response failed: {}", response.status())); - } - - let ide_response: Value = response.json().await?; - let (is_error, text) = match ide_response { - Value::Object(map) => { - let status = map.get("status").and_then(|v| v.as_str()); - let error = map.get("error").and_then(|v| v.as_str()); - - match (status, error) { - (Some(s), None) => (false, s.to_string()), - (None, Some(e)) => (true, e.to_string()), - _ => { - debug!("Invalid response format from IDE"); - return Err(anyhow!("Invalid response format from IDE")); - } - } - } - _ => { - debug!("Unexpected response type from IDE"); - return Err(anyhow!("Unexpected response type from IDE")); - } - }; - - Ok(CallToolResult { - content: vec![Content::text(text)], - is_error, - }) - } - - async fn send_tools_changed(&self) { - debug!("Sending tools changed notification"); - // TODO: Implement notification mechanism when needed - } - - pub async fn start(&self) -> Result<()> { - debug!("Initializing JetBrains Proxy..."); - info!("Initializing JetBrains Proxy..."); - - // Initial endpoint check - debug!("Performing initial endpoint check..."); - self.update_ide_endpoint().await; - - // Schedule periodic endpoint checks - let proxy = self.clone(); - tokio::spawn(async move { - loop { - tokio::time::sleep(ENDPOINT_CHECK_INTERVAL).await; - debug!("Performing periodic endpoint check..."); - proxy.update_ide_endpoint().await; - } - }); - - debug!("JetBrains Proxy running"); - info!("JetBrains Proxy running"); - Ok(()) - } -} - -impl Clone for JetBrainsProxy { - fn clone(&self) -> Self { - Self { - cached_endpoint: Arc::clone(&self.cached_endpoint), - previous_response: Arc::clone(&self.previous_response), - client: Client::new(), - } - } -} diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index 472349f571fd..c112c8fee3e9 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -10,13 +10,11 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { pub mod computercontroller; mod developer; pub mod google_drive; -mod jetbrains; mod memory; mod tutorial; pub use computercontroller::ComputerControllerRouter; pub use developer::DeveloperRouter; pub use google_drive::GoogleDriveRouter; -pub use jetbrains::JetBrainsRouter; pub use memory::MemoryRouter; pub use tutorial::TutorialRouter; diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index eb5adaeeaff4..d1fee83c4565 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -18,10 +18,10 @@ use mcp_core::{ protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, tool::{Tool, ToolAnnotations, ToolCall}, - Content, }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; +use rmcp::model::Content; // MemoryRouter implementation #[derive(Clone)] diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index ea9e32f08589..966ab8fdd6e3 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -1,6 +1,7 @@ use anyhow::Result; use include_dir::{include_dir, Dir}; use indoc::formatdoc; +use rmcp::model::Content; use serde_json::{json, Value}; use std::{future::Future, pin::Pin}; use tokio::sync::mpsc; @@ -16,8 +17,6 @@ use mcp_core::{ use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; -use mcp_core::content::Content; - static TUTORIALS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/tutorial/tutorials"); pub struct TutorialRouter { diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index eaf6c30b9384..22f2c71157ee 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -15,6 +15,7 @@ goose = { path = "../goose" } mcp-core = { path = "../mcp-core" } goose-mcp = { path = "../goose-mcp" } mcp-server = { path = "../mcp-server" } +rmcp = { workspace = true } axum = { version = "0.8.1", features = ["ws", "macros"] } tokio = { version = "1.43", features = ["full"] } chrono = "0.4" @@ -52,4 +53,4 @@ path = "src/bin/generate_schema.rs" [dev-dependencies] tower = "0.5" -async-trait = "0.1" \ No newline at end of file +async-trait = "0.1" diff --git a/crates/goose-server/src/bin/generate_schema.rs b/crates/goose-server/src/bin/generate_schema.rs index 529be54c19c4..8d2588c00659 100644 --- a/crates/goose-server/src/bin/generate_schema.rs +++ b/crates/goose-server/src/bin/generate_schema.rs @@ -1,13 +1,18 @@ use goose_server::openapi; use std::env; use std::fs; +use std::path::PathBuf; fn main() { let schema = openapi::generate_schema(); - // Get the current working directory - let current_dir = env::current_dir().unwrap(); - let output_path = current_dir.join("ui").join("desktop").join("openapi.json"); + let package_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let output_path = PathBuf::from(package_dir) + .join("..") + .join("..") + .join("ui") + .join("desktop") + .join("openapi.json"); // Ensure parent directory exists if let Some(parent) = output_path.parent() { diff --git a/crates/goose-server/src/commands/mcp.rs b/crates/goose-server/src/commands/mcp.rs index 5a23339a76fb..85395352db89 100644 --- a/crates/goose-server/src/commands/mcp.rs +++ b/crates/goose-server/src/commands/mcp.rs @@ -1,7 +1,6 @@ use anyhow::Result; use goose_mcp::{ - ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, JetBrainsRouter, MemoryRouter, - TutorialRouter, + ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, MemoryRouter, TutorialRouter, }; use mcp_server::router::RouterService; use mcp_server::{BoundedService, ByteTransport, Server}; @@ -15,7 +14,6 @@ pub async fn run(name: &str) -> Result<()> { let router: Option> = match name { "developer" => Some(Box::new(RouterService(DeveloperRouter::new()))), "computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))), - "jetbrains" => Some(Box::new(RouterService(JetBrainsRouter::new()))), "google_drive" | "googledrive" => { let router = GoogleDriveRouter::new().await; Some(Box::new(RouterService(router))) diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index ae77398fcf83..b49f7cd51038 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -11,12 +11,281 @@ use goose::permission::permission_confirmation::PrincipalType; use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata}; use goose::session::info::SessionInfo; use goose::session::SessionMetadata; -use mcp_core::content::{Annotations, Content, EmbeddedResource, ImageContent, TextContent}; use mcp_core::handler::ToolResultSchema; use mcp_core::resource::ResourceContents; -use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolAnnotations}; -use utoipa::OpenApi; +use rmcp::model::{Annotations, Content, EmbeddedResource, ImageContent, Role, TextContent}; +use utoipa::{OpenApi, ToSchema}; + +use rmcp::schemars::schema::{InstanceType, SchemaObject, SingleOrVec}; +use utoipa::openapi::schema::{ + AdditionalProperties, AnyOfBuilder, ArrayBuilder, ObjectBuilder, OneOfBuilder, Schema, + SchemaFormat, SchemaType, +}; +use utoipa::openapi::{AllOfBuilder, Ref, RefOr}; + +macro_rules! derive_utoipa { + ($inner_type:ident as $schema_name:ident) => { + struct $schema_name {} + + impl<'__s> ToSchema<'__s> for $schema_name { + fn schema() -> (&'__s str, utoipa::openapi::RefOr) { + let settings = rmcp::schemars::gen::SchemaSettings::openapi3(); + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::<$inner_type>(); + let schema = convert_schemars_to_utoipa(schema); + (stringify!($inner_type), schema) + } + + fn aliases() -> Vec<(&'__s str, utoipa::openapi::schema::Schema)> { + Vec::new() + } + } + }; +} + +fn convert_schemars_to_utoipa(schema: rmcp::schemars::schema::RootSchema) -> RefOr { + convert_schema_object(&rmcp::schemars::schema::Schema::Object( + schema.schema.clone(), + )) +} + +fn convert_schema_object(schema: &rmcp::schemars::schema::Schema) -> RefOr { + match schema { + rmcp::schemars::schema::Schema::Object(schema_object) => { + convert_schema_object_inner(schema_object) + } + rmcp::schemars::schema::Schema::Bool(true) => { + RefOr::T(Schema::Object(ObjectBuilder::new().build())) + } + rmcp::schemars::schema::Schema::Bool(false) => { + RefOr::T(Schema::Object(ObjectBuilder::new().build())) + } + } +} + +fn convert_schema_object_inner(schema: &SchemaObject) -> RefOr { + // Handle references first + if let Some(reference) = &schema.reference { + return RefOr::Ref(Ref::new(reference.clone())); + } + + // Handle subschemas (oneOf, allOf, anyOf) + if let Some(subschemas) = &schema.subschemas { + if let Some(one_of) = &subschemas.one_of { + let schemas: Vec> = one_of.iter().map(convert_schema_object).collect(); + let mut builder = OneOfBuilder::new(); + for schema in schemas { + builder = builder.item(schema); + } + return RefOr::T(Schema::OneOf(builder.build())); + } + if let Some(all_of) = &subschemas.all_of { + let schemas: Vec> = all_of.iter().map(convert_schema_object).collect(); + let mut all_of = AllOfBuilder::new(); + for schema in schemas { + all_of = all_of.item(schema); + } + return RefOr::T(Schema::AllOf(all_of.build())); + } + if let Some(any_of) = &subschemas.any_of { + let schemas: Vec> = any_of.iter().map(convert_schema_object).collect(); + let mut any_of = AnyOfBuilder::new(); + for schema in schemas { + any_of = any_of.item(schema); + } + return RefOr::T(Schema::AnyOf(any_of.build())); + } + } + + // Handle based on instance type + match &schema.instance_type { + Some(SingleOrVec::Single(instance_type)) => { + convert_single_instance_type(instance_type, schema) + } + Some(SingleOrVec::Vec(instance_types)) => { + // Multiple types - use AnyOf + let schemas: Vec> = instance_types + .iter() + .map(|instance_type| convert_single_instance_type(instance_type, schema)) + .collect(); + let mut any_of = AnyOfBuilder::new(); + for schema in schemas { + any_of = any_of.item(schema); + } + RefOr::T(Schema::AnyOf(any_of.build())) + } + None => { + // No type specified - create a generic schema + RefOr::T(Schema::Object(ObjectBuilder::new().build())) + } + } +} + +fn convert_single_instance_type( + instance_type: &InstanceType, + schema: &SchemaObject, +) -> RefOr { + match instance_type { + InstanceType::Object => { + let mut object_builder = ObjectBuilder::new(); + + if let Some(object_validation) = &schema.object { + // Add properties + for (name, prop_schema) in &object_validation.properties { + let prop = convert_schema_object(prop_schema); + object_builder = object_builder.property(name, prop); + } + + // Add required fields + for required_field in &object_validation.required { + object_builder = object_builder.required(required_field); + } + + // Handle additional properties + if let Some(additional) = &object_validation.additional_properties { + match &**additional { + rmcp::schemars::schema::Schema::Bool(false) => { + object_builder = object_builder + .additional_properties(Some(AdditionalProperties::FreeForm(false))); + } + rmcp::schemars::schema::Schema::Bool(true) => { + object_builder = object_builder + .additional_properties(Some(AdditionalProperties::FreeForm(true))); + } + rmcp::schemars::schema::Schema::Object(obj) => { + let schema = convert_schema_object( + &rmcp::schemars::schema::Schema::Object(obj.clone()), + ); + object_builder = object_builder + .additional_properties(Some(AdditionalProperties::RefOr(schema))); + } + } + } + } + + RefOr::T(Schema::Object(object_builder.build())) + } + InstanceType::Array => { + let mut array_builder = ArrayBuilder::new(); + + if let Some(array_validation) = &schema.array { + // Add items schema + if let Some(items) = &array_validation.items { + match items { + rmcp::schemars::schema::SingleOrVec::Single(item_schema) => { + let item_schema = convert_schema_object(item_schema); + array_builder = array_builder.items(item_schema); + } + rmcp::schemars::schema::SingleOrVec::Vec(item_schemas) => { + // Multiple item types - use AnyOf + let schemas: Vec> = + item_schemas.iter().map(convert_schema_object).collect(); + let mut any_of = AnyOfBuilder::new(); + for schema in schemas { + any_of = any_of.item(schema); + } + let any_of_schema = RefOr::T(Schema::AnyOf(any_of.build())); + array_builder = array_builder.items(any_of_schema); + } + } + } + + // Add constraints + if let Some(min_items) = array_validation.min_items { + array_builder = array_builder.min_items(Some(min_items as usize)); + } + if let Some(max_items) = array_validation.max_items { + array_builder = array_builder.max_items(Some(max_items as usize)); + } + } + + RefOr::T(Schema::Array(array_builder.build())) + } + InstanceType::String => { + let mut object_builder = ObjectBuilder::new().schema_type(SchemaType::String); + + if let Some(string_validation) = &schema.string { + if let Some(min_length) = string_validation.min_length { + object_builder = object_builder.min_length(Some(min_length as usize)); + } + if let Some(max_length) = string_validation.max_length { + object_builder = object_builder.max_length(Some(max_length as usize)); + } + if let Some(pattern) = &string_validation.pattern { + object_builder = object_builder.pattern(Some(pattern.clone())); + } + } + + if let Some(format) = &schema.format { + object_builder = object_builder.format(Some(SchemaFormat::Custom(format.clone()))); + } + + RefOr::T(Schema::Object(object_builder.build())) + } + InstanceType::Number => { + let mut object_builder = ObjectBuilder::new().schema_type(SchemaType::Number); + + if let Some(number_validation) = &schema.number { + if let Some(minimum) = number_validation.minimum { + object_builder = object_builder.minimum(Some(minimum)); + } + if let Some(maximum) = number_validation.maximum { + object_builder = object_builder.maximum(Some(maximum)); + } + if let Some(exclusive_minimum) = number_validation.exclusive_minimum { + object_builder = object_builder.exclusive_minimum(Some(exclusive_minimum)); + } + if let Some(exclusive_maximum) = number_validation.exclusive_maximum { + object_builder = object_builder.exclusive_maximum(Some(exclusive_maximum)); + } + if let Some(multiple_of) = number_validation.multiple_of { + object_builder = object_builder.multiple_of(Some(multiple_of)); + } + } + + RefOr::T(Schema::Object(object_builder.build())) + } + InstanceType::Integer => { + let mut object_builder = ObjectBuilder::new().schema_type(SchemaType::Integer); + + if let Some(number_validation) = &schema.number { + if let Some(minimum) = number_validation.minimum { + object_builder = object_builder.minimum(Some(minimum)); + } + if let Some(maximum) = number_validation.maximum { + object_builder = object_builder.maximum(Some(maximum)); + } + if let Some(exclusive_minimum) = number_validation.exclusive_minimum { + object_builder = object_builder.exclusive_minimum(Some(exclusive_minimum)); + } + if let Some(exclusive_maximum) = number_validation.exclusive_maximum { + object_builder = object_builder.exclusive_maximum(Some(exclusive_maximum)); + } + if let Some(multiple_of) = number_validation.multiple_of { + object_builder = object_builder.multiple_of(Some(multiple_of)); + } + } + + RefOr::T(Schema::Object(object_builder.build())) + } + InstanceType::Boolean => RefOr::T(Schema::Object( + ObjectBuilder::new() + .schema_type(SchemaType::Boolean) + .build(), + )), + InstanceType::Null => RefOr::T(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )), + } +} + +derive_utoipa!(Role as RoleSchema); +derive_utoipa!(Content as ContentSchema); +derive_utoipa!(EmbeddedResource as EmbeddedResourceSchema); +derive_utoipa!(ImageContent as ImageContentSchema); +derive_utoipa!(TextContent as TextContentSchema); +derive_utoipa!(Annotations as AnnotationsSchema); #[allow(dead_code)] // Used by utoipa for OpenAPI generation #[derive(OpenApi)] @@ -68,11 +337,11 @@ use utoipa::OpenApi; super::routes::session::SessionHistoryResponse, Message, MessageContent, - Content, - EmbeddedResource, - ImageContent, - Annotations, - TextContent, + ContentSchema, + EmbeddedResourceSchema, + ImageContentSchema, + AnnotationsSchema, + TextContentSchema, ToolResponse, ToolRequest, ToolResultSchema, @@ -83,7 +352,7 @@ use utoipa::OpenApi; ResourceContents, ContextLengthExceeded, SummarizationRequested, - Role, + RoleSchema, ProviderMetadata, ExtensionEntry, ExtensionConfig, diff --git a/crates/goose-server/src/routes/mod.rs b/crates/goose-server/src/routes/mod.rs index c5e662ec2c16..b757baa0306d 100644 --- a/crates/goose-server/src/routes/mod.rs +++ b/crates/goose-server/src/routes/mod.rs @@ -5,6 +5,7 @@ pub mod config_management; pub mod context; pub mod extension; pub mod health; +pub mod project; pub mod recipe; pub mod reply; pub mod schedule; @@ -27,4 +28,5 @@ pub fn configure(state: Arc) -> Router { .merge(recipe::routes(state.clone())) .merge(session::routes(state.clone())) .merge(schedule::routes(state.clone())) + .merge(project::routes(state.clone())) } diff --git a/crates/goose-server/src/routes/project.rs b/crates/goose-server/src/routes/project.rs new file mode 100644 index 000000000000..a83c1a0101e4 --- /dev/null +++ b/crates/goose-server/src/routes/project.rs @@ -0,0 +1,358 @@ +use super::utils::verify_secret_key; +use std::sync::Arc; + +use crate::state::AppState; +use axum::{ + extract::{Path, State}, + http::{HeaderMap, StatusCode}, + routing::{delete, get, post, put}, + Json, Router, +}; +use goose::project::{Project, ProjectMetadata}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateProjectRequest { + /// Display name of the project + pub name: String, + /// Optional description of the project + pub description: Option, + /// Default working directory for sessions in this project + #[schema(value_type = String)] + pub default_directory: std::path::PathBuf, +} + +#[derive(Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct UpdateProjectRequest { + /// Display name of the project + pub name: Option, + /// Optional description of the project + pub description: Option>, + /// Default working directory for sessions in this project + #[schema(value_type = String)] + pub default_directory: Option, +} + +#[derive(Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ProjectListResponse { + /// List of available project metadata objects + pub projects: Vec, +} + +#[derive(Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ProjectResponse { + /// Project details + pub project: Project, +} + +#[utoipa::path( + get, + path = "/projects", + responses( + (status = 200, description = "List of available projects retrieved successfully", body = ProjectListResponse), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// List all available projects +async fn list_projects( + State(state): State>, + headers: HeaderMap, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + let projects = + goose::project::list_projects().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(ProjectListResponse { projects })) +} + +#[utoipa::path( + get, + path = "/projects/{project_id}", + params( + ("project_id" = String, Path, description = "Unique identifier for the project") + ), + responses( + (status = 200, description = "Project details retrieved successfully", body = ProjectResponse), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Project not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Get a specific project details +async fn get_project_details( + State(state): State>, + headers: HeaderMap, + Path(project_id): Path, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + let project = goose::project::get_project(&project_id).map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; + + Ok(Json(ProjectResponse { project })) +} + +#[utoipa::path( + post, + path = "/projects", + request_body = CreateProjectRequest, + responses( + (status = 201, description = "Project created successfully", body = ProjectResponse), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 400, description = "Invalid request - Bad input parameters"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Create a new project +async fn create_project( + State(state): State>, + headers: HeaderMap, + Json(create_req): Json, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + // Validate input + if create_req.name.trim().is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + let project = goose::project::create_project( + create_req.name, + create_req.description, + create_req.default_directory, + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(ProjectResponse { project })) +} + +#[utoipa::path( + put, + path = "/projects/{project_id}", + params( + ("project_id" = String, Path, description = "Unique identifier for the project") + ), + request_body = UpdateProjectRequest, + responses( + (status = 200, description = "Project updated successfully", body = ProjectResponse), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Project not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Update a project +async fn update_project( + State(state): State>, + headers: HeaderMap, + Path(project_id): Path, + Json(update_req): Json, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + let project = goose::project::update_project( + &project_id, + update_req.name, + update_req.description, + update_req.default_directory, + ) + .map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; + + Ok(Json(ProjectResponse { project })) +} + +#[utoipa::path( + delete, + path = "/projects/{project_id}", + params( + ("project_id" = String, Path, description = "Unique identifier for the project") + ), + responses( + (status = 204, description = "Project deleted successfully"), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Project not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Delete a project +async fn delete_project( + State(state): State>, + headers: HeaderMap, + Path(project_id): Path, +) -> Result { + verify_secret_key(&headers, &state)?; + + goose::project::delete_project(&project_id).map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; + + Ok(StatusCode::NO_CONTENT) +} + +#[utoipa::path( + post, + path = "/projects/{project_id}/sessions/{session_id}", + params( + ("project_id" = String, Path, description = "Unique identifier for the project"), + ("session_id" = String, Path, description = "Unique identifier for the session to add") + ), + responses( + (status = 204, description = "Session added to project successfully"), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Project or session not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Add session to project +async fn add_session_to_project( + State(state): State>, + headers: HeaderMap, + Path((project_id, session_id)): Path<(String, String)>, +) -> Result { + verify_secret_key(&headers, &state)?; + + // Add the session to project + goose::project::add_session_to_project(&project_id, &session_id).map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; + + // Also update session metadata to include the project_id + let session_path = + goose::session::get_path(goose::session::Identifier::Name(session_id.clone())) + .map_err(|_| StatusCode::NOT_FOUND)?; + let mut metadata = + goose::session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; + metadata.project_id = Some(project_id); + + tokio::task::spawn(async move { + if let Err(e) = goose::session::update_metadata(&session_path, &metadata).await { + tracing::error!("Failed to update session metadata: {}", e); + } + }); + + Ok(StatusCode::NO_CONTENT) +} + +#[utoipa::path( + delete, + path = "/projects/{project_id}/sessions/{session_id}", + params( + ("project_id" = String, Path, description = "Unique identifier for the project"), + ("session_id" = String, Path, description = "Unique identifier for the session to remove") + ), + responses( + (status = 204, description = "Session removed from project successfully"), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Project or session not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Project Management" +)] +// Remove session from project +async fn remove_session_from_project( + State(state): State>, + headers: HeaderMap, + Path((project_id, session_id)): Path<(String, String)>, +) -> Result { + verify_secret_key(&headers, &state)?; + + // Remove from project + goose::project::remove_session_from_project(&project_id, &session_id).map_err(|e| { + if e.to_string().contains("not found") { + StatusCode::NOT_FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + })?; + + // Also update session metadata to remove the project_id + let session_path = + goose::session::get_path(goose::session::Identifier::Name(session_id.clone())) + .map_err(|_| StatusCode::NOT_FOUND)?; + let mut metadata = + goose::session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?; + + // Only update if this session was actually in this project + if metadata.project_id.as_deref() == Some(&project_id) { + metadata.project_id = None; + + tokio::task::spawn(async move { + if let Err(e) = goose::session::update_metadata(&session_path, &metadata).await { + tracing::error!("Failed to update session metadata: {}", e); + } + }); + } + + Ok(StatusCode::NO_CONTENT) +} + +// Configure routes for this module +pub fn routes(state: Arc) -> Router { + Router::new() + .route("/projects", get(list_projects)) + .route("/projects", post(create_project)) + .route("/projects/{project_id}", get(get_project_details)) + .route("/projects/{project_id}", put(update_project)) + .route("/projects/{project_id}", delete(delete_project)) + .route( + "/projects/{project_id}/sessions/{session_id}", + post(add_session_to_project), + ) + .route( + "/projects/{project_id}/sessions/{session_id}", + delete(remove_session_from_project), + ) + .with_state(state) +} diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index b3a9c8f02885..422856a6d5b0 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -20,7 +20,7 @@ "gcp_vertex_ai": { "name": "GCP Vertex AI", "description": "Use Vertex AI platform models", - "models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06"], + "models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "claude-sonnet-4@20250514", "claude-opus-4@20250514", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06", "gemini-2.5-flash", "gemini-2.5-pro"], "required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"] }, "google": { diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index d2f2df7b5ca3..871df176b574 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; use goose::message::Message; use goose::recipe::Recipe; +use goose::recipe_deeplink; use serde::{Deserialize, Serialize}; use crate::state::AppState; @@ -34,6 +35,26 @@ pub struct CreateRecipeResponse { error: Option, } +#[derive(Debug, Deserialize)] +pub struct EncodeRecipeRequest { + recipe: Recipe, +} + +#[derive(Debug, Serialize)] +pub struct EncodeRecipeResponse { + deeplink: String, +} + +#[derive(Debug, Deserialize)] +pub struct DecodeRecipeRequest { + deeplink: String, +} + +#[derive(Debug, Serialize)] +pub struct DecodeRecipeResponse { + recipe: Recipe, +} + /// Create a Recipe configuration from the current state of an agent async fn create_recipe( State(state): State>, @@ -84,8 +105,70 @@ async fn create_recipe( } } +async fn encode_recipe( + Json(request): Json, +) -> Result, StatusCode> { + match recipe_deeplink::encode(&request.recipe) { + Ok(encoded) => Ok(Json(EncodeRecipeResponse { deeplink: encoded })), + Err(err) => { + tracing::error!("Failed to encode recipe: {}", err); + Err(StatusCode::BAD_REQUEST) + } + } +} + +async fn decode_recipe( + Json(request): Json, +) -> Result, StatusCode> { + match recipe_deeplink::decode(&request.deeplink) { + Ok(recipe) => Ok(Json(DecodeRecipeResponse { recipe })), + Err(err) => { + tracing::error!("Failed to decode deeplink: {}", err); + Err(StatusCode::BAD_REQUEST) + } + } +} + pub fn routes(state: Arc) -> Router { Router::new() .route("/recipe/create", post(create_recipe)) + .route("/recipes/encode", post(encode_recipe)) + .route("/recipes/decode", post(decode_recipe)) .with_state(state) } + +#[cfg(test)] +mod tests { + use super::*; + use goose::recipe::Recipe; + + #[tokio::test] + async fn test_decode_and_encode_recipe() { + let original_recipe = Recipe::builder() + .title("Test Recipe") + .description("A test recipe") + .instructions("Test instructions") + .build() + .unwrap(); + let encoded = recipe_deeplink::encode(&original_recipe).unwrap(); + + let request = DecodeRecipeRequest { + deeplink: encoded.clone(), + }; + let response = decode_recipe(Json(request)).await; + + assert!(response.is_ok()); + let decoded = response.unwrap().0.recipe; + assert_eq!(decoded.title, original_recipe.title); + assert_eq!(decoded.description, original_recipe.description); + assert_eq!(decoded.instructions, original_recipe.instructions); + + let encode_request = EncodeRecipeRequest { recipe: decoded }; + let encode_response = encode_recipe(Json(encode_request)).await; + + assert!(encode_response.is_ok()); + let encoded_again = encode_response.unwrap().0.deeplink; + assert!(!encoded_again.is_empty()); + assert_eq!(encoded, encoded_again); + } +} diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 55d04d955282..705e88c73668 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -18,7 +18,8 @@ use goose::{ permission::{Permission, PermissionConfirmation}, session, }; -use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult}; +use mcp_core::{protocol::JsonRpcMessage, role::Role, ToolResult}; +use rmcp::model::Content; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -123,7 +124,7 @@ async fn handler( let stream = ReceiverStream::new(rx); let messages = request.messages; - let session_working_dir = request.session_working_dir; + let session_working_dir = request.session_working_dir.clone(); let session_id = request .session_id @@ -181,7 +182,7 @@ async fn handler( &messages, Some(SessionConfig { id: session::Identifier::Name(session_id.clone()), - working_dir: PathBuf::from(session_working_dir), + working_dir: PathBuf::from(&session_working_dir), schedule_id: request.scheduled_job_id.clone(), execution_mode: None, max_turns: None, @@ -297,8 +298,13 @@ async fn handler( if all_messages.len() > saved_message_count { let provider = Arc::clone(provider.as_ref().unwrap()); tokio::spawn(async move { - if let Err(e) = - session::persist_messages(&session_path, &all_messages, Some(provider)).await + if let Err(e) = session::persist_messages( + &session_path, + &all_messages, + Some(provider), + Some(PathBuf::from(&session_working_dir)), + ) + .await { tracing::error!("Failed to store session history: {:?}", e); } @@ -337,7 +343,7 @@ async fn ask_handler( ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let session_working_dir = request.session_working_dir; + let session_working_dir = request.session_working_dir.clone(); let session_id = request .session_id @@ -358,7 +364,7 @@ async fn ask_handler( &messages, Some(SessionConfig { id: session::Identifier::Name(session_id.clone()), - working_dir: PathBuf::from(session_working_dir), + working_dir: PathBuf::from(&session_working_dir), schedule_id: request.scheduled_job_id.clone(), execution_mode: None, max_turns: None, @@ -420,9 +426,15 @@ async fn ask_handler( let session_path_clone = session_path.clone(); let messages = all_messages.clone(); let provider = Arc::clone(provider.as_ref().unwrap()); + let session_working_dir_clone = session_working_dir.clone(); tokio::spawn(async move { - if let Err(e) = - session::persist_messages(&session_path_clone, &messages, Some(provider)).await + if let Err(e) = session::persist_messages( + &session_path_clone, + &messages, + Some(provider), + Some(PathBuf::from(session_working_dir_clone)), + ) + .await { tracing::error!("Failed to store session history: {:?}", e); } diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index ca0d59703058..8ed509e46f3d 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -1,4 +1,6 @@ use super::utils::verify_secret_key; +use chrono::{DateTime, Datelike}; +use std::collections::HashMap; use std::sync::Arc; use crate::state::AppState; @@ -13,6 +15,7 @@ use goose::session; use goose::session::info::{get_valid_sorted_sessions, SessionInfo, SortOrder}; use goose::session::SessionMetadata; use serde::Serialize; +use tracing::{error, info}; use utoipa::ToSchema; #[derive(Serialize, ToSchema)] @@ -33,6 +36,29 @@ pub struct SessionHistoryResponse { messages: Vec, } +#[derive(Serialize, ToSchema, Debug)] +#[serde(rename_all = "camelCase")] +pub struct SessionInsights { + /// Total number of sessions + total_sessions: usize, + /// Most active working directories with session counts + most_active_dirs: Vec<(String, usize)>, + /// Average session duration in minutes + avg_session_duration: f64, + /// Total tokens used across all sessions + total_tokens: i64, + /// Activity trend for the last 7 days + recent_activity: Vec<(String, usize)>, +} + +#[derive(Serialize, ToSchema, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ActivityHeatmapCell { + pub week: usize, + pub day: usize, + pub count: usize, +} + #[utoipa::path( get, path = "/sessions", @@ -106,10 +132,174 @@ async fn get_session_history( })) } +#[utoipa::path( + get, + path = "/sessions/insights", + responses( + (status = 200, description = "Session insights retrieved successfully", body = SessionInsights), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Session Management" +)] +async fn get_session_insights( + State(state): State>, + headers: HeaderMap, +) -> Result, StatusCode> { + info!("Received request for session insights"); + + verify_secret_key(&headers, &state)?; + + let sessions = get_valid_sorted_sessions(SortOrder::Descending).map_err(|e| { + error!("Failed to get session info: {:?}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Filter out sessions without descriptions + let sessions: Vec = sessions + .into_iter() + .filter(|session| !session.metadata.description.is_empty()) + .collect(); + + info!("Found {} sessions with descriptions", sessions.len()); + + // Calculate insights + let total_sessions = sessions.len(); + + // Debug: Log if we have very few sessions, which might indicate filtering issues + if total_sessions == 0 { + info!("Warning: No sessions found with descriptions"); + } + + // Track directory usage + let mut dir_counts: HashMap = HashMap::new(); + let mut total_duration = 0.0; + let mut total_tokens = 0; + let mut activity_by_date: HashMap = HashMap::new(); + + for session in &sessions { + // Track directory usage + let dir = session.metadata.working_dir.to_string_lossy().to_string(); + *dir_counts.entry(dir).or_insert(0) += 1; + + // Track tokens - only add positive values to prevent negative totals + if let Some(tokens) = session.metadata.accumulated_total_tokens { + if tokens > 0 { + total_tokens += tokens as i64; + } else if tokens < 0 { + // Log negative token values for debugging + info!( + "Warning: Session {} has negative accumulated_total_tokens: {}", + session.id, tokens + ); + } + } + + // Track activity by date + if let Ok(date) = DateTime::parse_from_str(&session.modified, "%Y-%m-%d %H:%M:%S UTC") { + let date_str = date.format("%Y-%m-%d").to_string(); + *activity_by_date.entry(date_str).or_insert(0) += 1; + } + + // Calculate session duration from messages + let session_path = session::get_path(session::Identifier::Name(session.id.clone())); + if let Ok(session_path) = session_path { + if let Ok(messages) = session::read_messages(&session_path) { + if let (Some(first), Some(last)) = (messages.first(), messages.last()) { + let duration = (last.created - first.created) as f64 / 60.0; // Convert to minutes + total_duration += duration; + } + } + } + } + + // Get top 3 most active directories + let mut dir_vec: Vec<(String, usize)> = dir_counts.into_iter().collect(); + dir_vec.sort_by(|a, b| b.1.cmp(&a.1)); + let most_active_dirs = dir_vec.into_iter().take(3).collect(); + + // Calculate average session duration + let avg_session_duration = if total_sessions > 0 { + total_duration / total_sessions as f64 + } else { + 0.0 + }; + + // Get last 7 days of activity + let mut activity_vec: Vec<(String, usize)> = activity_by_date.into_iter().collect(); + activity_vec.sort_by(|a, b| b.0.cmp(&a.0)); // Sort by date descending + let recent_activity = activity_vec.into_iter().take(7).collect(); + + let insights = SessionInsights { + total_sessions, + most_active_dirs, + avg_session_duration, + total_tokens, + recent_activity, + }; + + info!("Returning insights: {:?}", insights); + Ok(Json(insights)) +} + +#[utoipa::path( + get, + path = "/sessions/activity-heatmap", + responses( + (status = 200, description = "Activity heatmap data", body = [ActivityHeatmapCell]), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 500, description = "Internal server error") + ), + security(("api_key" = [])), + tag = "Session Management" +)] +async fn get_activity_heatmap( + State(state): State>, + headers: HeaderMap, +) -> Result>, StatusCode> { + verify_secret_key(&headers, &state)?; + + let sessions = get_valid_sorted_sessions(SortOrder::Descending) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // Only sessions with a description + let sessions: Vec = sessions + .into_iter() + .filter(|session| !session.metadata.description.is_empty()) + .collect(); + + // Map: (week, day) -> count + let mut heatmap: std::collections::HashMap<(usize, usize), usize> = + std::collections::HashMap::new(); + + for session in &sessions { + if let Ok(date) = + chrono::NaiveDateTime::parse_from_str(&session.modified, "%Y-%m-%d %H:%M:%S UTC") + { + let date = date.date(); + let week = date.iso_week().week() as usize - 1; // 0-based week + let day = date.weekday().num_days_from_sunday() as usize; // 0=Sun, 6=Sat + *heatmap.entry((week, day)).or_insert(0) += 1; + } + } + + let mut result = Vec::new(); + for ((week, day), count) in heatmap { + result.push(ActivityHeatmapCell { week, day, count }); + } + + Ok(Json(result)) +} + // Configure routes for this module pub fn routes(state: Arc) -> Router { Router::new() .route("/sessions", get(list_sessions)) .route("/sessions/{session_id}", get(get_session_history)) + .route("/sessions/insights", get(get_session_insights)) + .route("/sessions/activity-heatmap", get(get_activity_heatmap)) .with_state(state) } diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 8761857096a1..7b694cb18723 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -17,6 +17,7 @@ reqwest = { version = "0.12.9", features = ["json", "rustls-tls-native-roots"], [dependencies] mcp-client = { path = "../mcp-client" } mcp-core = { path = "../mcp-core" } +rmcp = { workspace = true } anyhow = "1.0" thiserror = "1.0" futures = "0.3" @@ -43,7 +44,7 @@ uuid = { version = "1.0", features = ["v4"] } regex = "1.11.1" async-trait = "0.1" async-stream = "0.3" -minijinja = "2.8.0" +minijinja = { version = "2.10.2", features = ["loader"] } include_dir = "0.7.4" tiktoken-rs = "0.6.0" chrono = { version = "0.4.38", features = ["serde"] } @@ -52,6 +53,7 @@ nanoid = "0.4" sha2 = "0.10" base64 = "0.21" url = "2.5" +urlencoding = "2.1" axum = "0.8.1" webbrowser = "0.8" lazy_static = "1.5.0" diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 24a75a745f9e..3c860048a1c9 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -5,10 +5,8 @@ use goose::{ message::Message, providers::{bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider}, }; -use mcp_core::{ - content::Content, - tool::{Tool, ToolCall}, -}; +use mcp_core::tool::{Tool, ToolCall}; +use rmcp::model::Content; use serde_json::json; use std::fs; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 3adee02309a0..f74823473c5e 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,10 +9,14 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; -use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ - self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, +use crate::agents::recipe_tools::dynamic_task_tools::{ + create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, }; use crate::agents::sub_recipe_manager::SubRecipeManager; +use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ + self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, +}; +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{push_message, Message}; use crate::permission::permission_judge::check_tool_permissions; @@ -43,26 +47,23 @@ use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::tool_vectordb::generate_table_id; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; -use mcp_core::{ - prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult, -}; - -use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; +use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, ToolError, ToolResult}; +use rmcp::model::Content; use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; -use super::subagent_manager::SubAgentManager; -use super::subagent_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; +use crate::agents::subagent_task_config::TaskConfig; const DEFAULT_MAX_TURNS: u32 = 1000; /// The main goose Agent pub struct Agent { pub(super) provider: Mutex>>, - pub(super) extension_manager: RwLock, + pub(super) extension_manager: Arc>, pub(super) sub_recipe_manager: Mutex, + pub(super) tasks_manager: TasksManager, pub(super) final_output_tool: Mutex>, pub(super) frontend_tools: Mutex>, pub(super) frontend_instructions: Mutex>, @@ -74,7 +75,7 @@ pub struct Agent { pub(super) tool_monitor: Mutex>, pub(super) router_tool_selector: Mutex>>>, pub(super) scheduler_service: Mutex>>, - pub(super) subagent_manager: Mutex>, + pub(super) mcp_tx: Mutex>, pub(super) mcp_notification_rx: Arc>>, } @@ -135,8 +136,9 @@ impl Agent { Self { provider: Mutex::new(None), - extension_manager: RwLock::new(ExtensionManager::new()), + extension_manager: Arc::new(RwLock::new(ExtensionManager::new())), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), + tasks_manager: TasksManager::new(), final_output_tool: Mutex::new(None), frontend_tools: Mutex::new(HashMap::new()), frontend_instructions: Mutex::new(None), @@ -149,7 +151,7 @@ impl Agent { router_tool_selector: Mutex::new(None), scheduler_service: Mutex::new(None), // Initialize with MCP notification support - subagent_manager: Mutex::new(Some(SubAgentManager::new(mcp_tx))), + mcp_tx: Mutex::new(mcp_tx), mcp_notification_rx: Arc::new(Mutex::new(mcp_rx)), } } @@ -291,10 +293,26 @@ impl Agent { let sub_recipe_manager = self.sub_recipe_manager.lock().await; let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) { sub_recipe_manager - .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) + .dispatch_sub_recipe_tool_call( + &tool_call.name, + tool_call.arguments.clone(), + &self.tasks_manager, + ) .await - } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + } else if tool_call.name == SUBAGENT_EXECUTE_TASK_TOOL_NAME { + let provider = self.provider().await.ok(); + let mcp_tx = self.mcp_tx.lock().await.clone(); + + let task_config = + TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); + subagent_execute_task_tool::run_tasks( + tool_call.arguments.clone(), + task_config, + &self.tasks_manager, + ) + .await + } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { + create_dynamic_task(tool_call.arguments.clone(), &self.tasks_manager).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -310,11 +328,6 @@ impl Agent { ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { ToolCallResult::from(extension_manager.search_available_extensions().await) - } else if tool_call.name == SUBAGENT_RUN_TASK_TOOL_NAME { - ToolCallResult::from( - self.handle_run_subagent_task(tool_call.arguments.clone()) - .await, - ) } else if self.is_frontend_tool(&tool_call.name).await { // For frontend tools, return an error indicating we need frontend execution ToolCallResult::from(Err(ToolError::ExecutionError( @@ -556,11 +569,8 @@ impl Agent { platform_tools::manage_schedule_tool(), ]); - // Add subagent tool (only if ALPHA_FEATURES is enabled) - let config = Config::global(); - if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(subagent_tools::run_task_subagent_tool()); - } + // Dynamic task tool + prefixed_tools.push(create_dynamic_task_tool()); // Add resource tools if supported if extension_manager.supports_resources() { @@ -578,8 +588,7 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { prefixed_tools.push(final_output_tool.tool()); } - prefixed_tools - .push(sub_recipe_execute_task_tool::create_sub_recipe_execute_task_tool()); + prefixed_tools.push(subagent_execute_task_tool::create_subagent_execute_task_tool()); } prefixed_tools @@ -837,24 +846,6 @@ impl Agent { let num_tool_requests = frontend_requests.len() + remaining_requests.len(); if num_tool_requests == 0 { - if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { - if final_output_tool.final_output.is_none() { - tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); - let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); - messages.push(message.clone()); - yield AgentEvent::Message(message); - continue; - } else { - let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); - messages.push(message.clone()); - yield AgentEvent::Message(message); - // Set added_message to true and continue to end the current iteration - added_message = true; - push_message(&mut messages, response); - continue; - } - } - // If there's no final output tool and no tool requests, continue the loop continue; } @@ -1039,10 +1030,14 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_none() { tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); - yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE)); + let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); + messages.push(message.clone()); + yield AgentEvent::Message(message); continue; } else { - yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); + messages.push(message.clone()); + yield AgentEvent::Message(message); } } break; @@ -1077,15 +1072,6 @@ impl Agent { let mut current_provider = self.provider.lock().await; *current_provider = Some(provider.clone()); - // Initialize subagent manager with MCP notification support - // Need to recreate the MCP channel since we're replacing the manager - let (mcp_tx, mcp_rx) = mpsc::channel(100); - { - let mut rx_guard = self.mcp_notification_rx.lock().await; - *rx_guard = mcp_rx; - } - *self.subagent_manager.lock().await = Some(SubAgentManager::new(mcp_tx)); - self.update_router_tool_selector(Some(provider), None) .await?; Ok(()) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index a418fe112a9e..d03c3bff158f 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -19,7 +19,8 @@ use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport}; -use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError}; +use mcp_core::{prompt::Prompt, Tool, ToolCall, ToolError}; +use rmcp::model::Content; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp diff --git a/crates/goose/src/agents/final_output_tool.rs b/crates/goose/src/agents/final_output_tool.rs index 3ada87c86e45..0c2e779b152f 100644 --- a/crates/goose/src/agents/final_output_tool.rs +++ b/crates/goose/src/agents/final_output_tool.rs @@ -3,13 +3,14 @@ use crate::recipe::Response; use indoc::formatdoc; use mcp_core::{ tool::{Tool, ToolAnnotations}, - Content, ToolCall, ToolError, + ToolCall, ToolError, }; +use rmcp::model::Content; use serde_json::Value; pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output"; pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str = - "I see I MUST call the `final_output` tool NOW with the final output for the user."; + "You MUST call the `final_output` tool NOW with the final output for the user."; pub struct FinalOutputTool { pub response: Response, diff --git a/crates/goose/src/agents/large_response_handler.rs b/crates/goose/src/agents/large_response_handler.rs index e4c0ab105544..ff8066215874 100644 --- a/crates/goose/src/agents/large_response_handler.rs +++ b/crates/goose/src/agents/large_response_handler.rs @@ -1,5 +1,6 @@ use chrono::Utc; -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; +use rmcp::model::Content; use std::fs::File; use std::io::Write; @@ -14,17 +15,17 @@ pub fn process_tool_response( let mut processed_contents = Vec::new(); for content in contents { - match content { - Content::Text(text_content) => { + match content.as_text() { + Some(text_content) => { // Check if text exceeds threshold - if text_content.text.len() > LARGE_TEXT_THRESHOLD { + if text_content.text.chars().count() > LARGE_TEXT_THRESHOLD { // Write to temp file match write_large_text_to_file(&text_content.text) { Ok(file_path) => { // Create a new text content with reference to the file let message = format!( "The response returned from the tool call was larger ({} characters) and is stored in the file which you can use other tools to examine or search in: {}", - text_content.text.len(), + text_content.text.chars().count(), file_path ); processed_contents.push(Content::text(message)); @@ -41,11 +42,13 @@ pub fn process_tool_response( } } else { // Keep original content for smaller texts - processed_contents.push(Content::Text(text_content)); + processed_contents.push(content); } } - // Pass through other content types unchanged - _ => processed_contents.push(content), + None => { + // Pass through other content types unchanged + processed_contents.push(content); + } } } @@ -76,7 +79,8 @@ fn write_large_text_to_file(content: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use mcp_core::{Content, ImageContent, TextContent, ToolError}; + use mcp_core::ToolError; + use rmcp::model::Content; use std::fs; use std::path::Path; @@ -84,10 +88,7 @@ mod tests { fn test_small_text_response_passes_through() { // Create a small text response let small_text = "This is a small text response"; - let content = Content::Text(TextContent { - text: small_text.to_string(), - annotations: None, - }); + let content = Content::text(small_text.to_string()); let response = Ok(vec![content]); @@ -96,7 +97,7 @@ mod tests { // Verify the response is unchanged assert_eq!(processed.len(), 1); - if let Content::Text(text_content) = &processed[0] { + if let Some(text_content) = processed[0].as_text() { assert_eq!(text_content.text, small_text); } else { panic!("Expected text content"); @@ -107,10 +108,7 @@ mod tests { fn test_large_text_response_redirected_to_file() { // Create a text larger than the threshold let large_text = "a".repeat(LARGE_TEXT_THRESHOLD + 1000); - let content = Content::Text(TextContent { - text: large_text.clone(), - annotations: None, - }); + let content = Content::text(large_text.clone()); let response = Ok(vec![content]); @@ -119,7 +117,7 @@ mod tests { // Verify the response contains a message about the file assert_eq!(processed.len(), 1); - if let Content::Text(text_content) = &processed[0] { + if let Some(text_content) = processed[0].as_text() { assert!(text_content .text .contains("The response returned from the tool call was larger")); @@ -147,11 +145,7 @@ mod tests { #[test] fn test_image_content_passes_through() { // Create an image content - let image_content = Content::Image(ImageContent { - data: "base64data".to_string(), - mime_type: "image/png".to_string(), - annotations: None, - }); + let image_content = Content::image("base64data".to_string(), "image/png".to_string()); let response = Ok(vec![image_content]); @@ -160,12 +154,11 @@ mod tests { // Verify the response is unchanged assert_eq!(processed.len(), 1); - match &processed[0] { - Content::Image(img) => { - assert_eq!(img.data, "base64data"); - assert_eq!(img.mime_type, "image/png"); - } - _ => panic!("Expected image content"), + if let Some(img) = processed[0].as_image() { + assert_eq!(img.data, "base64data"); + assert_eq!(img.mime_type, "image/png"); + } else { + panic!("Expected image content"); } } @@ -173,15 +166,8 @@ mod tests { fn test_mixed_content_handled_correctly() { // Create a response with mixed content types let small_text = Content::text("Small text"); - let large_text = Content::Text(TextContent { - text: "a".repeat(LARGE_TEXT_THRESHOLD + 1000), - annotations: None, - }); - let image = Content::Image(ImageContent { - data: "image_data".to_string(), - mime_type: "image/jpeg".to_string(), - annotations: None, - }); + let large_text = Content::text("a".repeat(LARGE_TEXT_THRESHOLD + 1000)); + let image = Content::image("image_data".to_string(), "image/jpeg".to_string()); let response = Ok(vec![small_text, large_text, image]); @@ -192,14 +178,14 @@ mod tests { assert_eq!(processed.len(), 3); // First item should be unchanged small text - if let Content::Text(text_content) = &processed[0] { + if let Some(text_content) = processed[0].as_text() { assert_eq!(text_content.text, "Small text"); } else { panic!("Expected text content"); } // Second item should be a message about the file - if let Content::Text(text_content) = &processed[1] { + if let Some(text_content) = processed[1].as_text() { assert!(text_content .text .contains("The response returned from the tool call was larger")); @@ -216,12 +202,11 @@ mod tests { } // Third item should be unchanged image - match &processed[2] { - Content::Image(img) => { - assert_eq!(img.data, "image_data"); - assert_eq!(img.mime_type, "image/jpeg"); - } - _ => panic!("Expected image content"), + if let Some(img) = processed[2].as_image() { + assert_eq!(img.data, "image_data"); + assert_eq!(img.mime_type, "image/jpeg"); + } else { + panic!("Expected image content"); } } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 353e57acde12..ffcc2b9ceffa 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -11,13 +11,11 @@ mod reply_parts; mod router_tool_selector; mod router_tools; mod schedule_tool; -pub mod sub_recipe_execution_tool; pub mod sub_recipe_manager; pub mod subagent; +pub mod subagent_execution_tool; pub mod subagent_handler; -pub mod subagent_manager; -pub mod subagent_tools; -pub mod subagent_types; +mod subagent_task_config; mod tool_execution; mod tool_router_index_manager; pub(crate) mod tool_vectordb; @@ -27,7 +25,6 @@ pub use agent::{Agent, AgentEvent}; pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; -pub use subagent::{SubAgent, SubAgentConfig, SubAgentProgress, SubAgentStatus}; -pub use subagent_manager::SubAgentManager; -pub use subagent_types::SpawnSubAgentArgs; +pub use subagent::{SubAgent, SubAgentProgress, SubAgentStatus}; +pub use subagent_task_config::TaskConfig; pub use types::{FrontendTool, SessionConfig}; diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs new file mode 100644 index 000000000000..e4705e762786 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -0,0 +1,148 @@ +// ======================================= +// Module: Dynamic Task Tools +// Handles creation of tasks dynamically without sub-recipes +// ======================================= +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; +use crate::agents::subagent_execution_tool::{lib::ExecutionMode, task_types::Task}; +use crate::agents::tool_execution::ToolCallResult; +use mcp_core::{tool::ToolAnnotations, Tool, ToolError}; +use rmcp::model::Content; +use serde_json::{json, Value}; + +pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; + +pub fn create_dynamic_task_tool() -> Tool { + Tool::new( + DYNAMIC_TASK_TOOL_NAME_PREFIX.to_string(), + "Use this tool to create one or more dynamic tasks from a shared text instruction and varying parameters.\ + How it works: + - Provide a single text instruction + - Use the 'task_parameters' field to pass an array of parameter sets + - Each resulting task will use the same instruction with different parameter values + This is useful when performing the same operation across many inputs (e.g., getting weather for multiple cities, searching multiple slack channels, iterating through various linear tickets, etc). + Once created, these tasks should be passed to the 'subagent__execute_task' tool for execution. Tasks can run sequentially or in parallel. + --- + What is a 'subagent'? + A 'subagent' is a stateless sub-process that executes a single task independently. Use subagents when: + - You want to parallelize similar work across different inputs + - You are not sure your search or operation will succeed on the first try + Each subagent receives a task with a defined payload and returns a result, which is not visible to the user unless explicitly summarized by the system. + --- + Examples of 'task_parameters' for a single task: + text_instruction: Search for the config file in the root directory. + Examples of 'task_parameters' for multiple tasks: + text_instruction: Get weather for Melbourne. + timeout_seconds: 300 + text_instruction: Get weather for Los Angeles. + timeout_seconds: 300 + text_instruction: Get weather for San Francisco. + timeout_seconds: 300 + ".to_string(), + json!({ + "type": "object", + "properties": { + "task_parameters": { + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": { + "text_instruction": { + "type": "string", + "description": "The text instruction to execute" + }, + "timeout_seconds": { + "type": "integer", + "description": "Optional timeout for the task in seconds (default: 300)", + "minimum": 1 + } + }, + "required": ["text_instruction"] + } + } + } + }), + Some(ToolAnnotations { + title: Some("Dynamic Task Creation".to_string()), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec { + task_params + .iter() + .map(|task_param| { + let text_instruction = task_param + .get("text_instruction") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let payload = json!({ + "text_instruction": text_instruction + }); + + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "text_instruction".to_string(), + payload, + } + }) + .collect() +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: ExecutionMode) -> Value { + let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); + json!({ + "task_ids": task_ids, + "execution_mode": execution_mode + }) +} + +pub async fn create_dynamic_task(params: Value, tasks_manager: &TasksManager) -> ToolCallResult { + let task_params_array = extract_task_parameters(¶ms); + + if task_params_array.is_empty() { + return ToolCallResult::from(Err(ToolError::ExecutionError( + "No task parameters provided".to_string(), + ))); + } + + let tasks = create_text_instruction_tasks_from_params(&task_params_array); + + // Use parallel execution if there are multiple tasks, sequential for single task + let execution_mode = if tasks.len() > 1 { + ExecutionMode::Parallel + } else { + ExecutionMode::Sequential + }; + + let task_execution_payload = create_task_execution_payload(tasks.clone(), execution_mode); + + let tasks_json = match serde_json::to_string(&task_execution_payload) { + Ok(json) => json, + Err(e) => { + return ToolCallResult::from(Err(ToolError::ExecutionError(format!( + "Failed to serialize task list: {}", + e + )))) + } + }; + tasks_manager.save_tasks(tasks.clone()).await; + ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) +} diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 5f2f95fc8485..6e6f28a80310 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1 +1,3 @@ +pub mod dynamic_task_tools; +pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs new file mode 100644 index 000000000000..bd8468c032dd --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use serde_json::Value; +use std::collections::HashMap; + +use crate::recipe::SubRecipe; + +pub fn prepare_command_params( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + if params_from_tool_call.is_empty() { + return Ok(vec![base_params]); + } + + let result = params_from_tool_call + .into_iter() + .map(|tool_param| { + let mut param_map = base_params.clone(); + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + param_map.entry(key.clone()).or_insert(value_str); + } + } + param_map + }) + .collect(); + + Ok(result) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs new file mode 100644 index 000000000000..583338d644a7 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -0,0 +1,140 @@ +use std::collections::HashMap; + +use crate::recipe::SubRecipe; +use serde_json::json; + +use crate::agents::recipe_tools::param_utils::prepare_command_params; + +fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, + }; + sub_recipe +} + +mod prepare_command_params_tests { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result, vec![HashMap::new()]); + } + + mod multiple_tool_parameters { + use super::*; + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_merge_base_values_with_tool_parameters() { + let parameter_array = vec![ + json!(HashMap::from([( + "key2".to_string(), + "override_value1".to_string() + )])), + json!(HashMap::from([( + "key2".to_string(), + "override_value2".to_string() + )])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ]), + ], + result + ); + } + } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 928cf8bd0845..66b89ea39db9 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,22 +1,35 @@ -use std::{collections::HashMap, fs}; +use std::collections::HashSet; +use std::fs; use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::subagent_execution_tool::lib::{ExecutionMode, Task}; +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; +use super::param_utils::prepare_command_params; + pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); Tool::new( format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), - "Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(), + format!( + "Create one or more tasks to run the '{}' sub recipe. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same sub recipe but with different parameter values. \ + This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ + After creating the tasks and execution_mode is provided, pass them to the task executor to run these tasks", + sub_recipe.name + ), input_schema, Some(ToolAnnotations { - title: Some(format!("create sub recipe task {}", sub_recipe.name)), + title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)), read_only_hint: false, destructive_hint: true, idempotent_hint: false, @@ -25,6 +38,64 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_tasks_from_params( + sub_recipe: &SubRecipe, + command_params: &[std::collections::HashMap], +) -> Vec { + let tasks: Vec = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + "sequential_when_repeated": sub_recipe.sequential_when_repeated + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect(); + + tasks +} + +fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value { + let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); + json!({ + "task_ids": task_ids, + "execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel }, + }) +} + +pub async fn create_sub_recipe_task( + sub_recipe: &SubRecipe, + params: Value, + tasks_manager: &TasksManager, +) -> Result { + let task_params_array = extract_task_parameters(¶ms); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = create_tasks_from_params(sub_recipe, &command_params); + let task_execution_payload = create_task_execution_payload(&tasks, sub_recipe); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + tasks_manager.save_tasks(tasks.clone()).await; + Ok(tasks_json) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -34,22 +105,55 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } -fn get_input_schema(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_params_map = HashMap::::new(); +fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet { + let mut sub_recipe_params_with_values = HashSet::::new(); if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params_map.insert(param_name.clone(), param_value.clone()); + for param_name in params_with_value.keys() { + sub_recipe_params_with_values.insert(param_name.clone()); } } + sub_recipe_params_with_values +} + +fn create_input_schema(param_properties: Map, param_required: Vec) -> Value { + let mut properties = Map::new(); + if !param_properties.is_empty() { + properties.insert( + "task_parameters".to_string(), + json!({ + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": param_properties, + "required": param_required + }, + }) + ); + } + json!({ + "type": "object", + "properties": properties, + }) +} + +fn get_input_schema(sub_recipe: &SubRecipe) -> Result { + let sub_recipe_params_with_values = get_params_with_values(sub_recipe); + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + if let Some(parameters) = parameter_definition { - let mut properties = Map::new(); - let mut required = Vec::new(); for param in parameters { - if sub_recipe_params_map.contains_key(¶m.key) { + if sub_recipe_params_with_values.contains(¶m.key.clone()) { continue; } - properties.insert( + param_properties.insert( param.key.clone(), json!({ "type": param.input_type.to_string(), @@ -57,60 +161,11 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { }), ); if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - required.push(param.key); + param_required.push(param.key); } } - Ok(json!({ - "type": "object", - "properties": properties, - "required": required - })) - } else { - Ok(json!({ - "type": "object", - "properties": {} - })) - } -} - -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Value, -) -> Result> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } } - if let Some(params_map) = params_from_tool_call.as_object() { - for (key, value) in params_map { - sub_recipe_params.insert( - key.to_string(), - value.as_str().unwrap_or(&value.to_string()).to_string(), - ); - } - } - Ok(sub_recipe_params) -} - -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { - let command_params = prepare_command_params(sub_recipe, params)?; - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": command_params, - "recipe_path": sub_recipe.path.clone(), - } - }); - let task = Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, - }; - let task_json = serde_json::to_string(&task) - .map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?; - Ok(task_json) + Ok(create_input_schema(param_properties, param_required)) } #[cfg(test)] diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 11ce390a6b3b..0b682b0b649b 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -3,66 +3,48 @@ mod tests { use std::collections::HashMap; use crate::recipe::SubRecipe; + use serde_json::json; + use serde_json::Value; + use tempfile::TempDir; - fn setup_sub_recipe() -> SubRecipe { + fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, }; sub_recipe } - mod prepare_command_params_tests { - use std::collections::HashMap; - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; - - #[test] - fn test_prepare_command_params_basic() { - let mut params = HashMap::new(); - params.insert("key2".to_string(), "value2".to_string()); + mod get_input_schema { + use super::*; + use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; - let sub_recipe = setup_sub_recipe(); - - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result.get("key1"), Some(&"value1".to_string())); - assert_eq!(result.get("key2"), Some(&"value2".to_string())); + fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) { + let mut sub_recipe = setup_default_sub_recipe(); + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join(sub_recipe.path.clone()); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + sub_recipe.path = temp_file.to_string_lossy().to_string(); + (sub_recipe, temp_dir) } - #[test] - fn test_prepare_command_params_empty() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - }; - let params: HashMap = HashMap::new(); - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 0); + fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) { + let task_parameters = result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + .unwrap() + .as_object() + .unwrap(); + let task_parameters_items = task_parameters.get("items").unwrap(); + assert_eq!(&expected_task_parameters_items, task_parameters_items); } - } - - mod get_input_schema_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - get_input_schema, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; - #[test] - fn test_get_input_schema_with_parameters() { - let sub_recipe = setup_sub_recipe(); - - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -83,73 +65,67 @@ mod tests { ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); let result = get_input_schema(&sub_recipe).unwrap(); - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key2_prop = &properties["key2"]; - assert_eq!(key2_prop["type"], "number"); - assert_eq!(key2_prop["description"], "An optional parameter"); - - let required = result["required"].as_array().unwrap(); - assert_eq!(required.len(), 0); + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": [] + }), + ); } #[test] - fn test_get_input_schema_no_parameters_values() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - }; + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); - let sub_recipe_file_content = r#"{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - "prompt": "Test prompt", - "parameters": [ - { - "key": "key1", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ] - }"#; + let result = get_input_schema(&sub_recipe).unwrap(); - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = None; let result = get_input_schema(&sub_recipe).unwrap(); - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key1_prop = &properties["key1"]; - assert_eq!(key1_prop["type"], "string"); - assert_eq!(key1_prop["description"], "A test parameter"); - assert_eq!(result["required"].as_array().unwrap().len(), 1); - assert_eq!(result["required"][0], "key1"); + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A test parameter" }, + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": ["key1"] + }), + ); } } } diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index 933316bcbdd1..52da661a95c4 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -1,6 +1,6 @@ -use mcp_core::content::TextContent; use mcp_core::tool::Tool; -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; +use rmcp::model::Content; use anyhow::{Context, Result}; use async_trait::async_trait; @@ -115,10 +115,7 @@ impl RouterToolSelector for VectorToolSelector { "Tool: {}\nDescription: {}\nSchema: {}", tool.tool_name, tool.description, tool.schema ); - Content::Text(TextContent { - text, - annotations: None, - }) + Content::text(text) }) .collect(); @@ -292,12 +289,7 @@ impl RouterToolSelector for LLMToolSelector { let tool_entries: Vec = text .split("\n\n") .filter(|entry| entry.trim().starts_with("Tool:")) - .map(|entry| { - Content::Text(TextContent { - text: entry.trim().to_string(), - annotations: None, - }) - }) + .map(|entry| Content::text(entry.trim().to_string())) .collect(); Ok(tool_entries) diff --git a/crates/goose/src/agents/schedule_tool.rs b/crates/goose/src/agents/schedule_tool.rs index 043866372596..30210544185e 100644 --- a/crates/goose/src/agents/schedule_tool.rs +++ b/crates/goose/src/agents/schedule_tool.rs @@ -6,7 +6,8 @@ use std::sync::Arc; use chrono::Utc; -use mcp_core::{Content, ToolError, ToolResult}; +use mcp_core::{ToolError, ToolResult}; +use rmcp::model::Content; use crate::recipe::Recipe; use crate::scheduler_trait::SchedulerTrait; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs deleted file mode 100644 index b796d412984d..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::sync::atomic::{AtomicBool, AtomicUsize}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::Instant; - -use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, -}; -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState}; - -pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let result = process_task(task, config.timeout_seconds).await; - - let execution_time = start_time.elapsed().as_millis(); - let completed = if result.status == "success" { 1 } else { 0 }; - let failed = if result.status == "failed" { 1 } else { 0 }; - - ExecutionResponse { - status: "completed".to_string(), - results: vec![result], - stats: ExecutionStats { - total_tasks: 1, - completed, - failed, - execution_time_ms: execution_time, - }, - } -} - -// Main parallel execution function -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let task_count = tasks.len(); - - // Create channels - let (task_tx, task_rx) = mpsc::channel::(task_count); - let (result_tx, mut result_rx) = mpsc::channel::(task_count); - - // Initialize shared state - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); - - // Send all tasks to the queue - for task in tasks.clone() { - let _ = task_tx.send(task).await; - } - // Close sender so workers know when queue is empty - drop(task_tx); - - // Start initial workers - let mut worker_handles = Vec::new(); - for i in 0..config.initial_workers { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); - worker_handles.push(handle); - } - - // Start the scaler - let scaler_state = shared_state.clone(); - let scaler_handle = tokio::spawn(async move { - run_scaler( - scaler_state, - task_count, - config.max_workers, - config.timeout_seconds, - ) - .await; - }); - - // Collect results - let mut results = Vec::new(); - while let Some(result) = result_rx.recv().await { - results.push(result); - if results.len() >= task_count { - break; - } - } - - // Wait for scaler to finish - let _ = scaler_handle.await; - - // Calculate stats - let execution_time = start_time.elapsed().as_millis(); - let completed = results.iter().filter(|r| r.status == "success").count(); - let failed = results.iter().filter(|r| r.status == "failed").count(); - - ExecutionResponse { - status: "completed".to_string(), - results, - stats: ExecutionStats { - total_tasks: task_count, - completed, - failed, - execution_time_ms: execution_time, - }, - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs deleted file mode 100644 index 9df784a46be0..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; -pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; -pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, -}; - -use serde_json::Value; - -pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { - let tasks: Vec = - serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) - .map_err(|e| format!("Failed to parse tasks: {}", e))?; - - let config: Config = if let Some(config_value) = input.get("config") { - serde_json::from_value(config_value.clone()) - .map_err(|e| format!("Failed to parse config: {}", e))? - } else { - Config::default() - }; - let task_count = tasks.len(); - match execution_mode { - "sequential" => { - if task_count == 1 { - let response = execute_single_task(&tasks[0], config).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) - } else { - Err("Sequential execution mode requires exactly one task".to_string()) - } - } - "parallel" => { - let response = parallel_execute(tasks, config).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) - } - _ => Err("Invalid execution mode".to_string()), - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index a49791e2776f..49fcc194c56a 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,6 +1,11 @@ mod executor; pub mod lib; +pub mod notification_events; pub mod sub_recipe_execute_task_tool; +mod task_execution_tracker; +mod task_types; mod tasks; -mod types; +pub mod tasks_manager; +pub mod utils; mod workers; + diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs deleted file mode 100644 index 46738b813b13..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ /dev/null @@ -1,124 +0,0 @@ -use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; -use serde_json::Value; - -use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, -}; - -pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; -pub fn create_sub_recipe_execute_task_tool() -> Tool { - Tool::new( - SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, - "Only use this tool when you execute sub recipe task. -EXECUTION STRATEGY: -- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution -- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' - -IMPLEMENTATION: -- Sequential execution: Call this tool multiple times, passing exactly ONE task per call -- Parallel execution: Call this tool once, passing an ARRAY of all tasks - -EXAMPLES: -- User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) -- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) -- User: 'run these simultaneously' → Parallel (1 tool call with task array) -- User: 'do task A then task B' → Sequential (2 separate tool calls)", - serde_json::json!({ - "type": "object", - "properties": { - "execution_mode": { - "type": "string", - "enum": ["sequential", "parallel"], - "default": "sequential", - "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." - }, - "tasks": { - "type": "array", - "items": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Unique identifier for the task" - }, - "task_type": { - "type": "string", - "enum": ["sub_recipe", "text_instruction"], - "default": "sub_recipe", - "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" - }, - "payload": { - "type": "object", - "properties": { - "sub_recipe": { - "type": "object", - "description": "sub recipe to execute", - "properties": { - "name": { - "type": "string", - "description": "name of the sub recipe to execute" - }, - "recipe_path": { - "type": "string", - "description": "path of the sub recipe file" - }, - "command_parameters": { - "type": "object", - "description": "parameters to pass to run recipe command with sub recipe file" - } - } - }, - "text_instruction": { - "type": "string", - "description": "text instruction to execute" - } - } - } - }, - "required": ["id", "payload"] - }, - "description": "The tasks to run in parallel" - }, - "config": { - "type": "object", - "properties": { - "timeout_seconds": { - "type": "number" - }, - "max_workers": { - "type": "number" - }, - "initial_workers": { - "type": "number" - } - } - } - }, - "required": ["tasks"] - }), - Some(ToolAnnotations { - title: Some("Run tasks in parallel".to_string()), - read_only_hint: false, - destructive_hint: true, - idempotent_hint: false, - open_world_hint: true, - }), - ) -} - -pub async fn run_tasks(execute_data: Value) -> ToolCallResult { - let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); - let execution_mode = execute_data_clone - .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() - .unwrap_or("sequential"); - match execute_tasks(execute_data, execution_mode).await { - Ok(result) => { - let output = serde_json::to_string(&result).unwrap(); - ToolCallResult::from(Ok(vec![Content::text(output)])) - } - Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 4e4584aa0b34..66f67729e69e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -1,76 +1,98 @@ use serde_json::Value; use std::process::Stdio; -use std::time::Duration; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; -// Process a single task based on its type -pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { - let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_seconds); - - // Execute with timeout - match timeout(timeout_duration, execute_task(task_clone)).await { - Ok(Ok(data)) => TaskResult { +pub async fn process_task( + task: &Task, + task_execution_tracker: Arc, +) -> TaskResult { + match get_task_result(task.clone(), task_execution_tracker).await { + Ok(data) => TaskResult { task_id: task.id.clone(), - status: "success".to_string(), + status: TaskStatus::Completed, data: Some(data), error: None, }, - Ok(Err(error)) => TaskResult { + Err(error) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some(error), }, - Err(_) => TaskResult { - task_id: task.id.clone(), - status: "failed".to_string(), - data: None, - error: Some("Task timeout".to_string()), - }, } } -async fn execute_task(task: Task) -> Result { +async fn get_task_result( + task: Task, + task_execution_tracker: Arc, +) -> Result { + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } +} + +fn build_command(task: &Task) -> Result<(Command, String), String> { + let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { - let sub_recipe = task.payload.get("sub_recipe").unwrap(); - let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap(); - let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap(); - let command_parameters = sub_recipe.get("command_parameters").unwrap(); + let sub_recipe_name = task + .get_sub_recipe_name() + .ok_or_else(|| task_error("sub_recipe name"))?; + let path = task + .get_sub_recipe_path() + .ok_or_else(|| task_error("sub_recipe path"))?; + let command_parameters = task + .get_command_parameters() + .ok_or_else(|| task_error("command_parameters"))?; + output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--recipe").arg(path); - if let Some(params_map) = command_parameters.as_object() { - for (key, value) in params_map { - let key_str = key.to_string(); - let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); - cmd.arg("--params") - .arg(format!("{}={}", key_str, value_str)); - } + cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); + + for (key, value) in command_parameters { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); } cmd } else { let text = task - .payload - .get("text_instruction") - .unwrap() - .as_str() - .unwrap(); + .get_text_instruction() + .ok_or_else(|| task_error("text_instruction"))?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd }; - // Configure to capture stdout command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + Ok((command, output_identifier)) +} - // Spawn the child process +async fn run_command( + mut command: Command, + output_identifier: &str, + task_id: &str, + task_execution_tracker: Arc, +) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -78,42 +100,87 @@ async fn execute_task(task: Task) -> Result { let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); + let stdout_task = spawn_output_reader( + stdout, + output_identifier, + false, + task_id, + task_execution_tracker.clone(), + ); + let stderr_task = spawn_output_reader( + stderr, + output_identifier, + true, + task_id, + task_execution_tracker.clone(), + ); - // Spawn background tasks to read from stdout and stderr - let output_identifier_clone = output_identifier.clone(); - let stdout_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stdout_reader.next_line().await { - println!("[{}] {}", output_identifier_clone, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); + let status = child + .wait() + .await + .map_err(|e| format!("Failed to wait for process: {}", e))?; - let stderr_task = tokio::spawn(async move { + let stdout_output = stdout_task.await.unwrap(); + let stderr_output = stderr_task.await.unwrap(); + + Ok((stdout_output, stderr_output, status.success())) +} + +fn spawn_output_reader( + reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, + output_identifier: &str, + is_stderr: bool, + task_id: &str, + task_execution_tracker: Arc, +) -> tokio::task::JoinHandle { + let output_identifier = output_identifier.to_string(); + let task_id = task_id.to_string(); + tokio::spawn(async move { let mut buffer = String::new(); - while let Ok(Some(line)) = stderr_reader.next_line().await { - eprintln!("[stderr for {}] {}", output_identifier, line); + let mut lines = BufReader::new(reader).lines(); + while let Ok(Some(line)) = lines.next_line().await { buffer.push_str(&line); buffer.push('\n'); + + if !is_stderr { + task_execution_tracker + .send_live_output(&task_id, &line) + .await; + } else { + tracing::warn!("Task stderr [{}]: {}", output_identifier, line); + } } buffer - }); + }) +} - let status = child - .wait() - .await - .map_err(|e| format!("Failed to wait for process: {}", e))?; +fn extract_json_from_line(line: &str) -> Option { + let start = line.find('{')?; + let end = line.rfind('}')?; - let stdout_output = stdout_task.await.unwrap(); - let stderr_output = stderr_task.await.unwrap(); + if start >= end { + return None; + } - if status.success() { - Ok(Value::String(stdout_output)) + let potential_json = &line[start..=end]; + if serde_json::from_str::(potential_json).is_ok() { + Some(potential_json.to_string()) } else { - Err(format!("Command failed:\n{}", stderr_output)) + None + } +} + +fn process_output(stdout_output: String) -> Result { + let last_line = stdout_output + .lines() + .filter(|line| !line.trim().is_empty()) + .next_back() + .unwrap_or(""); + + if let Some(json_string) = extract_json_from_line(last_line) { + Ok(Value::String(json_string)) + } else { + Ok(Value::String(stdout_output)) } } + diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs deleted file mode 100644 index ede71dbf40b4..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ /dev/null @@ -1,69 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -// Task definition that LLMs will send -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Task { - pub id: String, - pub task_type: String, - pub payload: Value, -} - -// Result for each task -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskResult { - pub task_id: String, - pub status: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -// Configuration for the parallel executor -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - #[serde(default = "default_max_workers")] - pub max_workers: usize, - #[serde(default = "default_timeout")] - pub timeout_seconds: u64, - #[serde(default = "default_initial_workers")] - pub initial_workers: usize, -} - -impl Default for Config { - fn default() -> Self { - Self { - max_workers: default_max_workers(), - timeout_seconds: default_timeout(), - initial_workers: default_initial_workers(), - } - } -} - -fn default_max_workers() -> usize { - 10 -} -fn default_timeout() -> u64 { - 300 -} -fn default_initial_workers() -> usize { - 2 -} - -// Stats for the execution -#[derive(Debug, Serialize)] -pub struct ExecutionStats { - pub total_tasks: usize, - pub completed: usize, - pub failed: usize, - pub execution_time_ms: u128, -} - -// Main response structure -#[derive(Debug, Serialize)] -pub struct ExecutionResponse { - pub status: String, - pub results: Vec, - pub stats: ExecutionStats, -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index e48f19c4d360..89473f7c6a65 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,133 +1,31 @@ +use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; -#[cfg(test)] -mod tests { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::Task; - - #[tokio::test] - async fn test_spawn_worker_returns_handle() { - // Create a simple shared state for testing - let (task_tx, task_rx) = mpsc::channel::(1); - let (result_tx, _result_rx) = mpsc::channel::(1); - - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); - - // Test that spawn_worker returns a JoinHandle - let handle = spawn_worker(shared_state.clone(), 0, 5); - - // Verify it's a JoinHandle by checking we can abort it - assert!(!handle.is_finished()); - - // Signal stop and close the channel to let the worker exit - shared_state.should_stop.store(true, Ordering::SeqCst); - drop(task_tx); // Close the channel - - // Wait for the worker to finish - let result = handle.await; - assert!(result.is_ok()); - } -} - -pub struct SharedState { - pub task_receiver: Arc>>, - pub result_sender: mpsc::Sender, - pub active_workers: Arc, - pub should_stop: Arc, - pub completed_tasks: Arc, +async fn receive_task(state: &SharedState) -> Option { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await } -// Spawn a worker task -pub fn spawn_worker( - state: Arc, - worker_id: usize, - timeout_seconds: u64, -) -> tokio::task::JoinHandle<()> { - state.active_workers.fetch_add(1, Ordering::SeqCst); +pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { + state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id, timeout_seconds).await; + worker_loop(state, worker_id).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { - loop { - // Try to receive a task - let task = { - let mut receiver = state.task_receiver.lock().await; - receiver.recv().await - }; - - match task { - Some(task) => { - // Process the task - let result = process_task(&task, timeout_seconds).await; - - // Send result - let _ = state.result_sender.send(result).await; - - // Update completed count - state.completed_tasks.fetch_add(1, Ordering::SeqCst); - } - None => { - // Channel closed, exit worker - break; - } - } +async fn worker_loop(state: Arc, _worker_id: usize) { + while let Some(task) = receive_task(&state).await { + state.task_execution_tracker.start_task(&task.id).await; + let result = process_task(&task, state.task_execution_tracker.clone()).await; - // Check if we should stop - if state.should_stop.load(Ordering::SeqCst) { + if let Err(e) = state.result_sender.send(result).await { + tracing::error!("Worker failed to send result: {}", e); break; } } - // Worker is exiting - state.active_workers.fetch_sub(1, Ordering::SeqCst); + state.decrement_active_workers(); } -// Scaling controller that monitors queue and spawns workers -pub async fn run_scaler( - state: Arc, - task_count: usize, - max_workers: usize, - timeout_seconds: u64, -) { - let mut worker_count = 0; - - loop { - sleep(Duration::from_millis(100)).await; - - let active = state.active_workers.load(Ordering::SeqCst); - let completed = state.completed_tasks.load(Ordering::SeqCst); - let pending = task_count.saturating_sub(completed); - - // Simple scaling logic: spawn worker if many pending tasks and under limit - if pending > active * 2 && active < max_workers && worker_count < max_workers { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - - // If all tasks completed, signal stop - if completed >= task_count { - state.should_stop.store(true, Ordering::SeqCst); - break; - } - - // If no active workers and tasks remaining, spawn one - if active == 0 && pending > 0 { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - } -} diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 2441684b4b0e..98431f5bb11d 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -1,4 +1,5 @@ -use mcp_core::{Content, Tool, ToolError}; +use mcp_core::{Tool, ToolError}; +use rmcp::model::Content; use serde_json::Value; use std::collections::HashMap; @@ -7,6 +8,7 @@ use crate::{ recipe_tools::sub_recipe_tools::{ create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, + subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }, recipe::SubRecipe, @@ -34,12 +36,6 @@ impl SubRecipeManager { pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec) { for sub_recipe in sub_recipes_to_add { - // let sub_recipe_key = format!( - // "{}_{}", - // SUB_RECIPE_TOOL_NAME_PREFIX, - // sub_recipe.name.clone() - // ); - // let tool = create_sub_recipe_tool(&sub_recipe); let sub_recipe_key = format!( "{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, @@ -59,43 +55,22 @@ impl SubRecipeManager { &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> ToolCallResult { - let result = self.call_sub_recipe_tool(tool_name, params).await; + let result = self + .call_sub_recipe_tool(tool_name, params, tasks_manager) + .await; match result { Ok(call_result) => ToolCallResult::from(Ok(call_result)), Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), } } - // async fn call_sub_recipe_tool( - // &self, - // tool_name: &str, - // params: Value, - // ) -> Result, ToolError> { - // let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { - // let sub_recipe_name = tool_name - // .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) - // .and_then(|s| s.strip_prefix("_")) - // .ok_or_else(|| { - // ToolError::InvalidParameters(format!( - // "Invalid sub-recipe tool name format: {}", - // tool_name - // )) - // }) - // .unwrap(); - - // ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) - // })?; - - // let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { - // ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) - // })?; - // Ok(vec![Content::text(output)]) - // } async fn call_sub_recipe_tool( &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> Result, ToolError> { let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { let sub_recipe_name = tool_name @@ -111,11 +86,10 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - - let output = create_sub_recipe_task(sub_recipe, params) + let output = create_sub_recipe_task(sub_recipe, params, tasks_manager) .await .map_err(|e| { - ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) + ToolError::ExecutionError(format!("Sub-recipe task createion failed: {}", e)) })?; Ok(vec![Content::text(output)]) } diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 0a02e2d1db73..030c787732b6 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,27 +1,18 @@ use crate::{ - agents::{extension_manager::ExtensionManager, Agent}, + agents::{Agent, TaskConfig}, message::{Message, MessageContent, ToolRequest}, prompt_template::render_global_file, - providers::base::Provider, providers::errors::ProviderError, - recipe::Recipe, }; use anyhow::anyhow; use chrono::{DateTime, Utc}; use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; -use mcp_core::{handler::ToolError, role::Role, tool::Tool}; +use mcp_core::{handler::ToolError, tool::Tool}; use serde::{Deserialize, Serialize}; use serde_json::{self, json}; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, instrument}; -use uuid::Uuid; - -use crate::agents::platform_tools::{ - self, PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME, - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, -}; -use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; /// Status of a subagent #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -32,48 +23,6 @@ pub enum SubAgentStatus { Terminated, // Manually terminated } -/// Configuration for a subagent -#[derive(Debug)] -pub struct SubAgentConfig { - pub id: String, - pub recipe: Option, - pub instructions: Option, - pub max_turns: Option, - pub timeout_seconds: Option, -} - -impl SubAgentConfig { - pub fn new_with_recipe(recipe: Recipe) -> Self { - Self { - id: Uuid::new_v4().to_string(), - recipe: Some(recipe), - instructions: None, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn new_with_instructions(instructions: String) -> Self { - Self { - id: Uuid::new_v4().to_string(), - recipe: None, - instructions: Some(instructions), - max_turns: None, - timeout_seconds: None, - } - } - - pub fn with_max_turns(mut self, max_turns: usize) -> Self { - self.max_turns = Some(max_turns); - self - } - - pub fn with_timeout(mut self, timeout_seconds: u64) -> Self { - self.timeout_seconds = Some(timeout_seconds); - self - } -} - /// Progress information for a subagent #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SubAgentProgress { @@ -90,58 +39,26 @@ pub struct SubAgent { pub id: String, pub conversation: Arc>>, pub status: Arc>, - pub config: SubAgentConfig, + pub config: TaskConfig, pub turn_count: Arc>, pub created_at: DateTime, - pub recipe_extensions: Arc>>, - pub missing_extensions: Arc>>, // Track extensions that weren't enabled - pub mcp_notification_tx: mpsc::Sender, // For MCP notifications } impl SubAgent { /// Create a new subagent with the given configuration and provider - #[instrument(skip(config, _provider, extension_manager, mcp_notification_tx))] + #[instrument(skip(task_config))] pub async fn new( - config: SubAgentConfig, - _provider: Arc, - extension_manager: Arc>, - mcp_notification_tx: mpsc::Sender, + task_config: TaskConfig, ) -> Result<(Arc, tokio::task::JoinHandle<()>), anyhow::Error> { - debug!("Creating new subagent with id: {}", config.id); - - let mut missing_extensions = Vec::new(); - let mut recipe_extensions = Vec::new(); - - // Check if extensions from recipe exist in the extension manager - if let Some(recipe) = &config.recipe { - if let Some(extensions) = &recipe.extensions { - for extension in extensions { - let extension_name = extension.name(); - let existing_extensions = extension_manager.list_extensions().await?; - - if !existing_extensions.contains(&extension_name) { - missing_extensions.push(extension_name); - } else { - recipe_extensions.push(extension_name); - } - } - } - } else { - // If no recipe, inherit all extensions from the parent agent - let existing_extensions = extension_manager.list_extensions().await?; - recipe_extensions = existing_extensions; - } + debug!("Creating new subagent with id: {}", task_config.id); let subagent = Arc::new(SubAgent { - id: config.id.clone(), + id: task_config.id.clone(), conversation: Arc::new(Mutex::new(Vec::new())), status: Arc::new(RwLock::new(SubAgentStatus::Ready)), - config, + config: task_config, turn_count: Arc::new(Mutex::new(0)), created_at: Utc::now(), - recipe_extensions: Arc::new(Mutex::new(recipe_extensions)), - missing_extensions: Arc::new(Mutex::new(missing_extensions)), - mcp_notification_tx, }); // Send initial MCP notification @@ -209,7 +126,7 @@ impl SubAgent { })), }); - if let Err(e) = self.mcp_notification_tx.send(notification).await { + if let Err(e) = self.config.mcp_tx.send(notification).await { error!( "Failed to send MCP notification from subagent {}: {}", self.id, e @@ -238,17 +155,29 @@ impl SubAgent { } /// Process a message and generate a response using the subagent's provider - #[instrument(skip(self, message, provider, extension_manager))] + #[instrument(skip(self, message))] pub async fn reply_subagent( &self, message: String, - provider: Arc, - extension_manager: Arc>, + task_config: TaskConfig, ) -> Result { debug!("Processing message for subagent {}", self.id); self.send_mcp_notification("message_processing", &format!("Processing: {}", message)) .await; + // Get provider and extension manager from task config + let provider = self + .config + .provider + .as_ref() + .ok_or_else(|| anyhow!("No provider configured for subagent"))?; + + let extension_manager = self + .config + .extension_manager + .as_ref() + .ok_or_else(|| anyhow!("No extension manager configured for subagent"))?; + // Check if we've exceeded max turns { let turn_count = *self.turn_count.lock().await; @@ -288,80 +217,12 @@ impl SubAgent { let mut messages = self.get_conversation().await; // Get tools based on whether we're using a recipe or inheriting from parent - let tools: Vec = if self.config.recipe.is_some() { - // Recipe mode: only get tools from the recipe's extensions - let recipe_extensions = self.recipe_extensions.lock().await; - let mut recipe_tools = Vec::new(); - - debug!( - "Subagent {} operating in recipe mode with {} extensions", - self.id, - recipe_extensions.len() - ); - - for extension_name in recipe_extensions.iter() { - match extension_manager - .get_prefixed_tools(Some(extension_name.clone())) - .await - { - Ok(mut ext_tools) => { - debug!( - "Added {} tools from extension {}", - ext_tools.len(), - extension_name - ); - recipe_tools.append(&mut ext_tools); - } - Err(e) => { - debug!( - "Failed to get tools for extension {}: {}", - extension_name, e - ); - } - } - } - - debug!( - "Subagent {} has {} total recipe tools before filtering", - self.id, - recipe_tools.len() - ); - // Filter out subagent tools from recipe tools - let mut filtered_tools = Self::filter_subagent_tools(recipe_tools); - - // Add platform tools (except subagent tools) - Self::add_platform_tools(&mut filtered_tools, &extension_manager).await; - - debug!( - "Subagent {} has {} tools after filtering and adding platform tools", - self.id, - filtered_tools.len() - ); - filtered_tools - } else { - // No recipe: inherit all tools from parent (but filter out subagent tools) - debug!( - "Subagent {} operating in inheritance mode, using all parent tools", - self.id - ); - let parent_tools = extension_manager.get_prefixed_tools(None).await?; - debug!( - "Subagent {} has {} parent tools before filtering", - self.id, - parent_tools.len() - ); - let mut filtered_tools = Self::filter_subagent_tools(parent_tools); - - // Add platform tools (except subagent tools) - Self::add_platform_tools(&mut filtered_tools, &extension_manager).await; - - debug!( - "Subagent {} has {} tools after filtering and adding platform tools", - self.id, - filtered_tools.len() - ); - filtered_tools - }; + let tools: Vec = extension_manager + .read() + .await + .get_prefixed_tools(None) + .await + .unwrap_or_default(); let toolshim_tools: Vec = vec![]; @@ -371,7 +232,7 @@ impl SubAgent { // Generate response from provider loop { match Agent::generate_response_from_provider( - Arc::clone(&provider), + Arc::clone(provider), &system_prompt, &messages, &tools, @@ -427,20 +288,14 @@ impl SubAgent { .await; // Handle platform tools or dispatch to extension manager - let tool_result = if self.is_platform_tool(&tool_call.name) { - self.handle_platform_tool_call( - tool_call.clone(), - &extension_manager, - ) + let tool_result = match extension_manager + .read() .await - } else { - match extension_manager - .dispatch_tool_call(tool_call.clone()) - .await - { - Ok(result) => result.result.await, - Err(e) => Err(ToolError::ExecutionError(e.to_string())), - } + .dispatch_tool_call(tool_call.clone()) + .await + { + Ok(result) => result.result.await, + Err(e) => Err(ToolError::ExecutionError(e.to_string())), }; match tool_result { @@ -529,142 +384,10 @@ impl SubAgent { Ok(()) } - /// Get formatted conversation for display - pub async fn get_formatted_conversation(&self) -> String { - let conversation = self.conversation.lock().await; - - let mut formatted = format!("=== Subagent {} Conversation ===\n", self.id); - - if let Some(recipe) = &self.config.recipe { - formatted.push_str(&format!("Recipe: {}\n", recipe.title)); - } else if let Some(instructions) = &self.config.instructions { - formatted.push_str(&format!("Instructions: {}\n", instructions)); - } else { - formatted.push_str("Mode: Ad-hoc subagent\n"); - } - - formatted.push_str(&format!( - "Created: {}\n", - self.created_at.format("%Y-%m-%d %H:%M:%S UTC") - )); - - let progress = self.get_progress().await; - - formatted.push_str(&format!("Status: {:?}\n", progress.status)); - formatted.push_str(&format!("Turn: {}", progress.turn)); - if let Some(max_turns) = progress.max_turns { - formatted.push_str(&format!("/{}", max_turns)); - } - formatted.push_str("\n\n"); - - for (i, message) in conversation.iter().enumerate() { - formatted.push_str(&format!( - "{}. {}: {}\n", - i + 1, - match message.role { - Role::User => "User", - Role::Assistant => "Assistant", - }, - message.as_concat_text() - )); - } - - formatted.push_str("=== End Conversation ===\n"); - - formatted - } - - /// Get the list of extensions that weren't enabled - pub async fn get_missing_extensions(&self) -> Vec { - self.missing_extensions.lock().await.clone() - } - /// Filter out subagent spawning tools to prevent infinite recursion - fn filter_subagent_tools(tools: Vec) -> Vec { - let original_count = tools.len(); - let filtered_tools: Vec = tools - .into_iter() - .filter(|tool| { - let should_keep = tool.name != SUBAGENT_RUN_TASK_TOOL_NAME; - if !should_keep { - debug!("Filtering out subagent tool: {}", tool.name); - } - should_keep - }) - .collect(); - - let filtered_count = filtered_tools.len(); - if filtered_count < original_count { - debug!( - "Filtered {} subagent tool(s) from {} total tools", - original_count - filtered_count, - original_count - ); - } - - filtered_tools - } - - /// Add platform tools to the subagent's tool list (excluding dangerous tools) - async fn add_platform_tools(tools: &mut Vec, extension_manager: &ExtensionManager) { - debug!("Adding safe platform tools to subagent"); - - // Add safe platform tools - subagents can search for extensions but can't manage them or schedules - tools.push(platform_tools::search_available_extensions_tool()); - debug!("Added search_available_extensions tool"); - - // Add resource tools if supported - these are generally safe for subagents - if extension_manager.supports_resources() { - tools.extend([ - platform_tools::read_resource_tool(), - platform_tools::list_resources_tool(), - ]); - debug!("Added 2 resource platform tools"); - } - - // Note: We explicitly do NOT add these tools for security reasons: - // - manage_extensions (could interfere with parent agent's extensions) - // - manage_schedule (could interfere with parent agent's scheduling) - // - subagent spawning tools (prevent recursion) - debug!("Platform tools added successfully (dangerous tools excluded)"); - } - - /// Check if a tool name is a platform tool that subagents can use - fn is_platform_tool(&self, tool_name: &str) -> bool { - matches!( - tool_name, - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME - | PLATFORM_READ_RESOURCE_TOOL_NAME - | PLATFORM_LIST_RESOURCES_TOOL_NAME - ) - } - - /// Handle platform tool calls that are safe for subagents - async fn handle_platform_tool_call( - &self, - tool_call: mcp_core::tool::ToolCall, - extension_manager: &ExtensionManager, - ) -> Result, ToolError> { - debug!("Handling platform tool: {}", tool_call.name); - - match tool_call.name.as_str() { - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME => extension_manager - .search_available_extensions() - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - PLATFORM_READ_RESOURCE_TOOL_NAME => extension_manager - .read_resource(tool_call.arguments) - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - PLATFORM_LIST_RESOURCES_TOOL_NAME => extension_manager - .list_resources(tool_call.arguments) - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - _ => Err(ToolError::ExecutionError(format!( - "Platform tool '{}' is not available to subagents for security reasons", - tool_call.name - ))), - } + fn _filter_subagent_tools(tools: Vec) -> Vec { + // TODO: add this in subagent loop + tools } /// Build the system prompt for the subagent using the template @@ -678,14 +401,6 @@ impl SubAgent { ); context.insert("subagent_id", serde_json::Value::String(self.id.clone())); - // Add recipe information if available - if let Some(recipe) = &self.config.recipe { - context.insert( - "recipe_title", - serde_json::Value::String(recipe.title.clone()), - ); - } - // Add max turns if configured if let Some(max_turns) = self.config.max_turns { context.insert( @@ -694,33 +409,6 @@ impl SubAgent { ); } - // Add task instructions - let instructions = if let Some(recipe) = &self.config.recipe { - recipe.instructions.as_deref().unwrap_or("") - } else { - self.config.instructions.as_deref().unwrap_or("") - }; - context.insert( - "task_instructions", - serde_json::Value::String(instructions.to_string()), - ); - - // Add available extensions (only if we have a recipe and extensions) - if self.config.recipe.is_some() { - let extensions: Vec = self.recipe_extensions.lock().await.clone(); - if !extensions.is_empty() { - context.insert( - "extensions", - serde_json::Value::Array( - extensions - .into_iter() - .map(serde_json::Value::String) - .collect(), - ), - ); - } - } - // Add available tools with descriptions for better context let tools_with_descriptions: Vec = available_tools .iter() diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs new file mode 100644 index 000000000000..9a71ad4ad7f9 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -0,0 +1,210 @@ +use crate::agents::subagent_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, +}; +use crate::agents::subagent_execution_tool::task_execution_tracker::{ + DisplayMode, TaskExecutionTracker, +}; +use crate::agents::subagent_execution_tool::tasks::process_task; +use crate::agents::subagent_execution_tool::workers::spawn_worker; +use crate::agents::subagent_task_config::TaskConfig; +use mcp_core::protocol::JsonRpcMessage; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::time::Instant; + +const EXECUTION_STATUS_COMPLETED: &str = "completed"; +const DEFAULT_MAX_WORKERS: usize = 10; + +pub async fn execute_single_task( + task: &Task, + notifier: mpsc::Sender, + task_config: TaskConfig, +) -> ExecutionResponse { + let start_time = Instant::now(); + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( + vec![task.clone()], + DisplayMode::SingleTaskOutput, + notifier, + )); + let result = process_task(task, task_execution_tracker.clone(), task_config).await; + + // Complete the task in the tracker + task_execution_tracker + .complete_task(&result.task_id, result.clone()) + .await; + + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&[result.clone()], execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![result], + stats, + } +} + +pub async fn execute_tasks_in_parallel( + tasks: Vec, + notifier: mpsc::Sender, + task_config: TaskConfig, +) -> ExecutionResponse { + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( + tasks.clone(), + DisplayMode::MultipleTasksOutput, + notifier, + )); + let start_time = Instant::now(); + let task_count = tasks.len(); + + if task_count == 0 { + return create_empty_response(); + } + + task_execution_tracker.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { + tracing::error!("Task execution failed: {}", e); + return create_error_response(e); + } + + let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); + + let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i, task_config.clone()); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await; + + for handle in worker_handles { + if let Err(e) = handle.await { + tracing::error!("Worker error: {}", e); + } + } + + task_execution_tracker.send_tasks_complete().await; + + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results, + stats, + } +} + +fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { + let completed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Completed)) + .count(); + let failed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .count(); + + ExecutionStats { + total_tasks: results.len(), + completed, + failed, + execution_time_ms, + } +} + +fn create_channels( + task_count: usize, +) -> ( + mpsc::Sender, + mpsc::Receiver, + mpsc::Sender, + mpsc::Receiver, +) { + let (task_tx, task_rx) = mpsc::channel::(task_count); + let (result_tx, result_rx) = mpsc::channel::(task_count); + (task_tx, task_rx, result_tx, result_rx) +} + +fn create_shared_state( + task_rx: mpsc::Receiver, + result_tx: mpsc::Sender, + task_execution_tracker: Arc, +) -> Arc { + Arc::new(SharedState { + task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), + result_sender: result_tx, + active_workers: Arc::new(AtomicUsize::new(0)), + task_execution_tracker, + }) +} + +async fn send_tasks_to_channel( + tasks: Vec, + task_tx: mpsc::Sender, +) -> Result<(), String> { + for task in tasks { + task_tx + .send(task) + .await + .map_err(|e| format!("Failed to queue task: {}", e))?; + } + Ok(()) +} + +fn create_empty_response() -> ExecutionResponse { + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 0, + execution_time_ms: 0, + }, + } +} +async fn collect_results( + result_rx: &mut mpsc::Receiver, + task_execution_tracker: Arc, + expected_count: usize, +) -> Vec { + let mut results = Vec::new(); + while let Some(mut result) = result_rx.recv().await { + // Truncate data to 650 chars if needed + if let Some(data) = result.data.as_mut() { + if let Some(data_str) = data.as_str() { + if data_str.len() > 650 { + *data = serde_json::Value::String(format!("{}...", &data_str[..650])); + } + } + } + task_execution_tracker + .complete_task(&result.task_id, result.clone()) + .await; + + results.push(result); + if results.len() >= expected_count { + break; + } + } + results +} + +fn create_error_response(error: String) -> ExecutionResponse { + tracing::error!("Creating error response: {}", error); + ExecutionResponse { + status: "failed".to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 1, + execution_time_ms: 0, + }, + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/tests.rs b/crates/goose/src/agents/subagent_execution_tool/executor/tests.rs new file mode 100644 index 000000000000..76385b87ef37 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/executor/tests.rs @@ -0,0 +1,100 @@ +use super::{calculate_stats, create_empty_response, create_error_response}; +use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult { + let is_failed = matches!(status, TaskStatus::Failed); + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"output": "test output"})), + error: if is_failed { + Some("Test error".to_string()) + } else { + None + }, + } +} + +#[test] +fn test_calculate_stats() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + create_test_task_result("task3", TaskStatus::Failed), + create_test_task_result("task4", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 1500); + + assert_eq!(stats.total_tasks, 4); + assert_eq!(stats.completed, 3); + assert_eq!(stats.failed, 1); + assert_eq!(stats.execution_time_ms, 1500); +} + +#[test] +fn test_calculate_stats_empty_results() { + let results = vec![]; + let stats = calculate_stats(&results, 0); + + assert_eq!(stats.total_tasks, 0); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 0); +} + +#[test] +fn test_calculate_stats_all_completed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 800); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 2); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 800); +} + +#[test] +fn test_calculate_stats_all_failed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed), + create_test_task_result("task2", TaskStatus::Failed), + ]; + + let stats = calculate_stats(&results, 1200); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 2); + assert_eq!(stats.execution_time_ms, 1200); +} + +#[test] +fn test_create_empty_response() { + let response = create_empty_response(); + + assert_eq!(response.status, "completed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 0); + assert_eq!(response.stats.execution_time_ms, 0); +} + +#[test] +fn test_create_error_response() { + let error_msg = "Test error message"; + let response = create_error_response(error_msg.to_string()); + + assert_eq!(response.status, "failed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 1); + assert_eq!(response.stats.execution_time_ms, 0); +} diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs new file mode 100644 index 000000000000..81d728886eab --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -0,0 +1,124 @@ +pub use crate::agents::subagent_execution_tool::task_types::{ + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, +}; +use crate::agents::subagent_execution_tool::{ + executor::{execute_single_task, execute_tasks_in_parallel}, + tasks_manager::TasksManager, +}; +use crate::agents::subagent_task_config::TaskConfig; +use mcp_core::protocol::JsonRpcMessage; +use serde_json::{json, Value}; +use tokio::sync::mpsc; + +pub async fn execute_tasks( + input: Value, + execution_mode: ExecutionMode, + notifier: mpsc::Sender, + task_config: TaskConfig, + tasks_manager: &TasksManager, +) -> Result { + let task_ids: Vec = serde_json::from_value( + input + .get("task_ids") + .ok_or("Missing task_ids field")? + .clone(), + ) + .map_err(|e| format!("Failed to parse task_ids: {}", e))?; + + let mut tasks = Vec::new(); + for task_id in &task_ids { + match tasks_manager.get_task(task_id).await { + Some(task) => tasks.push(task), + None => { + return Err(format!( + "Task with ID '{}' not found in TasksManager", + task_id + )) + } + } + } + + let task_count = tasks.len(); + match execution_mode { + ExecutionMode::Sequential => { + if task_count == 1 { + let response = execute_single_task(&tasks[0], notifier, task_config).await; + handle_response(response) + } else { + Err("Sequential execution mode requires exactly one task".to_string()) + } + } + ExecutionMode::Parallel => { + if tasks.iter().any(|task| task.get_sequential_when_repeated()) { + Ok(json!( + { + "execution_mode": ExecutionMode::Sequential, + "task_ids": task_ids, + "results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."] + } + )) + } else { + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await; + handle_response(response) + } + } + } +} + +fn extract_failed_tasks(results: &[TaskResult]) -> Vec { + results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(format_failed_task_error) + .collect() +} + +fn format_failed_task_error(result: &TaskResult) -> String { + let error_msg = result.error.as_deref().unwrap_or("Unknown error"); + let partial_output = result + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}\nOutput: {}", + result.task_id, + get_task_description(result), + error_msg, + partial_output + ) +} + +fn format_error_summary( + failed_count: usize, + total_count: usize, + failed_tasks: Vec, +) -> String { + format!( + "{}/{} tasks failed:\n{}", + failed_count, + total_count, + failed_tasks.join("\n") + ) +} + +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks = extract_failed_tasks(&response.results); + let error_summary = format_error_summary( + response.stats.failed, + response.stats.total_tasks, + failed_tasks, + ); + return Err(error_summary); + } + serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) +} + +fn get_task_description(result: &TaskResult) -> String { + format!("ID: {}", result.task_id) +} diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/tests.rs b/crates/goose/src/agents/subagent_execution_tool/lib/tests.rs new file mode 100644 index 000000000000..957b11274279 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/lib/tests.rs @@ -0,0 +1,216 @@ +use super::{ + extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description, + handle_response, +}; +use crate::agents::sub_recipe_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, TaskResult, TaskStatus, +}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option) -> TaskResult { + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"partial_output": "test output"})), + error, + } +} + +fn create_test_execution_response( + results: Vec, + failed_count: usize, +) -> ExecutionResponse { + ExecutionResponse { + status: "completed".to_string(), + results: results.clone(), + stats: ExecutionStats { + total_tasks: results.len(), + completed: results.len() - failed_count, + failed: failed_count, + execution_time_ms: 1000, + }, + } +} + +#[test] +fn test_extract_failed_tasks() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result( + "task2", + TaskStatus::Failed, + Some("Error message".to_string()), + ), + create_test_task_result("task3", TaskStatus::Completed, None), + create_test_task_result( + "task4", + TaskStatus::Failed, + Some("Another error".to_string()), + ), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 2); + assert!(failed_tasks[0].contains("task2")); + assert!(failed_tasks[0].contains("Error message")); + assert!(failed_tasks[1].contains("task4")); + assert!(failed_tasks[1].contains("Another error")); +} + +#[test] +fn test_extract_failed_tasks_empty() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 0); +} + +#[test] +fn test_format_failed_task_error_with_error_message() { + let result = create_test_task_result( + "task1", + TaskStatus::Failed, + Some("Test error message".to_string()), + ); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task1")); + assert!(formatted.contains("Test error message")); + assert!(formatted.contains("test output")); + assert!(formatted.contains("ID: task1")); +} + +#[test] +fn test_format_failed_task_error_without_error_message() { + let result = create_test_task_result("task2", TaskStatus::Failed, None); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task2")); + assert!(formatted.contains("Unknown error")); + assert!(formatted.contains("test output")); +} + +#[test] +fn test_format_failed_task_error_empty_partial_output() { + let mut result = + create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({"partial_output": ""})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_partial_output() { + let mut result = + create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_data() { + let mut result = + create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string())); + result.data = None; + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_error_summary() { + let failed_tasks = vec![ + "Task 'task1': Error 1\nOutput: output1".to_string(), + "Task 'task2': Error 2\nOutput: output2".to_string(), + ]; + + let summary = format_error_summary(2, 5, failed_tasks); + + assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2"); +} + +#[test] +fn test_format_error_summary_single_failure() { + let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()]; + + let summary = format_error_summary(1, 3, failed_tasks); + + assert_eq!( + summary, + "1/3 tasks failed:\nTask 'task1': Error\nOutput: output" + ); +} + +#[test] +fn test_handle_response_success() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + let response = create_test_execution_response(results, 0); + + let result = handle_response(response); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["status"], "completed"); + assert_eq!(value["stats"]["failed"], 0); +} + +#[test] +fn test_handle_response_with_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())), + ]; + let response = create_test_execution_response(results, 1); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("1/2 tasks failed")); + assert!(error.contains("task2")); + assert!(error.contains("Test error")); +} + +#[test] +fn test_handle_response_all_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())), + create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())), + ]; + let response = create_test_execution_response(results, 2); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("2/2 tasks failed")); + assert!(error.contains("task1")); + assert!(error.contains("task2")); + assert!(error.contains("Error 1")); + assert!(error.contains("Error 2")); +} + +#[test] +fn test_get_task_description() { + let result = create_test_task_result("test_task_123", TaskStatus::Completed, None); + + let description = get_task_description(&result); + + assert_eq!(description, "ID: test_task_123"); +} diff --git a/crates/goose/src/agents/subagent_execution_tool/mod.rs b/crates/goose/src/agents/subagent_execution_tool/mod.rs new file mode 100644 index 000000000000..2226e2d7e58c --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/mod.rs @@ -0,0 +1,10 @@ +mod executor; +pub mod lib; +pub mod notification_events; +pub mod subagent_execute_task_tool; +pub mod task_execution_tracker; +pub mod task_types; +pub mod tasks; +pub mod tasks_manager; +pub mod utils; +pub mod workers; diff --git a/crates/goose/src/agents/subagent_execution_tool/notification_events.rs b/crates/goose/src/agents/subagent_execution_tool/notification_events.rs new file mode 100644 index 000000000000..632cb976b94c --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/notification_events.rs @@ -0,0 +1,204 @@ +use crate::agents::subagent_execution_tool::task_types::TaskStatus; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype")] +pub enum TaskExecutionNotificationEvent { + #[serde(rename = "line_output")] + LineOutput { task_id: String, output: String }, + #[serde(rename = "tasks_update")] + TasksUpdate { + stats: TaskExecutionStats, + tasks: Vec, + }, + #[serde(rename = "tasks_complete")] + TasksComplete { + stats: TaskCompletionStats, + failed_tasks: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskExecutionStats { + pub total: usize, + pub pending: usize, + pub running: usize, + pub completed: usize, + pub failed: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskCompletionStats { + pub total: usize, + pub completed: usize, + pub failed: usize, + pub success_rate: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInfo { + pub id: String, + pub status: TaskStatus, + pub duration_secs: Option, + pub current_output: String, + pub task_type: String, + pub task_name: String, + pub task_metadata: String, + pub error: Option, + pub result_data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FailedTaskInfo { + pub id: String, + pub name: String, + pub error: Option, +} + +impl TaskExecutionNotificationEvent { + pub fn line_output(task_id: String, output: String) -> Self { + Self::LineOutput { task_id, output } + } + + pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec) -> Self { + Self::TasksUpdate { stats, tasks } + } + + pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec) -> Self { + Self::TasksComplete { + stats, + failed_tasks, + } + } + + /// Convert event to JSON format for MCP notification + pub fn to_notification_data(&self) -> serde_json::Value { + let mut event_data = serde_json::to_value(self).expect("Failed to serialize event"); + + // Add the type field at the root level + if let serde_json::Value::Object(ref mut map) = event_data { + map.insert( + "type".to_string(), + serde_json::Value::String("task_execution".to_string()), + ); + } + + event_data + } +} + +impl TaskExecutionStats { + pub fn new( + total: usize, + pending: usize, + running: usize, + completed: usize, + failed: usize, + ) -> Self { + Self { + total, + pending, + running, + completed, + failed, + } + } +} + +impl TaskCompletionStats { + pub fn new(total: usize, completed: usize, failed: usize) -> Self { + let success_rate = if total > 0 { + (completed as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + Self { + total, + completed, + failed, + success_rate, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_line_output_event_serialization() { + let event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Hello World".to_string(), + ); + + let notification_data = event.to_notification_data(); + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "line_output"); + assert_eq!(notification_data["task_id"], "task-1"); + assert_eq!(notification_data["output"], "Hello World"); + } + + #[test] + fn test_tasks_update_event_serialization() { + let stats = TaskExecutionStats::new(5, 2, 1, 1, 1); + let tasks = vec![TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }]; + + let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks); + let notification_data = event.to_notification_data(); + + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "tasks_update"); + assert_eq!(notification_data["stats"]["total"], 5); + assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1); + } + + #[test] + fn test_event_roundtrip_serialization() { + let original_event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Test output".to_string(), + ); + + // Serialize to JSON + let json_data = original_event.to_notification_data(); + + // Deserialize back to event (excluding the type field) + let mut event_data = json_data.clone(); + if let serde_json::Value::Object(ref mut map) = event_data { + map.remove("type"); + } + + let deserialized_event: TaskExecutionNotificationEvent = + serde_json::from_value(event_data).expect("Failed to deserialize"); + + match (original_event, deserialized_event) { + ( + TaskExecutionNotificationEvent::LineOutput { + task_id: id1, + output: out1, + }, + TaskExecutionNotificationEvent::LineOutput { + task_id: id2, + output: out2, + }, + ) => { + assert_eq!(id1, id2); + assert_eq!(out1, out2); + } + _ => panic!("Event types don't match after roundtrip"), + } + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs new file mode 100644 index 000000000000..f3860253eb78 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -0,0 +1,102 @@ +use mcp_core::{tool::ToolAnnotations, Tool, ToolError}; +use rmcp::model::Content; +use serde_json::Value; + +use crate::agents::subagent_task_config::TaskConfig; +use crate::agents::{ + subagent_execution_tool::lib::execute_tasks, + subagent_execution_tool::task_types::ExecutionMode, + subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, +}; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::mpsc; +use tokio_stream; + +pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task"; +pub fn create_subagent_execute_task_tool() -> Tool { + Tool::new( + SUBAGENT_EXECUTE_TASK_TOOL_NAME, + "Only use the subagent__execute_task tool when you execute sub recipe task or dynamic task. +EXECUTION STRATEGY DECISION: +1. If the tasks are created with execution_mode, use the execution_mode. +2. Execute tasks sequentially unless user explicitly requests parallel execution. PARALLEL: User uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' + +IMPLEMENTATION: +- Sequential execution: Call this tool multiple times, passing exactly ONE task per call +- Parallel execution: Call this tool once, passing an ARRAY of all tasks + +EXAMPLES: +User Intent Based: +- User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) +- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) +- User: 'run these simultaneously' → Parallel (1 tool call with task array) +- User: 'do task A then task B' → Sequential (2 separate tool calls)", + serde_json::json!({ + "type": "object", + "properties": { + "execution_mode": { + "type": "string", + "enum": ["sequential", "parallel"], + "default": "sequential", + "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + }, + "task_ids": { + "type": "array", + "items": { + "type": "string", + "description": "Unique identifier for the task" + } + } + }, + "required": ["task_ids"] + }), + Some(ToolAnnotations { + title: Some("Run tasks in parallel".to_string()), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +pub async fn run_tasks( + execute_data: Value, + task_config: TaskConfig, + tasks_manager: &TasksManager, +) -> ToolCallResult { + let (notification_tx, notification_rx) = mpsc::channel::(100); + + let tasks_manager_clone = tasks_manager.clone(); + let result_future = async move { + let execute_data_clone = execute_data.clone(); + let execution_mode = execute_data_clone + .get("execution_mode") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or_default(); + + match execute_tasks( + execute_data, + execution_mode, + notification_tx, + task_config, + &tasks_manager_clone, + ) + .await + { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + Ok(vec![Content::text(output)]) + } + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + } + }; + + // Convert receiver to stream + let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx); + + ToolCallResult { + result: Box::new(Box::pin(result_future)), + notification_stream: Some(Box::new(notification_stream)), + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs new file mode 100644 index 000000000000..c720459e01ae --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -0,0 +1,305 @@ +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; +use tokio::time::{sleep, Duration, Instant}; + +use crate::agents::subagent_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, + TaskInfo as EventTaskInfo, +}; +use crate::agents::subagent_execution_tool::task_types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::subagent_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq)] +pub enum DisplayMode { + MultipleTasksOutput, + SingleTaskOutput, +} + +const THROTTLE_INTERVAL_MS: u64 = 250; +const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500; + +fn format_task_metadata(task_info: &TaskInfo) -> String { + if let Some(params) = task_info.task.get_command_parameters() { + if params.is_empty() { + return String::new(); + } + + params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(",") + } else if task_info.task.task_type == "text_instruction" { + // For text_instruction tasks, extract and display the instruction + if let Some(text_instruction) = task_info.task.get_text_instruction() { + // Truncate long instructions to keep the display clean + if text_instruction.len() > 80 { + format!("instruction={}...", &text_instruction[..77]) + } else { + format!("instruction={}", text_instruction) + } + } else { + String::new() + } + } else { + String::new() + } +} + +pub struct TaskExecutionTracker { + tasks: Arc>>, + last_refresh: Arc>, + notifier: mpsc::Sender, + display_mode: DisplayMode, +} + +impl TaskExecutionTracker { + pub fn new( + tasks: Vec, + display_mode: DisplayMode, + notifier: mpsc::Sender, + ) -> Self { + let task_map = tasks + .into_iter() + .map(|task| { + let task_id = task.id.clone(); + ( + task_id, + TaskInfo { + task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }, + ) + }) + .collect(); + + Self { + tasks: Arc::new(RwLock::new(task_map)), + last_refresh: Arc::new(RwLock::new(Instant::now())), + notifier, + display_mode, + } + } + + pub async fn start_task(&self, task_id: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = TaskStatus::Running; + task_info.start_time = Some(Instant::now()); + } + drop(tasks); + self.force_refresh_display().await; + } + + pub async fn complete_task(&self, task_id: &str, result: TaskResult) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = result.status.clone(); + task_info.end_time = Some(Instant::now()); + task_info.result = Some(result); + } + drop(tasks); + self.force_refresh_display().await; + } + + pub async fn get_current_output(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks + .get(task_id) + .map(|task_info| task_info.current_output.clone()) + } + + async fn format_line(&self, task_info: Option<&TaskInfo>, line: &str) -> String { + if let Some(task_info) = task_info { + let task_name = get_task_name(task_info); + let task_type = task_info.task.task_type.clone(); + let metadata = format_task_metadata(task_info); + + if metadata.is_empty() { + format!("[{} ({})] {}", task_name, task_type, line) + } else { + format!("[{} ({}) {}] {}", task_name, task_type, metadata, line) + } + } else { + line.to_string() + } + } + + pub async fn send_live_output(&self, task_id: &str, line: &str) { + match self.display_mode { + DisplayMode::SingleTaskOutput => { + let tasks = self.tasks.read().await; + let task_info = tasks.get(task_id); + + let formatted_line = self.format_line(task_info, line).await; + drop(tasks); + let event = TaskExecutionNotificationEvent::line_output( + task_id.to_string(), + formatted_line, + ); + + if let Err(e) = + self.notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send live output notification: {}", e); + } + } + DisplayMode::MultipleTasksOutput => { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.current_output.push_str(line); + task_info.current_output.push('\n'); + } + drop(tasks); + + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } + } + } + + async fn should_throttle_refresh(&self) -> bool { + let now = Instant::now(); + let mut last_refresh = self.last_refresh.write().await; + + if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) { + *last_refresh = now; + false + } else { + true + } + } + + async fn send_tasks_update(&self) { + let tasks = self.tasks.read().await; + let task_list: Vec<_> = tasks.values().collect(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + + let stats = TaskExecutionStats::new(total, pending, running, completed, failed); + + let event_tasks: Vec = task_list + .iter() + .map(|task_info| { + let now = Instant::now(); + EventTaskInfo { + id: task_info.task.id.clone(), + status: task_info.status.clone(), + duration_secs: task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + current_output: task_info.current_output.clone(), + task_type: task_info.task.task_type.clone(), + task_name: get_task_name(task_info).to_string(), + task_metadata: format_task_metadata(task_info), + error: task_info.error().cloned(), + result_data: task_info.data().cloned(), + } + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks); + + if let Err(e) = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send tasks update notification: {}", e); + } + } + + pub async fn refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + // Live output is handled via send_live_output method + } + } + } + + // Force refresh without throttling - used for important status changes + async fn force_refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + // Reset throttle timer to allow immediate update + let mut last_refresh = self.last_refresh.write().await; + *last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1); + drop(last_refresh); + + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + } + } + } + + pub async fn send_tasks_complete(&self) { + let tasks = self.tasks.read().await; + let (total, _, _, completed, failed) = count_by_status(&tasks); + + let stats = TaskCompletionStats::new(total, completed, failed); + + let failed_tasks: Vec = tasks + .values() + .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) + .map(|task_info| FailedTaskInfo { + id: task_info.task.id.clone(), + name: get_task_name(task_info).to_string(), + error: task_info.error().cloned(), + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks); + + if let Err(e) = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } + + // Brief delay to ensure completion notification is processed + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/task_types.rs b/crates/goose/src/agents/subagent_execution_tool/task_types.rs new file mode 100644 index 000000000000..796491f624f2 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/task_types.rs @@ -0,0 +1,145 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum ExecutionMode { + #[default] + Sequential, + Parallel, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub id: String, + pub task_type: String, + pub payload: Value, +} + +impl Task { + pub fn get_sub_recipe(&self) -> Option<&Map> { + (self.task_type == "sub_recipe") + .then(|| self.payload.get("sub_recipe")?.as_object()) + .flatten() + } + + pub fn get_command_parameters(&self) -> Option<&Map> { + self.get_sub_recipe() + .and_then(|sr| sr.get("command_parameters")) + .and_then(|cp| cp.as_object()) + } + + pub fn get_sequential_when_repeated(&self) -> bool { + self.get_sub_recipe() + .and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool())) + .unwrap_or_default() + } + + pub fn get_sub_recipe_name(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("name")) + .and_then(|name| name.as_str()) + } + + pub fn get_sub_recipe_path(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("recipe_path")) + .and_then(|path| path.as_str()) + } + + pub fn get_text_instruction(&self) -> Option<&str> { + if self.task_type != "sub_recipe" { + self.payload + .get("text_instruction") + .and_then(|text| text.as_str()) + } else { + None + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskResult { + pub task_id: String, + pub status: TaskStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TaskStatus { + Pending, + Running, + Completed, + Failed, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "Pending"), + TaskStatus::Running => write!(f, "Running"), + TaskStatus::Completed => write!(f, "Completed"), + TaskStatus::Failed => write!(f, "Failed"), + } + } +} + +#[derive(Debug, Clone)] +pub struct TaskInfo { + pub task: Task, + pub status: TaskStatus, + pub start_time: Option, + pub end_time: Option, + pub result: Option, + pub current_output: String, +} + +impl TaskInfo { + pub fn error(&self) -> Option<&String> { + self.result.as_ref().and_then(|r| r.error.as_ref()) + } + + pub fn data(&self) -> Option<&Value> { + self.result.as_ref().and_then(|r| r.data.as_ref()) + } +} + +pub struct SharedState { + pub task_receiver: Arc>>, + pub result_sender: mpsc::Sender, + pub active_workers: Arc, + pub task_execution_tracker: Arc, +} + +impl SharedState { + pub fn increment_active_workers(&self) { + self.active_workers.fetch_add(1, Ordering::SeqCst); + } + + pub fn decrement_active_workers(&self) { + self.active_workers.fetch_sub(1, Ordering::SeqCst); + } +} + +#[derive(Debug, Serialize)] +pub struct ExecutionStats { + pub total_tasks: usize, + pub completed: usize, + pub failed: usize, + pub execution_time_ms: u128, +} + +#[derive(Debug, Serialize)] +pub struct ExecutionResponse { + pub status: String, + pub results: Vec, + pub stats: ExecutionStats, +} diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs new file mode 100644 index 000000000000..a330711e0a0a --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -0,0 +1,234 @@ +use serde_json::Value; +use std::ops::Deref; +use std::process::Stdio; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::process::Command; + +use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; +use crate::agents::subagent_execution_tool::task_types::{Task, TaskResult, TaskStatus}; +use crate::agents::subagent_handler::run_complete_subagent_task; +use crate::agents::subagent_task_config::TaskConfig; + +pub async fn process_task( + task: &Task, + task_execution_tracker: Arc, + task_config: TaskConfig, +) -> TaskResult { + match get_task_result(task.clone(), task_execution_tracker, task_config).await { + Ok(data) => TaskResult { + task_id: task.id.clone(), + status: TaskStatus::Completed, + data: Some(data), + error: None, + }, + Err(error) => TaskResult { + task_id: task.id.clone(), + status: TaskStatus::Failed, + data: None, + error: Some(error), + }, + } +} + +async fn get_task_result( + task: Task, + task_execution_tracker: Arc, + task_config: TaskConfig, +) -> Result { + if task.task_type == "text_instruction" { + // Handle text_instruction tasks using subagent system + handle_text_instruction_task(task, task_execution_tracker, task_config).await + } else { + // Handle sub_recipe tasks using command execution + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } + } +} + +async fn handle_text_instruction_task( + task: Task, + task_execution_tracker: Arc, + task_config: TaskConfig, +) -> Result { + let text_instruction = task + .get_text_instruction() + .ok_or_else(|| format!("Task {}: Missing text_instruction", task.id))?; + + // Start tracking the task + task_execution_tracker.start_task(&task.id).await; + + // Create arguments for the subagent task + let task_arguments = serde_json::json!({ + "text_instruction": text_instruction, + // "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", + }); + + match run_complete_subagent_task(task_arguments, task_config).await { + Ok(contents) => { + // Extract the text content from the result + let result_text = contents + .into_iter() + .filter_map(|content| match content.deref() { + rmcp::model::RawContent::Text(raw_text_content) => { + Some(raw_text_content.text.clone()) + } + _ => None, + }) + .collect::>() + .join("\n"); + + Ok(serde_json::json!({ + "result": result_text + })) + } + Err(e) => { + let error_msg = format!("Subagent execution failed: {}", e); + Err(error_msg) + } + } +} + +fn build_command(task: &Task) -> Result<(Command, String), String> { + let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); + + let (mut command, output_identifier) = if task.task_type == "sub_recipe" { + let sub_recipe_name = task + .get_sub_recipe_name() + .ok_or_else(|| task_error("sub_recipe name"))?; + let path = task + .get_sub_recipe_path() + .ok_or_else(|| task_error("sub_recipe path"))?; + let command_parameters = task + .get_command_parameters() + .ok_or_else(|| task_error("command_parameters"))?; + + let mut cmd = Command::new("goose"); + cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); + + for (key, value) in command_parameters { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); + } + (cmd, format!("sub-recipe {}", sub_recipe_name)) + } else { + // This branch should not be reached for text_instruction tasks anymore + // as they are handled in handle_text_instruction_task + return Err("Text instruction tasks are handled separately".to_string()); + }; + + command.stdout(Stdio::piped()); + command.stderr(Stdio::piped()); + Ok((command, output_identifier)) +} + +async fn run_command( + mut command: Command, + output_identifier: &str, + task_id: &str, + task_execution_tracker: Arc, +) -> Result<(String, String, bool), String> { + let mut child = command + .spawn() + .map_err(|e| format!("Failed to spawn goose: {}", e))?; + + let stdout = child.stdout.take().expect("Failed to capture stdout"); + let stderr = child.stderr.take().expect("Failed to capture stderr"); + + let stdout_task = spawn_output_reader( + stdout, + output_identifier, + false, + task_id, + task_execution_tracker.clone(), + ); + let stderr_task = spawn_output_reader( + stderr, + output_identifier, + true, + task_id, + task_execution_tracker.clone(), + ); + + let status = child + .wait() + .await + .map_err(|e| format!("Failed to wait for process: {}", e))?; + + let stdout_output = stdout_task.await.unwrap(); + let stderr_output = stderr_task.await.unwrap(); + + Ok((stdout_output, stderr_output, status.success())) +} + +fn spawn_output_reader( + reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, + output_identifier: &str, + is_stderr: bool, + task_id: &str, + task_execution_tracker: Arc, +) -> tokio::task::JoinHandle { + let output_identifier = output_identifier.to_string(); + let task_id = task_id.to_string(); + tokio::spawn(async move { + let mut buffer = String::new(); + let mut lines = BufReader::new(reader).lines(); + while let Ok(Some(line)) = lines.next_line().await { + buffer.push_str(&line); + buffer.push('\n'); + + if !is_stderr { + task_execution_tracker + .send_live_output(&task_id, &line) + .await; + } else { + tracing::warn!("Task stderr [{}]: {}", output_identifier, line); + } + } + buffer + }) +} + +fn extract_json_from_line(line: &str) -> Option { + let start = line.find('{')?; + let end = line.rfind('}')?; + + if start >= end { + return None; + } + + let potential_json = &line[start..=end]; + if serde_json::from_str::(potential_json).is_ok() { + Some(potential_json.to_string()) + } else { + None + } +} + +fn process_output(stdout_output: String) -> Result { + let last_line = stdout_output + .lines() + .filter(|line| !line.trim().is_empty()) + .next_back() + .unwrap_or(""); + + if let Some(json_string) = extract_json_from_line(last_line) { + Ok(Value::String(json_string)) + } else { + Ok(Value::String(stdout_output)) + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs new file mode 100644 index 000000000000..334379fa4ef5 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::agents::subagent_execution_tool::task_types::Task; + +#[derive(Debug, Clone)] +pub struct TasksManager { + tasks: Arc>>, +} + +impl Default for TasksManager { + fn default() -> Self { + Self::new() + } +} + +impl TasksManager { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn save_tasks(&self, tasks: Vec) { + let mut task_map = self.tasks.write().await; + for task in tasks { + task_map.insert(task.id.clone(), task); + } + } + + pub async fn get_task(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks.get(task_id).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn create_test_task(id: &str, sub_recipe_name: &str) -> Task { + Task { + id: id.to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": sub_recipe_name, + "command_parameters": {}, + "recipe_path": "/test/path" + } + }), + } + } + + #[tokio::test] + async fn test_save_and_get_task() { + let manager = TasksManager::new(); + let tasks = vec![create_test_task("task1", "weather")]; + + manager.save_tasks(tasks).await; + + let retrieved = manager.get_task("task1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().id, "task1"); + } + + #[tokio::test] + async fn test_save_multiple_tasks() { + let manager = TasksManager::new(); + let tasks = vec![ + create_test_task("task1", "weather"), + create_test_task("task2", "news"), + ]; + + manager.save_tasks(tasks).await; + + let task1 = manager.get_task("task1").await; + let task2 = manager.get_task("task2").await; + assert!(task1.is_some()); + assert!(task2.is_some()); + assert_eq!(task1.unwrap().id, "task1"); + assert_eq!(task2.unwrap().id, "task2"); + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs new file mode 100644 index 000000000000..5d75675283d3 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs @@ -0,0 +1,27 @@ +use std::collections::HashMap; + +use crate::agents::subagent_execution_tool::task_types::{TaskInfo, TaskStatus}; + +pub fn get_task_name(task_info: &TaskInfo) -> &str { + task_info + .task + .get_sub_recipe_name() + .unwrap_or(&task_info.task.id) +} + +pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { + let total = tasks.len(); + let (pending, running, completed, failed) = tasks.values().fold( + (0, 0, 0, 0), + |(pending, running, completed, failed), task| match task.status { + TaskStatus::Pending => (pending + 1, running, completed, failed), + TaskStatus::Running => (pending, running + 1, completed, failed), + TaskStatus::Completed => (pending, running, completed + 1, failed), + TaskStatus::Failed => (pending, running, completed, failed + 1), + }, + ); + (total, pending, running, completed, failed) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs b/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs new file mode 100644 index 000000000000..b4e7f757b420 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs @@ -0,0 +1,154 @@ +use crate::agents::subagent_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; +use crate::agents::subagent_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::json; +use std::collections::HashMap; + +fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo { + TaskInfo { + task, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } +} + +mod test_get_task_name { + use super::*; + + #[test] + fn test_extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } + + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_2"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_3"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_4"); + } +} + +mod count_by_status { + use super::*; + + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + let task = Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }; + create_task_info_with_defaults(task, status) + } + + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); + } + + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } + + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } +} diff --git a/crates/goose/src/agents/subagent_execution_tool/workers.rs b/crates/goose/src/agents/subagent_execution_tool/workers.rs new file mode 100644 index 000000000000..4ae0ab250737 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/workers.rs @@ -0,0 +1,40 @@ +use crate::agents::subagent_execution_tool::task_types::{SharedState, Task}; +use crate::agents::subagent_execution_tool::tasks::process_task; +use crate::agents::subagent_task_config::TaskConfig; +use std::sync::Arc; + +async fn receive_task(state: &SharedState) -> Option { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await +} + +pub fn spawn_worker( + state: Arc, + worker_id: usize, + task_config: TaskConfig, +) -> tokio::task::JoinHandle<()> { + state.increment_active_workers(); + + tokio::spawn(async move { + worker_loop(state, worker_id, task_config).await; + }) +} + +async fn worker_loop(state: Arc, _worker_id: usize, task_config: TaskConfig) { + while let Some(task) = receive_task(&state).await { + state.task_execution_tracker.start_task(&task.id).await; + let result = process_task( + &task, + state.task_execution_tracker.clone(), + task_config.clone(), + ) + .await; + + if let Err(e) = state.result_sender.send(result).await { + tracing::error!("Worker failed to send result: {}", e); + break; + } + } + + state.decrement_active_workers(); +} diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index f281f7488e4d..6fadd24760e5 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,79 +1,44 @@ +use crate::agents::subagent::SubAgent; +use crate::agents::subagent_task_config::TaskConfig; use anyhow::Result; -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; +use rmcp::model::Content; use serde_json::Value; -use std::sync::Arc; -use crate::agents::subagent_types::SpawnSubAgentArgs; -use crate::agents::Agent; - -impl Agent { - /// Handle running a complete subagent task (replaces the individual spawn/send/check tools) - pub async fn handle_run_subagent_task( - &self, - arguments: Value, - ) -> Result, ToolError> { - let subagent_manager = self.subagent_manager.lock().await; - let manager = subagent_manager.as_ref().ok_or_else(|| { - ToolError::ExecutionError("Subagent manager not initialized".to_string()) - })?; - - // Parse arguments - using "task" as the main message parameter - let message = arguments - .get("task") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::ExecutionError("Missing task parameter".to_string()))? - .to_string(); - - // Either recipe_name or instructions must be provided - let recipe_name = arguments - .get("recipe_name") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let instructions = arguments - .get("instructions") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let mut args = if let Some(recipe_name) = recipe_name { - SpawnSubAgentArgs::new_with_recipe(recipe_name, message.clone()) - } else if let Some(instructions) = instructions { - SpawnSubAgentArgs::new_with_instructions(instructions, message.clone()) - } else { - return Err(ToolError::ExecutionError( - "Either recipe_name or instructions parameter must be provided".to_string(), - )); - }; - - // Set max_turns with default of 10 - let max_turns = arguments - .get("max_turns") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - args = args.with_max_turns(max_turns); - - if let Some(timeout) = arguments.get("timeout_seconds").and_then(|v| v.as_u64()) { - args = args.with_timeout(timeout); - } - - // Get the provider from the parent agent - let provider = self - .provider() - .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to get provider: {}", e)))?; - - // Get the extension manager from the parent agent - let extension_manager = Arc::new(self.extension_manager.read().await); - - // Run the complete subagent task - match manager - .run_complete_subagent_task(args, provider, extension_manager) - .await - { - Ok(result) => Ok(vec![Content::text(result)]), - Err(e) => Err(ToolError::ExecutionError(format!( - "Failed to run subagent task: {}", - e - ))), +/// Standalone function to run a complete subagent task +pub async fn run_complete_subagent_task( + task_arguments: Value, + task_config: TaskConfig, +) -> Result, ToolError> { + // Parse arguments - using "task" as the main message parameter + let text_instruction = task_arguments + .get("text_instruction") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::ExecutionError("Missing text_instruction parameter".to_string()))? + .to_string(); + + // Create the subagent with the parent agent's provider + let (subagent, handle) = SubAgent::new(task_config.clone()) + .await + .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; + + // Execute the subagent task + let result = match subagent.reply_subagent(text_instruction, task_config).await { + Ok(response) => { + let response_text = response.as_concat_text(); + Ok(vec![Content::text(response_text)]) } + Err(e) => Err(ToolError::ExecutionError(format!( + "Subagent execution failed: {}", + e + ))), + }; + + // Clean up the subagent handle + if let Err(e) = handle.await { + tracing::debug!("Subagent handle cleanup error: {}", e); } + + // Return the result + result } diff --git a/crates/goose/src/agents/subagent_manager.rs b/crates/goose/src/agents/subagent_manager.rs deleted file mode 100644 index 174faceecc1b..000000000000 --- a/crates/goose/src/agents/subagent_manager.rs +++ /dev/null @@ -1,404 +0,0 @@ -use std::collections::HashMap; -use std::path::Path; -use std::sync::Arc; - -use anyhow::{anyhow, Result}; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::{mpsc, Mutex, RwLock}; -use tracing::{debug, error, instrument, warn}; - -use crate::agents::extension_manager::ExtensionManager; -use crate::agents::subagent::{SubAgent, SubAgentConfig, SubAgentProgress, SubAgentStatus}; -use crate::agents::subagent_types::SpawnSubAgentArgs; -use crate::providers::base::Provider; -use crate::recipe::Recipe; - -/// Manages the lifecycle of subagents -pub struct SubAgentManager { - subagents: Arc>>>, - handles: Arc>>>, - mcp_notification_tx: mpsc::Sender, -} - -impl SubAgentManager { - /// Create a new subagent manager - pub fn new(mcp_notification_tx: mpsc::Sender) -> Self { - Self { - subagents: Arc::new(RwLock::new(HashMap::new())), - handles: Arc::new(Mutex::new(HashMap::new())), - mcp_notification_tx, - } - } - - /// Spawn a new interactive subagent - #[instrument(skip(self, args, provider, extension_manager))] - pub async fn spawn_interactive_subagent( - &self, - args: SpawnSubAgentArgs, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - debug!("Spawning interactive subagent"); - - // Create subagent config based on whether we have a recipe or instructions - let mut config = if let Some(recipe_name) = args.recipe_name { - debug!("Using recipe: {}", recipe_name); - // Load the recipe - let recipe = self.load_recipe(&recipe_name).await?; - SubAgentConfig::new_with_recipe(recipe) - } else if let Some(instructions) = args.instructions { - debug!("Using direct instructions"); - SubAgentConfig::new_with_instructions(instructions) - } else { - return Err(anyhow!( - "Either recipe_name or instructions must be provided" - )); - }; - - if let Some(max_turns) = args.max_turns { - config = config.with_max_turns(max_turns); - } - if let Some(timeout) = args.timeout_seconds { - config = config.with_timeout(timeout); - } - - // Create the subagent with the parent agent's provider - let (subagent, handle) = SubAgent::new( - config, - Arc::clone(&provider), - Arc::clone(&extension_manager), - self.mcp_notification_tx.clone(), - ) - .await?; - let subagent_id = subagent.id.clone(); - - // Store the subagent and its handle - { - let mut subagents = self.subagents.write().await; - subagents.insert(subagent_id.clone(), Arc::clone(&subagent)); - } - { - let mut handles = self.handles.lock().await; - handles.insert(subagent_id.clone(), handle); - } - - // Return immediately - no initial message processing - Ok(subagent_id) - } - - /// Get a subagent by ID - pub async fn get_subagent(&self, id: &str) -> Option> { - let subagents = self.subagents.read().await; - subagents.get(id).cloned() - } - - /// List all active subagent IDs - pub async fn list_subagents(&self) -> Vec { - let subagents = self.subagents.read().await; - subagents.keys().cloned().collect() - } - - /// Get status of all subagents - pub async fn get_subagent_status(&self) -> HashMap { - let subagents = self.subagents.read().await; - let mut status_map = HashMap::new(); - - for (id, subagent) in subagents.iter() { - status_map.insert(id.clone(), subagent.get_status().await); - } - - status_map - } - - /// Get progress of all subagents - pub async fn get_subagent_progress(&self) -> HashMap { - let subagents = self.subagents.read().await; - let mut progress_map = HashMap::new(); - - for (id, subagent) in subagents.iter() { - progress_map.insert(id.clone(), subagent.get_progress().await); - } - - progress_map - } - - /// Send a message to a specific subagent - #[instrument(skip(self, message, provider, extension_manager))] - pub async fn send_message_to_subagent( - &self, - subagent_id: &str, - message: String, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - let subagent = self - .get_subagent(subagent_id) - .await - .ok_or_else(|| anyhow!("Subagent {} not found", subagent_id))?; - - // Process the message and get a reply - match subagent - .reply_subagent(message, provider, extension_manager) - .await - { - Ok(response) => Ok(format!( - "Message sent to subagent {}. Response:\n{}", - subagent_id, - response.as_concat_text() - )), - Err(e) => Err(anyhow!("Failed to process message in subagent: {}", e)), - } - } - - /// Terminate a specific subagent - #[instrument(skip(self))] - pub async fn terminate_subagent(&self, id: &str) -> Result<()> { - debug!("Terminating subagent {}", id); - - // Get and terminate the subagent - let subagent = { - let mut subagents = self.subagents.write().await; - subagents.remove(id) - }; - - if let Some(subagent) = subagent { - subagent.terminate().await?; - } else { - warn!("Attempted to terminate non-existent subagent {}", id); - return Err(anyhow!("Subagent {} not found", id)); - } - - // Clean up the background handle - let handle = { - let mut handles = self.handles.lock().await; - handles.remove(id) - }; - - if let Some(handle) = handle { - handle.abort(); - } - - debug!("Subagent {} terminated successfully", id); - Ok(()) - } - - /// Terminate all subagents - #[instrument(skip(self))] - pub async fn terminate_all_subagents(&self) -> Result<()> { - debug!("Terminating all subagents"); - - let subagent_ids: Vec = { - let subagents = self.subagents.read().await; - subagents.keys().cloned().collect() - }; - - for id in subagent_ids { - if let Err(e) = self.terminate_subagent(&id).await { - error!("Failed to terminate subagent {}: {}", id, e); - } - } - - debug!("All subagents terminated"); - Ok(()) - } - - /// Get formatted conversation from a subagent - pub async fn get_subagent_conversation(&self, id: &str) -> Result { - let subagent = self - .get_subagent(id) - .await - .ok_or_else(|| anyhow!("Subagent {} not found", id))?; - - Ok(subagent.get_formatted_conversation().await) - } - - /// Clean up completed or failed subagents - pub async fn cleanup_completed_subagents(&self) -> Result { - let mut completed_ids = Vec::new(); - - // Find completed subagents - { - let subagents = self.subagents.read().await; - for (id, subagent) in subagents.iter() { - if subagent.is_completed().await { - completed_ids.push(id.clone()); - } - } - } - - // Remove completed subagents - let count = completed_ids.len(); - for id in completed_ids { - if let Err(e) = self.terminate_subagent(&id).await { - error!("Failed to cleanup completed subagent {}: {}", id, e); - } - } - - debug!("Cleaned up {} completed subagents", count); - Ok(count) - } - - /// Load a recipe from file - async fn load_recipe(&self, recipe_name: &str) -> Result { - // Try to load from current directory first - let recipe_path = if recipe_name.ends_with(".yaml") || recipe_name.ends_with(".yml") { - recipe_name.to_string() - } else { - format!("{}.yaml", recipe_name) - }; - - if Path::new(&recipe_path).exists() { - let content = tokio::fs::read_to_string(&recipe_path).await?; - let recipe: Recipe = serde_yaml::from_str(&content)?; - return Ok(recipe); - } - - // Try some common recipe locations - let common_paths = [ - format!("recipes/{}", recipe_path), - format!("./recipes/{}", recipe_path), - format!("../recipes/{}", recipe_path), - ]; - - for path in &common_paths { - if Path::new(path).exists() { - let content = tokio::fs::read_to_string(path).await?; - let recipe: Recipe = serde_yaml::from_str(&content)?; - return Ok(recipe); - } - } - - Err(anyhow!( - "Recipe file '{}' not found in current directory or common recipe locations", - recipe_name - )) - } - - /// Get count of active subagents - pub async fn get_active_count(&self) -> usize { - let subagents = self.subagents.read().await; - subagents.len() - } - - /// Check if a subagent exists - pub async fn has_subagent(&self, id: &str) -> bool { - let subagents = self.subagents.read().await; - subagents.contains_key(id) - } - - /// Run a complete subagent task (spawn, execute, cleanup) - #[instrument(skip(self, args, provider, extension_manager))] - pub async fn run_complete_subagent_task( - &self, - args: SpawnSubAgentArgs, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - debug!("Running complete subagent task"); - - // Create subagent config based on whether we have a recipe or instructions - let mut config = if let Some(recipe_name) = args.recipe_name { - debug!("Using recipe: {}", recipe_name); - // Load the recipe - let recipe = self.load_recipe(&recipe_name).await?; - SubAgentConfig::new_with_recipe(recipe) - } else if let Some(instructions) = args.instructions { - debug!("Using direct instructions"); - SubAgentConfig::new_with_instructions(instructions) - } else { - return Err(anyhow!( - "Either recipe_name or instructions must be provided" - )); - }; - - // Set default max_turns if not provided - let max_turns = args.max_turns.unwrap_or(10); - config = config.with_max_turns(max_turns); - - if let Some(timeout) = args.timeout_seconds { - config = config.with_timeout(timeout); - } - - // Create the subagent with the parent agent's provider - let (subagent, handle) = SubAgent::new( - config, - Arc::clone(&provider), - Arc::clone(&extension_manager), - self.mcp_notification_tx.clone(), - ) - .await?; - let subagent_id = subagent.id.clone(); - - // Store the subagent and its handle temporarily - { - let mut subagents = self.subagents.write().await; - subagents.insert(subagent_id.clone(), Arc::clone(&subagent)); - } - { - let mut handles = self.handles.lock().await; - handles.insert(subagent_id.clone(), handle); - } - - // Run the complete conversation - let mut conversation_result = String::new(); - let turn_count = 0; - let current_message = args.message.clone(); - - // For now, we just complete after one turn since we don't have a mechanism - // for the subagent to continue autonomously without user input - // In a future iteration, we could add logic for the subagent to continue - // working on multi-step tasks with proper turn management - match subagent - .reply_subagent( - current_message, - Arc::clone(&provider), - Arc::clone(&extension_manager), - ) - .await - { - Ok(response) => { - let response_text = response.as_concat_text(); - conversation_result.push_str(&format!( - "\n--- Turn {} ---\n{}", - turn_count + 1, - response_text - )); - conversation_result.push_str(&format!( - "\n[Task completed after {} turns]", - turn_count + 1 - )); - } - Err(e) => { - conversation_result - .push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); - } - } - - // Clean up the subagent - if let Err(e) = self.terminate_subagent(&subagent_id).await { - debug!("Failed to cleanup subagent {}: {}", subagent_id, e); - } - - // Return the complete conversation result - Ok(format!("Subagent task completed:\n{}", conversation_result)) - } -} - -impl Default for SubAgentManager { - fn default() -> Self { - // Create a dummy channel for default implementation - // In practice, this should not be used - SubAgentManager should be created - // with a proper MCP notification sender - let (tx, _rx) = mpsc::channel(1); - Self::new(tx) - } -} - -impl Drop for SubAgentManager { - fn drop(&mut self) { - // Note: In a real implementation, you might want to spawn a task to clean up - // subagents gracefully, but for now we'll rely on the Drop implementations - // of the individual components - debug!("SubAgentManager dropped"); - } -} diff --git a/crates/goose/src/agents/subagent_task_config.rs b/crates/goose/src/agents/subagent_task_config.rs new file mode 100644 index 000000000000..261fb82b6f5f --- /dev/null +++ b/crates/goose/src/agents/subagent_task_config.rs @@ -0,0 +1,55 @@ +use crate::agents::extension_manager::ExtensionManager; +use crate::providers::base::Provider; +use mcp_core::protocol::JsonRpcMessage; +use std::fmt; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; +use uuid::Uuid; + +/// Configuration for task execution with all necessary dependencies +#[derive(Clone)] +pub struct TaskConfig { + pub id: String, + pub provider: Option>, + pub extension_manager: Option>>, + pub mcp_tx: mpsc::Sender, + pub max_turns: Option, +} + +impl fmt::Debug for TaskConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskConfig") + .field("id", &self.id) + .field("provider", &"") + .field("extension_manager", &"") + .field("max_turns", &self.max_turns) + .finish() + } +} + +impl TaskConfig { + /// Create a new TaskConfig with all required dependencies + pub fn new( + provider: Option>, + extension_manager: Option>>, + mcp_tx: mpsc::Sender, + ) -> Self { + Self { + id: Uuid::new_v4().to_string(), + provider, + extension_manager, + mcp_tx, + max_turns: Some(10), + } + } + + /// Get a reference to the provider + pub fn provider(&self) -> Option<&Arc> { + self.provider.as_ref() + } + + /// Get a clone of the MCP sender + pub fn mcp_tx(&self) -> mpsc::Sender { + self.mcp_tx.clone() + } +} diff --git a/crates/goose/src/agents/subagent_tools.rs b/crates/goose/src/agents/subagent_tools.rs deleted file mode 100644 index f8f35f1ff0ba..000000000000 --- a/crates/goose/src/agents/subagent_tools.rs +++ /dev/null @@ -1,68 +0,0 @@ -use indoc::indoc; -use mcp_core::tool::{Tool, ToolAnnotations}; -use serde_json::json; - -pub const SUBAGENT_RUN_TASK_TOOL_NAME: &str = "subagent__run_task"; - -pub fn run_task_subagent_tool() -> Tool { - Tool::new( - SUBAGENT_RUN_TASK_TOOL_NAME.to_string(), - indoc! {r#" - Spawn a specialized subagent to handle a specific task completely and automatically. - - This tool creates a subagent, processes your task through a complete conversation, - and returns the final result. The subagent is automatically cleaned up after completion. - - You can configure the subagent in two ways: - 1. Using a recipe file that defines instructions, extensions, and behavior - 2. Providing direct instructions for ad-hoc tasks - - The subagent will work autonomously until the task is complete, it reaches max_turns, - or it encounters an error. You'll get the final result without needing to manage - the subagent lifecycle manually. - - Examples: - - "Convert these unittest files to pytest format: file1.py, file2.py" - - "Research the latest developments in AI and provide a comprehensive summary" - - "Review this code for security vulnerabilities and suggest fixes" - - "Refactor this legacy code to use modern Python patterns" - "#} - .to_string(), - json!({ - "type": "object", - "required": ["task"], - "properties": { - "recipe_name": { - "type": "string", - "description": "Name of the recipe file to configure the subagent (e.g., 'research_assistant_recipe.yaml'). Either this or 'instructions' must be provided." - }, - "instructions": { - "type": "string", - "description": "Direct instructions for the subagent's task. Either this or 'recipe_name' must be provided. Example: 'You are a code refactoring assistant. Help convert unittest tests to pytest format.'" - }, - "task": { - "type": "string", - "description": "The task description or initial message for the subagent to work on" - }, - "max_turns": { - "type": "integer", - "description": "Maximum number of conversation turns before auto-completion (default: 10)", - "minimum": 1, - "default": 10 - }, - "timeout_seconds": { - "type": "integer", - "description": "Optional timeout for the entire task in seconds", - "minimum": 1 - } - } - }), - Some(ToolAnnotations { - title: Some("Run subagent task".to_string()), - read_only_hint: false, - destructive_hint: false, - idempotent_hint: false, - open_world_hint: false, - }), - ) -} diff --git a/crates/goose/src/agents/subagent_types.rs b/crates/goose/src/agents/subagent_types.rs deleted file mode 100644 index 1fbc85563f74..000000000000 --- a/crates/goose/src/agents/subagent_types.rs +++ /dev/null @@ -1,42 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SpawnSubAgentArgs { - pub recipe_name: Option, - pub instructions: Option, - pub message: String, - pub max_turns: Option, - pub timeout_seconds: Option, -} - -impl SpawnSubAgentArgs { - pub fn new_with_recipe(recipe_name: String, message: String) -> Self { - Self { - recipe_name: Some(recipe_name), - instructions: None, - message, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn new_with_instructions(instructions: String, message: String) -> Self { - Self { - recipe_name: None, - instructions: Some(instructions), - message, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn with_max_turns(mut self, max_turns: usize) -> Self { - self.max_turns = Some(max_turns); - self - } - - pub fn with_timeout(mut self, timeout_seconds: u64) -> Self { - self.timeout_seconds = Some(timeout_seconds); - self - } -} diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 446d1f586004..ea997dfd5ebc 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -11,7 +11,8 @@ use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; use crate::message::{Message, ToolRequest}; use crate::permission::Permission; -use mcp_core::{Content, ToolResult}; +use mcp_core::ToolResult; +use rmcp::model::Content; // ToolCallResult combines the result of a tool call with an optional notification stream that // can be used to receive notifications from the tool. diff --git a/crates/goose/src/agents/types.rs b/crates/goose/src/agents/types.rs index 41711cd03536..4ad78f7f77f7 100644 --- a/crates/goose/src/agents/types.rs +++ b/crates/goose/src/agents/types.rs @@ -1,5 +1,6 @@ use crate::session; -use mcp_core::{Content, Tool, ToolResult}; +use mcp_core::{Tool, ToolResult}; +use rmcp::model::Content; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index 84ea104bb9bc..5b7d049df171 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -3,7 +3,7 @@ use crate::message::{Message, MessageContent}; use crate::providers::base::Provider; use crate::token_counter::{AsyncTokenCounter, TokenCounter}; use anyhow::Result; -use mcp_core::Role; +use rmcp::model::Role; use std::sync::Arc; // Constants for the summarization prompt and a follow-up user message. @@ -220,8 +220,10 @@ mod tests { use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; use chrono::Utc; - use mcp_core::{tool::Tool, Role}; - use mcp_core::{Content, TextContent, ToolCall}; + use mcp_core::tool::Tool; + use mcp_core::ToolCall; + use rmcp::model::Role; + use rmcp::model::{AnnotateAble, Content, RawTextContent}; use serde_json::json; use std::sync::Arc; @@ -250,10 +252,12 @@ mod tests { Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: "Summarized content".to_string(), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: "Summarized content".to_string(), + } + .no_annotation(), + )], ), ProviderUsage::new("mock".to_string(), Usage::default()), )) @@ -447,10 +451,12 @@ mod tests { let summarized_messages = vec![Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: "Summary".to_string(), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: "Summary".to_string(), + } + .no_annotation(), + )], )]; let arguments = json!({ "param1": "value1" diff --git a/crates/goose/src/context_mgmt/truncate.rs b/crates/goose/src/context_mgmt/truncate.rs index ba2f6490e0bb..cb20500b53ea 100644 --- a/crates/goose/src/context_mgmt/truncate.rs +++ b/crates/goose/src/context_mgmt/truncate.rs @@ -1,7 +1,9 @@ use crate::message::{Message, MessageContent}; +use crate::utils::safe_truncate; use anyhow::{anyhow, Result}; -use mcp_core::{Content, ResourceContents, Role}; +use rmcp::model::{RawContent, ResourceContents, Role}; use std::collections::HashSet; +use std::ops::DerefMut; use tracing::{debug, warn}; /// Maximum size for truncated content in characters @@ -75,11 +77,11 @@ fn truncate_message_content(message: &Message, max_content_size: usize) -> Resul for content in &mut new_message.content { match content { MessageContent::Text(text_content) => { - if text_content.text.len() > max_content_size { + if text_content.text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... content truncated from {} to {} characters ...]", - &text_content.text[..max_content_size.min(text_content.text.len())], - text_content.text.len(), + safe_truncate(&text_content.text, max_content_size), + text_content.text.chars().count(), max_content_size ); text_content.text = truncated; @@ -88,27 +90,29 @@ fn truncate_message_content(message: &Message, max_content_size: usize) -> Resul MessageContent::ToolResponse(tool_response) => { if let Ok(ref mut result) = tool_response.tool_result { for content_item in result { - if let Content::Text(ref mut text_content) = content_item { - if text_content.text.len() > max_content_size { + if let RawContent::Text(ref mut text_content) = content_item.deref_mut() { + if text_content.text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... tool response truncated from {} to {} characters ...]", - &text_content.text[..max_content_size.min(text_content.text.len())], - text_content.text.len(), + safe_truncate(&text_content.text, max_content_size), + text_content.text.chars().count(), max_content_size ); text_content.text = truncated; } } // Handle Resource content which might contain large text - else if let Content::Resource(ref mut resource_content) = content_item { + else if let RawContent::Resource(ref mut resource_content) = + content_item.deref_mut() + { if let ResourceContents::TextResourceContents { text, .. } = &mut resource_content.resource { - if text.len() > max_content_size { + if text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... resource content truncated from {} to {} characters ...]", - &text[..max_content_size.min(text.len())], - text.len(), + safe_truncate(text, max_content_size), + text.chars().count(), max_content_size ); *text = truncated; @@ -138,19 +142,21 @@ fn estimate_message_tokens(message: &Message, estimate_fn: &dyn Fn(&str) -> usiz MessageContent::ToolResponse(tool_response) => { if let Ok(ref result) = tool_response.tool_result { for content_item in result { - match content_item { - Content::Text(text_content) => { + match &content_item.raw { + RawContent::Text(text_content) => { total_tokens += estimate_fn(&text_content.text); } - Content::Resource(resource_content) => { - match &resource_content.resource { + RawContent::Resource(resource) => { + match &resource.resource { ResourceContents::TextResourceContents { text, .. } => { total_tokens += estimate_fn(text); } _ => total_tokens += 5, // Small overhead for other resource types } } - _ => total_tokens += 5, // Small overhead for other content types + _ => { + total_tokens += 5; // Small overhead for other content types + } } } } @@ -374,8 +380,8 @@ mod tests { use super::*; use crate::message::Message; use anyhow::Result; - use mcp_core::content::Content; use mcp_core::tool::ToolCall; + use rmcp::model::Content; use serde_json::json; // Helper function to create a user text message with a specified token count diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 83c4934d76fa..32b8da8027e9 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -4,9 +4,11 @@ pub mod context_mgmt; pub mod message; pub mod model; pub mod permission; +pub mod project; pub mod prompt_template; pub mod providers; pub mod recipe; +pub mod recipe_deeplink; pub mod scheduler; pub mod scheduler_factory; pub mod scheduler_trait; @@ -15,6 +17,7 @@ pub mod temporal_scheduler; pub mod token_counter; pub mod tool_monitor; pub mod tracing; +pub mod utils; #[cfg(test)] mod cron_test; diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index ef207eb1d01c..699b67aa1ff6 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -8,12 +8,14 @@ use std::collections::HashSet; /// The content of the messages uses MCP types to avoid additional conversions /// when interacting with MCP servers. use chrono::Utc; -use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; -use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole}; -use mcp_core::resource::ResourceContents; -use mcp_core::role::Role; use mcp_core::tool::ToolCall; +use rmcp::model::ResourceContents; +use rmcp::model::Role; +use rmcp::model::{ + AnnotateAble, Content, ImageContent, PromptMessage, PromptMessageContent, PromptMessageRole, + RawContent, RawImageContent, RawTextContent, TextContent, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; use utoipa::ToSchema; @@ -114,18 +116,17 @@ pub enum MessageContent { impl MessageContent { pub fn text>(text: S) -> Self { - MessageContent::Text(TextContent { - text: text.into(), - annotations: None, - }) + MessageContent::Text(RawTextContent { text: text.into() }.no_annotation()) } pub fn image, T: Into>(data: S, mime_type: T) -> Self { - MessageContent::Image(ImageContent { - data: data.into(), - mime_type: mime_type.into(), - annotations: None, - }) + MessageContent::Image( + RawImageContent { + data: data.into(), + mime_type: mime_type.into(), + } + .no_annotation(), + ) } pub fn tool_request>(id: S, tool_call: ToolResult) -> Self { @@ -220,7 +221,7 @@ impl MessageContent { if let Ok(contents) = &tool_response.tool_result { let texts: Vec = contents .iter() - .filter_map(|content| content.as_text().map(String::from)) + .filter_map(|content| content.as_text().map(|t| t.text.to_string())) .collect(); if !texts.is_empty() { return Some(texts.join("\n")); @@ -257,13 +258,25 @@ impl MessageContent { impl From for MessageContent { fn from(content: Content) -> Self { - match content { - Content::Text(text) => MessageContent::Text(text), - Content::Image(image) => MessageContent::Image(image), - Content::Resource(resource) => MessageContent::Text(TextContent { - text: resource.get_text(), - annotations: None, - }), + match content.raw { + RawContent::Text(text) => { + MessageContent::Text(text.optional_annotate(content.annotations)) + } + RawContent::Image(image) => { + MessageContent::Image(image.optional_annotate(content.annotations)) + } + RawContent::Resource(resource) => { + let text = match &resource.resource { + ResourceContents::TextResourceContents { text, .. } => text.clone(), + ResourceContents::BlobResourceContents { blob, .. } => { + format!("[Binary content: {}]", blob.clone()) + } + }; + MessageContent::text(text) + } + RawContent::Audio(_) => { + MessageContent::text("[Audio content: not supported]".to_string()) + } } } } @@ -280,16 +293,16 @@ impl From for Message { let content = match prompt_message.content { PromptMessageContent::Text { text } => MessageContent::text(text), PromptMessageContent::Image { image } => { - MessageContent::image(image.data, image.mime_type) + MessageContent::image(image.data.clone(), image.mime_type.clone()) } PromptMessageContent::Resource { resource } => { // For resources, convert to text content with the resource text - match resource.resource { + match &resource.resource { ResourceContents::TextResourceContents { text, .. } => { - MessageContent::text(text) + MessageContent::text(text.clone()) } ResourceContents::BlobResourceContents { blob, .. } => { - MessageContent::text(format!("[Binary content: {}]", blob)) + MessageContent::text(format!("[Binary content: {}]", blob.clone())) } } } @@ -512,10 +525,8 @@ impl Message { #[cfg(test)] mod tests { use super::*; - use mcp_core::content::EmbeddedResource; use mcp_core::handler::ToolError; - use mcp_core::prompt::PromptMessageContent; - use mcp_core::resource::ResourceContents; + use rmcp::model::{PromptMessage, PromptMessageContent, RawEmbeddedResource, ResourceContents}; use serde_json::{json, Value}; #[test] @@ -654,11 +665,11 @@ mod tests { #[test] fn test_from_prompt_message_image() { let prompt_content = PromptMessageContent::Image { - image: ImageContent { + image: RawImageContent { data: "base64data".to_string(), mime_type: "image/jpeg".to_string(), - annotations: None, - }, + } + .no_annotation(), }; let prompt_message = PromptMessage { @@ -685,10 +696,7 @@ mod tests { }; let prompt_content = PromptMessageContent::Resource { - resource: EmbeddedResource { - resource, - annotations: None, - }, + resource: RawEmbeddedResource { resource }.no_annotation(), }; let prompt_message = PromptMessage { @@ -714,10 +722,7 @@ mod tests { }; let prompt_content = PromptMessageContent::Resource { - resource: EmbeddedResource { - resource, - annotations: None, - }, + resource: RawEmbeddedResource { resource }.no_annotation(), }; let prompt_message = PromptMessage { diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 9fe73aecdb74..59f9976370c3 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -275,6 +275,7 @@ mod tests { } #[test] + #[serial_test::serial] fn test_model_config_context_limit_env_vars() { use temp_env::with_vars; diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index b8b5110e1820..6a452e24a5a4 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -5,8 +5,8 @@ use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::Provider; use chrono::Utc; use indoc::indoc; +use mcp_core::tool::Tool; use mcp_core::tool::ToolAnnotations; -use mcp_core::{tool::Tool, TextContent}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashSet; @@ -82,10 +82,9 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { .collect(); let mut check_messages = vec![]; check_messages.push(Message::new( - mcp_core::Role::User, + rmcp::model::Role::User, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: format!( + vec![MessageContent::text(format!( "Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \ \n\nGuidelines for Read-Only Operations: \ \n- Read-only operations do not modify any data or state. \ @@ -93,9 +92,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { \n- Write operations include INSERT, UPDATE, DELETE, and file writing. \ \n\nPlease provide a list of tool names that qualify as read-only:", tool_names.join(", "), - ), - annotations: None, - })], + ))], )); check_messages } @@ -270,7 +267,8 @@ mod tests { use crate::providers::errors::ProviderError; use chrono::Utc; use mcp_core::ToolCall; - use mcp_core::{tool::Tool, Role, ToolResult}; + use mcp_core::{tool::Tool, ToolResult}; + use rmcp::model::Role; use serde_json::json; use tempfile::NamedTempFile; diff --git a/crates/goose/src/project/mod.rs b/crates/goose/src/project/mod.rs new file mode 100644 index 000000000000..601b47df01c3 --- /dev/null +++ b/crates/goose/src/project/mod.rs @@ -0,0 +1,68 @@ +pub mod storage; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use utoipa::ToSchema; + +/// Main project structure that holds project metadata and associated sessions +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct Project { + /// Unique identifier for the project + pub id: String, + /// Display name of the project + pub name: String, + /// Optional description of the project + pub description: Option, + /// Default working directory for sessions in this project + #[schema(value_type = String, example = "/home/user/projects/my-project")] + pub default_directory: PathBuf, + /// When the project was created + pub created_at: DateTime, + /// When the project was last updated + pub updated_at: DateTime, + /// List of session IDs associated with this project + pub session_ids: Vec, +} + +/// Simplified project metadata for listing +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ProjectMetadata { + /// Unique identifier for the project + pub id: String, + /// Display name of the project + pub name: String, + /// Optional description of the project + pub description: Option, + /// Default working directory for sessions in this project + #[schema(value_type = String)] + pub default_directory: PathBuf, + /// Number of sessions in this project + pub session_count: usize, + /// When the project was created + pub created_at: DateTime, + /// When the project was last updated + pub updated_at: DateTime, +} + +impl From<&Project> for ProjectMetadata { + fn from(project: &Project) -> Self { + ProjectMetadata { + id: project.id.clone(), + name: project.name.clone(), + description: project.description.clone(), + default_directory: project.default_directory.clone(), + session_count: project.session_ids.len(), + created_at: project.created_at, + updated_at: project.updated_at, + } + } +} + +// Re-export storage functions +pub use storage::{ + add_session_to_project, create_project, delete_project, ensure_project_dir, get_project, + list_projects, remove_session_from_project, update_project, +}; diff --git a/crates/goose/src/project/storage.rs b/crates/goose/src/project/storage.rs new file mode 100644 index 000000000000..ef8e70dc2465 --- /dev/null +++ b/crates/goose/src/project/storage.rs @@ -0,0 +1,239 @@ +use crate::project::{Project, ProjectMetadata}; +use anyhow::{anyhow, Context, Result}; +use chrono::Utc; +use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; +use serde_json; +use std::fs::{self, File}; +use std::io::Write; +use std::path::PathBuf; +use tracing::{error, info}; + +const APP_NAME: &str = "goose"; + +/// Ensure the project directory exists and return its path +pub fn ensure_project_dir() -> Result { + let app_strategy = AppStrategyArgs { + top_level_domain: "Block".to_string(), + author: "Block".to_string(), + app_name: APP_NAME.to_string(), + }; + + let data_dir = choose_app_strategy(app_strategy) + .context("goose requires a home dir")? + .data_dir() + .join("projects"); + + if !data_dir.exists() { + fs::create_dir_all(&data_dir)?; + } + + Ok(data_dir) +} + +/// Generate a unique project ID +fn generate_project_id() -> String { + use rand::Rng; + let timestamp = Utc::now().timestamp(); + let random: u32 = rand::thread_rng().gen(); + format!("proj_{}_{}", timestamp, random) +} + +/// Get the path for a specific project file +fn get_project_path(project_id: &str) -> Result { + let project_dir = ensure_project_dir()?; + Ok(project_dir.join(format!("{}.json", project_id))) +} + +/// Create a new project +pub fn create_project( + name: String, + description: Option, + default_directory: PathBuf, +) -> Result { + let project_dir = ensure_project_dir()?; + + // Validate the default directory exists + if !default_directory.exists() { + return Err(anyhow!( + "Default directory does not exist: {:?}", + default_directory + )); + } + + let now = Utc::now(); + let project = Project { + id: generate_project_id(), + name, + description, + default_directory, + created_at: now, + updated_at: now, + session_ids: Vec::new(), + }; + + // Save the project + let project_path = project_dir.join(format!("{}.json", project.id)); + let mut file = File::create(&project_path)?; + let json = serde_json::to_string_pretty(&project)?; + file.write_all(json.as_bytes())?; + + info!("Created project {} at {:?}", project.id, project_path); + Ok(project) +} + +/// Update an existing project +pub fn update_project( + project_id: &str, + name: Option, + description: Option>, + default_directory: Option, +) -> Result { + let project_path = get_project_path(project_id)?; + + if !project_path.exists() { + return Err(anyhow!("Project not found: {}", project_id)); + } + + // Read existing project + let mut project: Project = serde_json::from_reader(File::open(&project_path)?)?; + + // Update fields + if let Some(new_name) = name { + project.name = new_name; + } + + if let Some(new_description) = description { + project.description = new_description; + } + + if let Some(new_directory) = default_directory { + if !new_directory.exists() { + return Err(anyhow!( + "Default directory does not exist: {:?}", + new_directory + )); + } + project.default_directory = new_directory; + } + + project.updated_at = Utc::now(); + + // Save updated project + let mut file = File::create(&project_path)?; + let json = serde_json::to_string_pretty(&project)?; + file.write_all(json.as_bytes())?; + + info!("Updated project {}", project_id); + Ok(project) +} + +/// Delete a project (does not delete associated sessions) +pub fn delete_project(project_id: &str) -> Result<()> { + let project_path = get_project_path(project_id)?; + + if !project_path.exists() { + return Err(anyhow!("Project not found: {}", project_id)); + } + + fs::remove_file(&project_path)?; + info!("Deleted project {}", project_id); + Ok(()) +} + +/// List all projects +pub fn list_projects() -> Result> { + let project_dir = ensure_project_dir()?; + let mut projects = Vec::new(); + + if let Ok(entries) = fs::read_dir(&project_dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("json") { + match serde_json::from_reader::<_, Project>(File::open(&path)?) { + Ok(project) => { + projects.push(ProjectMetadata::from(&project)); + } + Err(e) => { + error!("Failed to read project file {:?}: {}", path, e); + } + } + } + } + } + + // Sort by updated_at descending + projects.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + + Ok(projects) +} + +/// Get a specific project +pub fn get_project(project_id: &str) -> Result { + let project_path = get_project_path(project_id)?; + + if !project_path.exists() { + return Err(anyhow!("Project not found: {}", project_id)); + } + + let project: Project = serde_json::from_reader(File::open(&project_path)?)?; + Ok(project) +} + +/// Add a session to a project +pub fn add_session_to_project(project_id: &str, session_id: &str) -> Result<()> { + let project_path = get_project_path(project_id)?; + + if !project_path.exists() { + return Err(anyhow!("Project not found: {}", project_id)); + } + + // Read project + let mut project: Project = serde_json::from_reader(File::open(&project_path)?)?; + + // Check if session already exists in project + if project.session_ids.contains(&session_id.to_string()) { + return Ok(()); // Already added + } + + // Add session and update timestamp + project.session_ids.push(session_id.to_string()); + project.updated_at = Utc::now(); + + // Save updated project + let mut file = File::create(&project_path)?; + let json = serde_json::to_string_pretty(&project)?; + file.write_all(json.as_bytes())?; + + info!("Added session {} to project {}", session_id, project_id); + Ok(()) +} + +/// Remove a session from a project +pub fn remove_session_from_project(project_id: &str, session_id: &str) -> Result<()> { + let project_path = get_project_path(project_id)?; + + if !project_path.exists() { + return Err(anyhow!("Project not found: {}", project_id)); + } + + // Read project + let mut project: Project = serde_json::from_reader(File::open(&project_path)?)?; + + // Remove session + let original_len = project.session_ids.len(); + project.session_ids.retain(|id| id != session_id); + + if project.session_ids.len() == original_len { + return Ok(()); // Session wasn't in project + } + + project.updated_at = Utc::now(); + + // Save updated project + let mut file = File::create(&project_path)?; + let json = serde_json::to_string_pretty(&project)?; + file.write_all(json.as_bytes())?; + + info!("Removed session {} from project {}", session_id, project_id); + Ok(()) +} diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 88a71b0f145c..6a3c930460ce 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -1,13 +1,21 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; use axum::http::HeaderMap; +use futures::TryStreamExt; use reqwest::{Client, StatusCode}; use serde_json::Value; +use std::io; use std::time::Duration; +use tokio::pin; -use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use tokio_util::io::StreamReader; + +use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; -use super::formats::anthropic::{create_request, get_usage, response_to_message}; +use super::formats::anthropic::{ + create_request, get_usage, response_to_message, response_to_streaming_message, +}; use super::utils::{emit_debug_trace, get_model}; use crate::message::Message; use crate::model::ModelConfig; @@ -195,10 +203,17 @@ impl Provider for AnthropicProvider { // Parse response let message = response_to_message(response.clone())?; let usage = get_usage(&response)?; + tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let provider_usage = ProviderUsage::new(model, usage); + tracing::debug!( + "🔍 Anthropic non-streaming returning ProviderUsage: {:?}", + provider_usage + ); + Ok((message, provider_usage)) } /// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present @@ -232,4 +247,82 @@ impl Provider for AnthropicProvider { models.sort(); Ok(Some(models)) } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request(&self.model, system, messages, tools)?; + + // Add stream parameter + payload + .as_object_mut() + .unwrap() + .insert("stream".to_string(), Value::Bool(true)); + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-api-key", self.api_key.parse().unwrap()); + headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap()); + + let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); + if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled { + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta + headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap()); + } + + if self.model.model_name.starts_with("claude-3-7-sonnet-") { + // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use + headers.insert( + "anthropic-beta", + "token-efficient-tools-2025-02-19".parse().unwrap(), + ); + } + + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("v1/messages").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let response = self + .client + .post(url) + .headers(headers) + .json(&payload) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(ProviderError::RequestFailed(format!( + "Streaming request failed with status: {}. Error: {}", + status, error_text + ))); + } + + // Map reqwest error to io::Error + let stream = response.bytes_stream().map_err(io::Error::other); + + let model_config = self.model.clone(); + // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = futures::StreamExt::next(&mut message_stream).await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } + + fn supports_streaming(&self) -> bool { + true + } } diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 3825ef21a2ec..46c1f0062d88 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -252,14 +252,10 @@ impl Provider for AzureProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index c3510fe88008..260e1f22f581 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -41,6 +41,8 @@ pub struct ModelInfo { pub output_token_cost: Option, /// Currency for the costs (default: "$") pub currency: Option, + /// Whether this model supports cache control + pub supports_cache_control: Option, } impl ModelInfo { @@ -52,6 +54,7 @@ impl ModelInfo { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, } } @@ -68,6 +71,7 @@ impl ModelInfo { input_token_cost: Some(input_cost), output_token_cost: Some(output_cost), currency: Some("$".to_string()), + supports_cache_control: None, } } } @@ -115,6 +119,7 @@ impl ProviderMetadata { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }) .collect(), model_doc_link: model_doc_link.to_string(), @@ -290,6 +295,11 @@ pub trait Provider: Send + Sync { false } + /// Check if this provider supports cache control + fn supports_cache_control(&self) -> bool { + false + } + /// Create embeddings if supported. Default implementation returns an error. async fn create_embeddings(&self, _texts: Vec) -> Result>, ProviderError> { Err(ProviderError::ExecutionError( @@ -435,6 +445,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_eq!(info.context_limit, 1000); @@ -445,6 +456,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_eq!(info, info2); @@ -455,6 +467,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_ne!(info, info3); } diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 0ee203360178..563ea89d0d41 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -10,9 +10,8 @@ use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use mcp_core::content::TextContent; use mcp_core::tool::Tool; -use mcp_core::Role; +use rmcp::model::Role; pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "default"; pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["default"]; @@ -98,7 +97,7 @@ impl ClaudeCodeProvider { // Convert tool result contents to text let content_text = tool_contents .iter() - .filter_map(|content| content.as_text()) + .filter_map(|content| content.as_text().map(|t| t.text.clone())) .collect::>() .join("\n"); @@ -214,10 +213,7 @@ impl ClaudeCodeProvider { )); } - let message_content = vec![MessageContent::Text(TextContent { - text: combined_text, - annotations: None, - })]; + let message_content = vec![MessageContent::text(combined_text)]; let response_message = Message::new( Role::Assistant, @@ -330,7 +326,7 @@ impl ClaudeCodeProvider { // Extract the first user message text let description = messages .iter() - .find(|m| m.role == mcp_core::Role::User) + .find(|m| m.role == rmcp::model::Role::User) .and_then(|m| { m.content.iter().find_map(|c| match c { MessageContent::Text(text_content) => Some(&text_content.text), @@ -354,12 +350,9 @@ impl ClaudeCodeProvider { } let message = Message::new( - mcp_core::Role::Assistant, + rmcp::model::Role::Assistant, chrono::Utc::now().timestamp(), - vec![MessageContent::Text(mcp_core::content::TextContent { - text: description.clone(), - annotations: None, - })], + vec![MessageContent::text(description.clone())], ); let usage = Usage::default(); diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index e074d9393af8..99768c1972ae 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -13,13 +13,13 @@ use tokio_util::io::StreamReader; use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; -use super::formats::databricks::{create_request, get_usage, response_to_message}; +use super::formats::databricks::{create_request, response_to_message}; use super::oauth; use super::utils::{get_model, ImageFormat}; use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::databricks::response_to_streaming_message; +use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; use mcp_core::tool::Tool; use serde_json::json; use tokio::time::sleep; @@ -455,13 +455,10 @@ impl Provider for DatabricksProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match response.get("usage").map(get_usage) { - Some(usage) => usage, - None => { - tracing::debug!("Failed to get usage data"); - Usage::default() - } - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index e627505b657e..091425530df4 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -12,6 +12,7 @@ use super::{ google::GoogleProvider, groq::GroqProvider, lead_worker::LeadWorkerProvider, + litellm::LiteLLMProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, @@ -50,6 +51,7 @@ pub fn providers() -> Vec { // GithubCopilotProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), + LiteLLMProvider::metadata(), OllamaProvider::metadata(), OpenAiProvider::metadata(), OpenRouterProvider::metadata(), @@ -158,6 +160,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result> "databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)), "gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)), "groq" => Ok(Arc::new(GroqProvider::from_env(model)?)), + "litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)), "ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)), "openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)), "gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)), @@ -177,7 +180,7 @@ mod tests { use crate::message::{Message, MessageContent}; use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; use chrono::Utc; - use mcp_core::{content::TextContent, Role}; + use rmcp::model::{AnnotateAble, RawTextContent, Role}; use std::env; #[allow(dead_code)] @@ -215,13 +218,15 @@ mod tests { Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: format!( - "Response from {} with model {}", - self.name, self.model_config.model_name - ), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: format!( + "Response from {} with model {}", + self.name, self.model_config.model_name + ), + } + .no_annotation(), + )], ), ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()), )) diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index 661d2bce5156..3001a08d0387 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -3,9 +3,8 @@ use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; -use mcp_core::content::Content; -use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; +use rmcp::model::Role; use serde_json::{json, Value}; use std::collections::HashSet; @@ -69,10 +68,7 @@ pub fn format_messages(messages: &[Message]) -> Vec { Ok(result) => { let text = result .iter() - .filter_map(|c| match c { - Content::Text(t) => Some(t.text.clone()), - _ => None, - }) + .filter_map(|c| c.as_text().map(|t| t.text.clone())) .collect::>() .join("\n"); @@ -292,40 +288,68 @@ pub fn get_usage(data: &Value) -> Result { .and_then(|v| v.as_u64()) .unwrap_or(0); - // IMPORTANT: Based on the API responses, when caching is used: - // - input_tokens is ONLY the new/fresh tokens (can be very small, like 7) - // - cache_creation_input_tokens and cache_read_input_tokens are the cached content - // - These cached tokens are charged at different rates: - // * Fresh input tokens: 100% of regular price - // * Cache creation tokens: 125% of regular price - // * Cache read tokens: 10% of regular price - // - // Calculate effective input tokens for cost calculation based on Anthropic's pricing: - // - Fresh input tokens: 100% of regular price (1.0x) - // - Cache creation tokens: 125% of regular price (1.25x) - // - Cache read tokens: 10% of regular price (0.10x) - // - // The effective input tokens represent the cost-equivalent tokens when multiplied - // by the regular input price, ensuring accurate cost calculations in the frontend. - let effective_input_tokens = input_tokens as f64 * 1.0 - + cache_creation_tokens as f64 * 1.25 - + cache_read_tokens as f64 * 0.10; - - // For token counting purposes, we still want to show the actual total count - let _total_actual_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; - - // Return the effective input tokens for cost calculation - // This ensures the frontend cost calculation is accurate when multiplying by regular prices - let effective_input_i32 = effective_input_tokens.round().clamp(0.0, i32::MAX as f64) as i32; + // IMPORTANT: For display purposes, we want to show the ACTUAL total tokens consumed + // The cache pricing should only affect cost calculation, not token count display + let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; + + // Convert to i32 with bounds checking + let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32; let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32; let total_tokens_i32 = - (effective_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; + (total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; Ok(Usage::new( - Some(effective_input_i32), + Some(total_input_i32), Some(output_tokens_i32), Some(total_tokens_i32), )) + } else if data.as_object().is_some() { + // Check if the data itself is the usage object (for message_delta events that might have usage at top level) + let input_tokens = data + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let cache_creation_tokens = data + .get("cache_creation_input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let cache_read_tokens = data + .get("cache_read_input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let output_tokens = data + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + // If we found any token data, process it + if input_tokens > 0 + || cache_creation_tokens > 0 + || cache_read_tokens > 0 + || output_tokens > 0 + { + let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; + + let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32; + let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32; + let total_tokens_i32 = + (total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; + + tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}", + total_input_i32, output_tokens_i32, total_tokens_i32); + + Ok(Usage::new( + Some(total_input_i32), + Some(output_tokens_i32), + Some(total_tokens_i32), + )) + } else { + tracing::debug!("🔍 Anthropic no token data found in object"); + Ok(Usage::new(None, None, None)) + } } else { tracing::debug!( "Failed to get usage data: {}", @@ -414,6 +438,232 @@ pub fn create_request( Ok(payload) } +/// Process streaming response from Anthropic's API +pub fn response_to_streaming_message( + mut stream: S, +) -> impl futures::Stream< + Item = anyhow::Result<( + Option, + Option, + )>, +> + 'static +where + S: futures::Stream> + Unpin + Send + 'static, +{ + use async_stream::try_stream; + use futures::StreamExt; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug)] + struct StreamingEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(flatten)] + data: Value, + } + + try_stream! { + let mut accumulated_text = String::new(); + let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); + let mut current_tool_id: Option = None; + let mut final_usage: Option = None; + + while let Some(line_result) = stream.next().await { + let line = line_result?; + + // Skip empty lines and non-data lines + if line.trim().is_empty() || !line.starts_with("data: ") { + continue; + } + + let data_part = line.strip_prefix("data: ").unwrap_or(&line); + + // Handle end of stream + if data_part.trim() == "[DONE]" { + break; + } + + // Parse the JSON event + let event: StreamingEvent = match serde_json::from_str(data_part) { + Ok(event) => event, + Err(e) => { + tracing::debug!("Failed to parse streaming event: {} - Line: {}", e, data_part); + continue; + } + }; + + match event.event_type.as_str() { + "message_start" => { + // Message started, we can extract initial metadata and usage if needed + if let Some(message_data) = event.data.get("message") { + if let Some(usage_data) = message_data.get("usage") { + let usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic message_start parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); + let model = message_data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + } else { + tracing::debug!("🔍 Anthropic message_start has no usage data"); + } + } + continue; + } + "content_block_start" => { + // A new content block started + if let Some(content_block) = event.data.get("content_block") { + if content_block.get("type") == Some(&json!("tool_use")) { + if let Some(id) = content_block.get("id").and_then(|v| v.as_str()) { + current_tool_id = Some(id.to_string()); + if let Some(name) = content_block.get("name").and_then(|v| v.as_str()) { + accumulated_tool_calls.insert(id.to_string(), (name.to_string(), String::new())); + } + } + } + } + continue; + } + "content_block_delta" => { + if let Some(delta) = event.data.get("delta") { + if delta.get("type") == Some(&json!("text_delta")) { + // Text content delta + if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { + accumulated_text.push_str(text); + + // Yield partial text message + let message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::text(text)], + ); + yield (Some(message), None); + } + } else if delta.get("type") == Some(&json!("input_json_delta")) { + // Tool input delta + if let Some(tool_id) = ¤t_tool_id { + if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) { + if let Some((_name, args)) = accumulated_tool_calls.get_mut(tool_id) { + args.push_str(partial_json); + } + } + } + } + } + continue; + } + "content_block_stop" => { + // Content block finished + if let Some(tool_id) = current_tool_id.take() { + // Tool call finished, yield complete tool call + if let Some((name, args)) = accumulated_tool_calls.remove(&tool_id) { + let parsed_args = if args.is_empty() { + json!({}) + } else { + match serde_json::from_str::(&args) { + Ok(parsed) => parsed, + Err(_) => { + // If parsing fails, create an error tool request + let error = mcp_core::handler::ToolError::InvalidParameters( + format!("Could not parse tool arguments: {}", args) + ); + let message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::tool_request(tool_id, Err(error))], + ); + yield (Some(message), None); + continue; + } + } + }; + + let tool_call = ToolCall::new(&name, parsed_args); + let message = Message::new( + rmcp::model::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::tool_request(tool_id, Ok(tool_call))], + ); + yield (Some(message), None); + } + } + continue; + } + "message_delta" => { + // Message metadata delta (like stop_reason) and cumulative usage + tracing::debug!("🔍 Anthropic message_delta event data: {}", serde_json::to_string_pretty(&event.data).unwrap_or_else(|_| format!("{:?}", event.data))); + if let Some(usage_data) = event.data.get("usage") { + tracing::debug!("🔍 Anthropic message_delta usage data (cumulative): {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data))); + let delta_usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic message_delta parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + delta_usage.input_tokens, delta_usage.output_tokens, delta_usage.total_tokens); + + // IMPORTANT: message_delta usage should be MERGED with existing usage, not replace it + // message_start has input tokens, message_delta has output tokens + if let Some(existing_usage) = &final_usage { + let merged_input = existing_usage.usage.input_tokens.or(delta_usage.input_tokens); + let merged_output = delta_usage.output_tokens.or(existing_usage.usage.output_tokens); + let merged_total = match (merged_input, merged_output) { + (Some(input), Some(output)) => Some(input + output), + (Some(input), None) => Some(input), + (None, Some(output)) => Some(output), + (None, None) => None, + }; + + let merged_usage = crate::providers::base::Usage::new(merged_input, merged_output, merged_total); + final_usage = Some(crate::providers::base::ProviderUsage::new(existing_usage.model.clone(), merged_usage)); + tracing::debug!("🔍 Anthropic MERGED usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + merged_input, merged_output, merged_total); + } else { + // No existing usage, just use delta usage + let model = event.data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, delta_usage)); + tracing::debug!("🔍 Anthropic no existing usage, using delta usage"); + } + } else { + tracing::debug!("🔍 Anthropic message_delta event has no usage field"); + } + continue; + } + "message_stop" => { + // Message finished, extract final usage if available + if let Some(usage_data) = event.data.get("usage") { + tracing::debug!("🔍 Anthropic streaming usage data: {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data))); + let usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); + let model = event.data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + tracing::debug!("🔍 Anthropic final_usage created with model: {}", model); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + } else { + tracing::debug!("🔍 Anthropic message_stop event has no usage data"); + } + break; + } + _ => { + // Unknown event type, log and continue + tracing::debug!("Unknown streaming event type: {}", event.event_type); + continue; + } + } + } + + // Yield final usage information if available + if let Some(usage) = final_usage { + yield (None, Some(usage)); + } else { + tracing::debug!("🔍 Anthropic no final usage to yield"); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -449,9 +699,9 @@ mod tests { panic!("Expected Text content"); } - assert_eq!(usage.input_tokens, Some(27)); // 12 * 1.0 + 12 * 1.25 = 27 effective tokens + assert_eq!(usage.input_tokens, Some(24)); // 12 + 12 = 24 actual tokens assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(42)); // 27 + 15 + assert_eq!(usage.total_tokens, Some(39)); // 24 + 15 Ok(()) } @@ -492,9 +742,9 @@ mod tests { panic!("Expected ToolRequest content"); } - assert_eq!(usage.input_tokens, Some(34)); // 15 * 1.0 + 15 * 1.25 = 33.75 → 34 effective tokens + assert_eq!(usage.input_tokens, Some(30)); // 15 + 15 = 30 actual tokens assert_eq!(usage.output_tokens, Some(20)); - assert_eq!(usage.total_tokens, Some(54)); // 34 + 20 + assert_eq!(usage.total_tokens, Some(50)); // 30 + 20 Ok(()) } @@ -718,11 +968,11 @@ mod tests { let usage = get_usage(&response)?; - // Effective input tokens should be: - // 7 * 1.0 + 10000 * 1.25 + 5000 * 0.10 = 7 + 12500 + 500 = 13007 - assert_eq!(usage.input_tokens, Some(13007)); + // ACTUAL input tokens should be: + // 7 + 10000 + 5000 = 15007 total actual tokens + assert_eq!(usage.input_tokens, Some(15007)); assert_eq!(usage.output_tokens, Some(50)); - assert_eq!(usage.total_tokens, Some(13057)); // 13007 + 50 + assert_eq!(usage.total_tokens, Some(15057)); // 15007 + 50 Ok(()) } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 9a1651b57fc3..ae8840f27fbf 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -6,12 +6,12 @@ use aws_sdk_bedrockruntime::types as bedrock; use aws_smithy_types::{Document, Number}; use base64::Engine; use chrono::Utc; -use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{Tool, ToolCall, ToolError, ToolResult}; +use rmcp::model::{Content, RawContent, ResourceContents, Role}; use serde_json::Value; use super::super::base::Usage; use crate::message::{Message, MessageContent}; -use mcp_core::content::ImageContent; pub fn to_bedrock_message(message: &Message) -> Result { bedrock::Message::builder() @@ -33,7 +33,9 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result { bedrock::ContentBlock::Text("".to_string()) } - MessageContent::Image(image) => bedrock::ContentBlock::Image(to_bedrock_image(image)?), + MessageContent::Image(image) => { + bedrock::ContentBlock::Image(to_bedrock_image(&image.data, &image.mime_type)?) + } MessageContent::Thinking(_) => { // Thinking blocks are not supported in Bedrock - skip bedrock::ContentBlock::Text("".to_string()) @@ -88,7 +90,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result>()?, ), Err(_) => None, @@ -114,12 +116,14 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result Result { - Ok(match content { - Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), - Content::Image(image) => bedrock::ToolResultContentBlock::Image(to_bedrock_image(image)?), - Content::Resource(resource) => match &resource.resource { + Ok(match content.raw { + RawContent::Text(text) => bedrock::ToolResultContentBlock::Text(text.text), + RawContent::Image(image) => { + bedrock::ToolResultContentBlock::Image(to_bedrock_image(&image.data, &image.mime_type)?) + } + RawContent::Resource(resource) => match &resource.resource { ResourceContents::TextResourceContents { text, .. } => { match to_bedrock_document(tool_use_id, &resource.resource)? { Some(doc) => bedrock::ToolResultContentBlock::Document(doc), @@ -130,6 +134,7 @@ pub fn to_bedrock_tool_result_content_block( bail!("Blob resource content is not supported by Bedrock provider yet") } }, + RawContent::Audio(..) => bail!("Audio is not not supported by Bedrock provider"), }) } @@ -140,23 +145,23 @@ pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { } } -pub fn to_bedrock_image(image: &ImageContent) -> Result { +pub fn to_bedrock_image(data: &String, mime_type: &String) -> Result { // Extract format from MIME type - let format = match image.mime_type.as_str() { + let format = match mime_type.as_str() { "image/png" => bedrock::ImageFormat::Png, "image/jpeg" | "image/jpg" => bedrock::ImageFormat::Jpeg, "image/gif" => bedrock::ImageFormat::Gif, "image/webp" => bedrock::ImageFormat::Webp, _ => bail!( "Unsupported image format: {}. Bedrock supports png, jpeg, gif, webp", - image.mime_type + mime_type ), }; // Create image source with base64 data let source = bedrock::ImageSource::Bytes(aws_smithy_types::Blob::new( base64::prelude::BASE64_STANDARD - .decode(&image.data) + .decode(data) .map_err(|e| anyhow!("Failed to decode base64 image data: {}", e))?, )); @@ -347,7 +352,7 @@ pub fn from_bedrock_json(document: &Document) -> Result { mod tests { use super::*; use anyhow::Result; - use mcp_core::content::ImageContent; + use rmcp::model::{AnnotateAble, RawImageContent}; // Base64 encoded 1x1 PNG image for testing const TEST_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="; @@ -363,13 +368,13 @@ mod tests { ]; for mime_type in supported_formats { - let image = ImageContent { + let image = RawImageContent { data: TEST_IMAGE_BASE64.to_string(), mime_type: mime_type.to_string(), - annotations: None, - }; + } + .no_annotation(); - let result = to_bedrock_image(&image); + let result = to_bedrock_image(&image.data, &image.mime_type); assert!(result.is_ok(), "Failed to convert {} format", mime_type); } @@ -378,13 +383,13 @@ mod tests { #[test] fn test_to_bedrock_image_unsupported_format() { - let image = ImageContent { + let image = RawImageContent { data: TEST_IMAGE_BASE64.to_string(), mime_type: "image/bmp".to_string(), - annotations: None, - }; + } + .no_annotation(); - let result = to_bedrock_image(&image); + let result = to_bedrock_image(&image.data, &image.mime_type); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Unsupported image format: image/bmp")); @@ -393,13 +398,13 @@ mod tests { #[test] fn test_to_bedrock_image_invalid_base64() { - let image = ImageContent { + let image = RawImageContent { data: "invalid_base64_data!!!".to_string(), mime_type: "image/png".to_string(), - annotations: None, - }; + } + .no_annotation(); - let result = to_bedrock_image(&image); + let result = to_bedrock_image(&image.data, &image.mime_type); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Failed to decode base64 image data")); @@ -407,11 +412,11 @@ mod tests { #[test] fn test_to_bedrock_message_content_image() -> Result<()> { - let image = ImageContent { + let image = RawImageContent { data: TEST_IMAGE_BASE64.to_string(), mime_type: "image/png".to_string(), - annotations: None, - }; + } + .no_annotation(); let message_content = MessageContent::Image(image); let result = to_bedrock_message_content(&message_content)?; @@ -424,14 +429,8 @@ mod tests { #[test] fn test_to_bedrock_tool_result_content_block_image() -> Result<()> { - let image = ImageContent { - data: TEST_IMAGE_BASE64.to_string(), - mime_type: "image/png".to_string(), - annotations: None, - }; - - let content = Content::Image(image); - let result = to_bedrock_tool_result_content_block("test_id", &content)?; + let content = Content::image(TEST_IMAGE_BASE64.to_string(), "image/png".to_string()); + let result = to_bedrock_tool_result_content_block("test_id", content)?; // Verify the wrapper correctly converts Content::Image to ToolResultContentBlock::Image assert!(matches!(result, bedrock::ToolResultContentBlock::Image(_))); diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 8c462b624e9b..10a59cc1b6a0 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1,15 +1,14 @@ use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; -use async_stream::try_stream; -use futures::Stream; use mcp_core::ToolError; -use mcp_core::{Content, Role, Tool, ToolCall}; +use mcp_core::{Tool, ToolCall}; +use rmcp::model::Role; +use rmcp::model::{AnnotateAble, Content, RawContent, ResourceContents}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -129,7 +128,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< .audience() .is_none_or(|audience| audience.contains(&Role::Assistant)) }) - .map(|content| content.unannotated()) + .map(|content| content.raw.clone()) .collect(); // Process all content, replacing images with placeholder text @@ -138,30 +137,33 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< for content in abridged { match content { - Content::Image(image) => { + RawContent::Image(image) => { // Add placeholder text in the tool response tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); // Create a separate image message image_messages.push(json!({ "role": "user", - "content": [convert_image(&image, image_format)] + "content": [convert_image(&image.no_annotation(), image_format)] })); } - Content::Resource(resource) => { - tool_content.push(Content::text(resource.get_text())); + RawContent::Resource(resource) => { + let text = match &resource.resource { + ResourceContents::TextResourceContents { + text, .. + } => text.clone(), + _ => String::new(), + }; + tool_content.push(Content::text(text)); } _ => { - tool_content.push(content); + tool_content.push(content.no_annotation()); } } } let tool_response_content: Value = json!(tool_content .iter() - .map(|content| match content { - Content::Text(text) => text.text.clone(), - _ => String::new(), - }) + .filter_map(|content| content.as_text().map(|t| t.text.clone())) .collect::>() .join(" ")); @@ -404,140 +406,6 @@ struct StreamingChunk { model: String, } -fn strip_data_prefix(line: &str) -> Option<&str> { - line.strip_prefix("data: ").map(|s| s.trim()) -} - -pub fn response_to_streaming_message( - mut stream: S, -) -> impl Stream, Option)>> + 'static -where - S: Stream> + Unpin + Send + 'static, -{ - try_stream! { - use futures::StreamExt; - - 'outer: while let Some(response) = stream.next().await { - if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { - break 'outer; - } - let response_str = response?; - let line = strip_data_prefix(&response_str); - - if line.is_none() || line.is_some_and(|l| l.is_empty()) { - continue - } - - let chunk: StreamingChunk = serde_json::from_str(line - .ok_or_else(|| anyhow!("unexpected stream format"))?) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; - let model = chunk.model.clone(); - - let usage = chunk.usage.as_ref().map(|u| { - ProviderUsage { - usage: get_usage(u), - model, - } - }); - - if chunk.choices.is_empty() { - yield (None, usage) - } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { - let tool_call = &tool_calls[0]; - let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?; - let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?; - let mut arguments = tool_call.function.arguments.clone(); - - while let Some(response_chunk) = stream.next().await { - if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") { - break 'outer; - } - let response_str = response_chunk?; - if let Some(line) = strip_data_prefix(&response_str) { - let tool_chunk: StreamingChunk = serde_json::from_str(line) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; - let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref() - .and_then(|calls| calls.first()) - .map(|call| call.function.arguments.as_str()); - if let Some(more_args) = more_args { - arguments.push_str(more_args); - } else { - break; - } - } - } - - let parsed = if arguments.is_empty() { - Ok(json!({})) - } else { - serde_json::from_str::(&arguments) - }; - - let content = match parsed { - Ok(params) => MessageContent::tool_request( - id, - Ok(ToolCall::new(function_name, params)), - ), - Err(e) => { - let error = ToolError::InvalidParameters(format!( - "Could not interpret tool use parameters for id {}: {}", - id, e - )); - MessageContent::tool_request(id, Err(error)) - } - }; - - yield ( - Some(Message { - id: chunk.id, - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![content], - }), - usage, - ) - } else if let Some(text) = &chunk.choices[0].delta.content { - yield ( - Some(Message { - id: chunk.id, - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![MessageContent::text(text)], - }), - if chunk.choices[0].finish_reason.is_some() { - usage - } else { - None - }, - ) - } - } - } -} - -pub fn get_usage(usage: &Value) -> Usage { - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Usage::new(input_tokens, output_tokens, total_tokens) -} - /// Validates and fixes tool schemas to ensure they have proper parameter structure. /// If parameters exist, ensures they have properties and required fields, or removes parameters entirely. pub fn validate_tool_schemas(tools: &mut [Value]) { @@ -721,7 +589,6 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; - use mcp_core::content::Content; use serde_json::json; #[test] diff --git a/crates/goose/src/providers/formats/gcpvertexai.rs b/crates/goose/src/providers/formats/gcpvertexai.rs index d83d193991d7..8b2b4f36f25d 100644 --- a/crates/goose/src/providers/formats/gcpvertexai.rs +++ b/crates/goose/src/providers/formats/gcpvertexai.rs @@ -83,6 +83,8 @@ pub enum ClaudeVersion { Haiku35, /// Claude Sonnet 4 Sonnet4, + /// Claude Opus 4 + Opus4, /// Generic Claude model for custom or new versions Generic(String), } @@ -102,6 +104,10 @@ pub enum GeminiVersion { Flash25Preview, /// Gemini 2.5 Pro Preview version Pro25Preview, + /// Gemini 2.5 Flash version + Flash25, + /// Gemini 2.5 Pro version + Pro25, /// Generic Gemini model for custom or new versions Generic(String), } @@ -115,6 +121,7 @@ impl fmt::Display for GcpVertexAIModel { ClaudeVersion::Sonnet37 => "claude-3-7-sonnet@20250219", ClaudeVersion::Haiku35 => "claude-3-5-haiku@20241022", ClaudeVersion::Sonnet4 => "claude-sonnet-4@20250514", + ClaudeVersion::Opus4 => "claude-opus-4@20250514", ClaudeVersion::Generic(name) => name, }, Self::Gemini(version) => match version { @@ -124,6 +131,8 @@ impl fmt::Display for GcpVertexAIModel { GeminiVersion::Pro25Exp => "gemini-2.5-pro-exp-03-25", GeminiVersion::Flash25Preview => "gemini-2.5-flash-preview-05-20", GeminiVersion::Pro25Preview => "gemini-2.5-pro-preview-05-06", + GeminiVersion::Flash25 => "gemini-2.5-flash", + GeminiVersion::Pro25 => "gemini-2.5-pro", GeminiVersion::Generic(name) => name, }, }; @@ -156,12 +165,15 @@ impl TryFrom<&str> for GcpVertexAIModel { "claude-3-7-sonnet@20250219" => Ok(Self::Claude(ClaudeVersion::Sonnet37)), "claude-3-5-haiku@20241022" => Ok(Self::Claude(ClaudeVersion::Haiku35)), "claude-sonnet-4@20250514" => Ok(Self::Claude(ClaudeVersion::Sonnet4)), + "claude-opus-4@20250514" => Ok(Self::Claude(ClaudeVersion::Opus4)), "gemini-1.5-pro-002" => Ok(Self::Gemini(GeminiVersion::Pro15)), "gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)), "gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)), "gemini-2.5-pro-exp-03-25" => Ok(Self::Gemini(GeminiVersion::Pro25Exp)), "gemini-2.5-flash-preview-05-20" => Ok(Self::Gemini(GeminiVersion::Flash25Preview)), "gemini-2.5-pro-preview-05-06" => Ok(Self::Gemini(GeminiVersion::Pro25Preview)), + "gemini-2.5-flash" => Ok(Self::Gemini(GeminiVersion::Flash25)), + "gemini-2.5-pro" => Ok(Self::Gemini(GeminiVersion::Pro25)), // Generic models based on prefix matching _ if s.starts_with("claude-") => { Ok(Self::Claude(ClaudeVersion::Generic(s.to_string()))) @@ -427,7 +439,7 @@ mod tests { } // Test generic Gemini models - let gemini_models = ["gemini-3-pro", "gemini-2.5-flash", "gemini-experimental"]; + let gemini_models = ["gemini-3-pro", "gemini-2.0-flash", "gemini-experimental"]; for model_id in gemini_models { let model = GcpVertexAIModel::try_from(model_id)?; diff --git a/crates/goose/src/providers/formats/google.rs b/crates/goose/src/providers/formats/google.rs index 6c801d504538..2d0c4871aeb3 100644 --- a/crates/goose/src/providers/formats/google.rs +++ b/crates/goose/src/providers/formats/google.rs @@ -4,11 +4,11 @@ use crate::providers::base::Usage; use crate::providers::errors::ProviderError; use crate::providers::utils::{is_valid_function_name, sanitize_function_name}; use anyhow::Result; -use mcp_core::content::Content; -use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; use rand::{distributions::Alphanumeric, Rng}; +use rmcp::model::{AnnotateAble, RawContent, Role}; use serde_json::{json, Map, Value}; +use std::ops::Deref; /// Convert internal Message format to Google's API message specification pub fn format_messages(messages: &[Message]) -> Vec { @@ -66,13 +66,13 @@ pub fn format_messages(messages: &[Message]) -> Vec { audience.contains(&Role::Assistant) }) }) - .map(|content| content.unannotated()) + .map(|content| content.raw.clone()) .collect(); let mut tool_content = Vec::new(); for content in abridged { match content { - Content::Image(image) => { + RawContent::Image(image) => { parts.push(json!({ "inline_data": { "mime_type": image.mime_type, @@ -81,15 +81,20 @@ pub fn format_messages(messages: &[Message]) -> Vec { })); } _ => { - tool_content.push(content); + tool_content.push(content.no_annotation()); } } } let mut text = tool_content .iter() - .filter_map(|c| match c { - Content::Text(t) => Some(t.text.clone()), - Content::Resource(r) => Some(r.get_text()), + .filter_map(|c| match c.deref() { + RawContent::Text(t) => Some(t.text.clone()), + RawContent::Resource(raw_embedded_resource) => Some( + raw_embedded_resource + .clone() + .no_annotation() + .get_text(), + ), _ => None, }) .collect::>() @@ -313,6 +318,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; + use rmcp::model::Content; use serde_json::json; fn set_up_text_message(text: &str, role: Role) -> Message { diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index ce929253405c..d6b62933cd96 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1,15 +1,57 @@ use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; +use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; +use async_stream::try_stream; +use futures::Stream; use mcp_core::ToolError; -use mcp_core::{Content, Role, Tool, ToolCall}; +use mcp_core::{Tool, ToolCall}; +use rmcp::model::Role; +use rmcp::model::{AnnotateAble, Content, RawContent, ResourceContents}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use std::ops::Deref; + +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCallFunction { + name: Option, + arguments: String, // chunk of encoded JSON, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCall { + id: Option, + function: DeltaToolCallFunction, + index: Option, + r#type: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct Delta { + content: Option, + role: Option, + tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChoice { + delta: Delta, + index: Option, + finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChunk { + choices: Vec, + created: Option, + id: Option, + usage: Option, + model: String, +} /// Convert internal Message format to OpenAI's API message specification /// some openai compatible endpoints use the anthropic image spec at the content level @@ -95,7 +137,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< .audience() .is_none_or(|audience| audience.contains(&Role::Assistant)) }) - .map(|content| content.unannotated()) + .cloned() .collect(); // Process all content, replacing images with placeholder text @@ -103,19 +145,25 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< let mut image_messages = Vec::new(); for content in abridged { - match content { - Content::Image(image) => { + match content.deref() { + RawContent::Image(image) => { // Add placeholder text in the tool response tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); // Create a separate image message image_messages.push(json!({ "role": "user", - "content": [convert_image(&image, image_format)] + "content": [convert_image(&image.clone().no_annotation(), image_format)] })); } - Content::Resource(resource) => { - tool_content.push(Content::text(resource.get_text())); + RawContent::Resource(resource) => { + let text = match &resource.resource { + ResourceContents::TextResourceContents { + text, .. + } => text.clone(), + _ => String::new(), + }; + tool_content.push(Content::text(text)); } _ => { tool_content.push(content); @@ -124,8 +172,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } let tool_response_content: Value = json!(tool_content .iter() - .map(|content| match content { - Content::Text(text) => text.text.clone(), + .map(|content| match content.deref() { + RawContent::Text(text) => text.text.clone(), _ => String::new(), }) .collect::>() @@ -281,11 +329,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result { )) } -pub fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?; - +pub fn get_usage(usage: &Value) -> Usage { let input_tokens = usage .get("prompt_tokens") .and_then(|v| v.as_i64()) @@ -305,7 +349,7 @@ pub fn get_usage(data: &Value) -> Result { _ => None, }); - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + Usage::new(input_tokens, output_tokens, total_tokens) } /// Validates and fixes tool schemas to ensure they have proper parameter structure. @@ -354,6 +398,117 @@ fn ensure_valid_json_schema(schema: &mut Value) { } } +fn strip_data_prefix(line: &str) -> Option<&str> { + line.strip_prefix("data: ").map(|s| s.trim()) +} + +pub fn response_to_streaming_message( + mut stream: S, +) -> impl Stream, Option)>> + 'static +where + S: Stream> + Unpin + Send + 'static, +{ + try_stream! { + use futures::StreamExt; + + 'outer: while let Some(response) = stream.next().await { + if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response?; + let line = strip_data_prefix(&response_str); + + if line.is_none() || line.is_some_and(|l| l.is_empty()) { + continue + } + + let chunk: StreamingChunk = serde_json::from_str(line + .ok_or_else(|| anyhow!("unexpected stream format"))?) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let model = chunk.model.clone(); + + let usage = chunk.usage.as_ref().map(|u| { + ProviderUsage { + usage: get_usage(u), + model, + } + }); + + if chunk.choices.is_empty() { + yield (None, usage) + } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { + let tool_call = &tool_calls[0]; + let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?; + let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?; + let mut arguments = tool_call.function.arguments.clone(); + + while let Some(response_chunk) = stream.next().await { + if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response_chunk?; + if let Some(line) = strip_data_prefix(&response_str) { + let tool_chunk: StreamingChunk = serde_json::from_str(line) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref() + .and_then(|calls| calls.first()) + .map(|call| call.function.arguments.as_str()); + if let Some(more_args) = more_args { + arguments.push_str(more_args); + } else { + break; + } + } + } + + let parsed = if arguments.is_empty() { + Ok(json!({})) + } else { + serde_json::from_str::(&arguments) + }; + + let content = match parsed { + Ok(params) => MessageContent::tool_request( + id, + Ok(ToolCall::new(function_name, params)), + ), + Err(e) => { + let error = ToolError::InvalidParameters(format!( + "Could not interpret tool use parameters for id {}: {}", + id, e + )); + MessageContent::tool_request(id, Err(error)) + } + }; + + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![content], + }), + usage, + ) + } else if let Some(text) = &chunk.choices[0].delta.content { + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text(text)], + }), + if chunk.choices[0].finish_reason.is_some() { + usage + } else { + None + }, + ) + } + } + } +} + pub fn create_request( model_config: &ModelConfig, system: &str, @@ -453,7 +608,6 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; - use mcp_core::content::Content; use serde_json::json; #[test] diff --git a/crates/goose/src/providers/formats/snowflake.rs b/crates/goose/src/providers/formats/snowflake.rs index c7cfe5924bb4..d29e8448cc64 100644 --- a/crates/goose/src/providers/formats/snowflake.rs +++ b/crates/goose/src/providers/formats/snowflake.rs @@ -3,9 +3,8 @@ use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; -use mcp_core::content::Content; -use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; +use rmcp::model::Role; use serde_json::{json, Value}; use std::collections::HashSet; @@ -39,10 +38,7 @@ pub fn format_messages(messages: &[Message]) -> Vec { if let Ok(result) = &tool_response.tool_result { let text = result .iter() - .filter_map(|c| match c { - Content::Text(t) => Some(t.text.clone()), - _ => None, - }) + .filter_map(|c| c.as_text().map(|t| t.text.clone())) .collect::>() .join("\n"); diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 6385ec299abc..9f92dc569bc8 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -2,6 +2,7 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; +use once_cell::sync::Lazy; use reqwest::{Client, StatusCode}; use serde_json::Value; use tokio::time::sleep; @@ -34,6 +35,9 @@ const DEFAULT_MAX_RETRIES: usize = 6; const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; /// Default maximum interval for retry (in milliseconds) const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; +/// Status code for Anthropic's API overloaded error (529) +static STATUS_API_OVERLOADED: Lazy = + Lazy::new(|| StatusCode::from_u16(529).expect("Valid status code 529 for API_OVERLOADED")); /// Represents errors specific to GCP Vertex AI operations. #[derive(Debug, thiserror::Error)] @@ -50,8 +54,10 @@ enum GcpVertexAIError { /// Retry configuration for handling rate limit errors #[derive(Debug, Clone)] struct RetryConfig { - /// Maximum number of retry attempts - max_retries: usize, + /// Maximum number of retry attempts for 429 errors + max_rate_limit_retries: usize, + /// Maximum number of retry attempts for 529 errors + max_overloaded_retries: usize, /// Initial interval between retries in milliseconds initial_interval_ms: u64, /// Multiplier for backoff (exponential) @@ -63,7 +69,8 @@ struct RetryConfig { impl Default for RetryConfig { fn default() -> Self { Self { - max_retries: DEFAULT_MAX_RETRIES, + max_rate_limit_retries: DEFAULT_MAX_RETRIES, + max_overloaded_retries: DEFAULT_MAX_RETRIES, initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, @@ -92,6 +99,19 @@ impl RetryConfig { Duration::from_millis(jittered_delay_ms) } + + /// Get max retries for a specific error type + #[allow(dead_code)] // Used in tests + fn max_retries_for_status(&self, status: StatusCode) -> usize { + if status == StatusCode::TOO_MANY_REQUESTS { + self.max_rate_limit_retries + } else if status == *STATUS_API_OVERLOADED { + self.max_overloaded_retries + } else { + // Default to rate limit retries for any other status code + self.max_rate_limit_retries + } + } } /// Provider implementation for Google Cloud Platform's Vertex AI service. @@ -172,10 +192,32 @@ impl GcpVertexAIProvider { /// Loads retry configuration from environment variables or uses defaults. fn load_retry_config(config: &crate::config::Config) -> RetryConfig { - let max_retries = config - .get_param("GCP_MAX_RETRIES") + // Load max retries for 429 rate limit errors + let max_rate_limit_retries = config + .get_param("GCP_MAX_RATE_LIMIT_RETRIES") + .ok() + .and_then(|v: String| v.parse::().ok()) + .or_else(|| { + // Fall back to generic GCP_MAX_RETRIES if specific one isn't set + config + .get_param("GCP_MAX_RETRIES") + .ok() + .and_then(|v: String| v.parse::().ok()) + }) + .unwrap_or(DEFAULT_MAX_RETRIES); + + // Load max retries for 529 API overloaded errors + let max_overloaded_retries = config + .get_param("GCP_MAX_OVERLOADED_RETRIES") .ok() .and_then(|v: String| v.parse::().ok()) + .or_else(|| { + // Fall back to generic GCP_MAX_RETRIES if specific one isn't set + config + .get_param("GCP_MAX_RETRIES") + .ok() + .and_then(|v: String| v.parse::().ok()) + }) .unwrap_or(DEFAULT_MAX_RETRIES); let initial_interval_ms = config @@ -197,7 +239,8 @@ impl GcpVertexAIProvider { .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); RetryConfig { - max_retries, + max_rate_limit_retries, + max_overloaded_retries, initial_interval_ms, backoff_multiplier, max_interval_ms, @@ -269,7 +312,7 @@ impl GcpVertexAIProvider { } /// Makes an authenticated POST request to the Vertex AI API at a specific location. - /// Includes retry logic for 429 Too Many Requests errors. + /// Includes retry logic for 429 (Too Many Requests) and 529 (API Overloaded) errors. /// /// # Arguments /// * `payload` - The request payload to send @@ -285,21 +328,12 @@ impl GcpVertexAIProvider { .build_request_url(context.provider(), location) .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - // Initialize retry counter - let mut attempts = 0; + // Initialize separate counters for different error types + let mut rate_limit_attempts = 0; + let mut overloaded_attempts = 0; let mut last_error = None; loop { - // Check if we've exceeded max retries - if attempts > 0 && attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting (429)", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - // Get a fresh auth token for each attempt let auth_header = self .get_auth_header() @@ -318,60 +352,116 @@ impl GcpVertexAIProvider { let status = response.status(); - // If not a 429, process normally - if status != StatusCode::TOO_MANY_REQUESTS { - let response_json = response.json::().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to parse response: {e}")) - })?; - - return match status { - StatusCode::OK => Ok(response_json), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - tracing::debug!( - "Authentication failed. Status: {status}, Payload: {payload:?}" + // Handle 429 Too Many Requests and 529 API Overloaded errors + match status { + status if status == StatusCode::TOO_MANY_REQUESTS => { + rate_limit_attempts += 1; + + if rate_limit_attempts > self.retry_config.max_rate_limit_retries { + let error_msg = format!( + "Exceeded maximum retry attempts ({}) for rate limiting (429) errors", + self.retry_config.max_rate_limit_retries + ); + tracing::error!("{}", error_msg); + return Err( + last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)) ); - Err(ProviderError::Authentication(format!( - "Authentication failed: {response_json:?}" - ))) } - _ => { - tracing::debug!( - "Request failed. Status: {status}, Response: {response_json:?}" + + // Try to parse response for more detailed error info + let cite_gcp_vertex_429 = + "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; + let response_text = response.text().await.unwrap_or_default(); + + let error_message = + if response_text.contains("Exceeded the Provisioned Throughput") { + // Handle 429 rate limit due to throughput limits + format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}") + } else { + // Handle generic 429 rate limit + format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}") + }; + + tracing::warn!( + "Rate limit exceeded error (429) (attempt {}/{}): {}. Retrying after backoff...", + rate_limit_attempts, + self.retry_config.max_rate_limit_retries, + error_message + ); + + // Store the error in case we need to return it after max retries + last_error = Some(ProviderError::RateLimitExceeded(error_message)); + + // Calculate and apply the backoff delay + let delay = self.retry_config.delay_for_attempt(rate_limit_attempts); + tracing::info!("Backing off for {:?} before retry (rate limit 429)", delay); + sleep(delay).await; + } + status if status == *STATUS_API_OVERLOADED => { + overloaded_attempts += 1; + + if overloaded_attempts > self.retry_config.max_overloaded_retries { + let error_msg = format!( + "Exceeded maximum retry attempts ({}) for API overloaded (529) errors", + self.retry_config.max_overloaded_retries + ); + tracing::error!("{}", error_msg); + return Err( + last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)) ); - Err(ProviderError::RequestFailed(format!( - "Request failed with status {status}: {response_json:?}" - ))) } - }; - } - // Handle 429 Too Many Requests - attempts += 1; - - // Try to parse response for more detailed error info - let cite_gcp_vertex_429 = - "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; - let response_text = response.text().await.unwrap_or_default(); - let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") { - format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.") - } else { - format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.") - }; - - tracing::warn!( - "Rate limit exceeded (attempt {}/{}): {}. Retrying after backoff...", - attempts, - self.retry_config.max_retries, - quota_error - ); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded(quota_error)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; + // Handle 529 Overloaded error (https://docs.anthropic.com/en/api/errors) + let error_message = + "Vertex AI Provider API is temporarily overloaded. This is similar to a rate limit \ + error but indicates backend processing capacity issues." + .to_string(); + + tracing::warn!( + "API overloaded error (529) (attempt {}/{}): {}. Retrying after backoff...", + overloaded_attempts, + self.retry_config.max_overloaded_retries, + error_message + ); + + // Store the error in case we need to return it after max retries + last_error = Some(ProviderError::RateLimitExceeded(error_message)); + + // Calculate and apply the backoff delay + let delay = self.retry_config.delay_for_attempt(overloaded_attempts); + tracing::info!( + "Backing off for {:?} before retry (API overloaded 529)", + delay + ); + sleep(delay).await; + } + // For any other status codes, process normally + _ => { + let response_json = response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response: {e}")) + })?; + + return match status { + StatusCode::OK => Ok(response_json), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + tracing::debug!( + "Authentication failed. Status: {status}, Payload: {payload:?}" + ); + Err(ProviderError::Authentication(format!( + "Authentication failed: {response_json:?}" + ))) + } + _ => { + tracing::debug!( + "Request failed. Status: {status}, Response: {response_json:?}" + ); + Err(ProviderError::RequestFailed(format!( + "Request failed with status {status}: {response_json:?}" + ))) + } + }; + } + } } } @@ -431,12 +521,15 @@ impl Provider for GcpVertexAIProvider { GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37), GcpVertexAIModel::Claude(ClaudeVersion::Haiku35), GcpVertexAIModel::Claude(ClaudeVersion::Sonnet4), + GcpVertexAIModel::Claude(ClaudeVersion::Opus4), GcpVertexAIModel::Gemini(GeminiVersion::Pro15), GcpVertexAIModel::Gemini(GeminiVersion::Flash20), GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp), GcpVertexAIModel::Gemini(GeminiVersion::Pro25Exp), GcpVertexAIModel::Gemini(GeminiVersion::Flash25Preview), GcpVertexAIModel::Gemini(GeminiVersion::Pro25Preview), + GcpVertexAIModel::Gemini(GeminiVersion::Flash25), + GcpVertexAIModel::Gemini(GeminiVersion::Pro25), ] .iter() .map(|model| model.to_string()) @@ -448,7 +541,7 @@ impl Provider for GcpVertexAIProvider { "gcp_vertex_ai", "GCP Vertex AI", "Access variety of AI models such as Claude, Gemini through Vertex AI", - GcpVertexAIModel::Gemini(GeminiVersion::Flash20) + GcpVertexAIModel::Gemini(GeminiVersion::Flash25) .to_string() .as_str(), known_models, @@ -456,6 +549,18 @@ impl Provider for GcpVertexAIProvider { vec![ ConfigKey::new("GCP_PROJECT_ID", true, false, None), ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())), + ConfigKey::new( + "GCP_MAX_RATE_LIMIT_RETRIES", + false, + false, + Some(&DEFAULT_MAX_RETRIES.to_string()), + ), + ConfigKey::new( + "GCP_MAX_OVERLOADED_RETRIES", + false, + false, + Some(&DEFAULT_MAX_RETRIES.to_string()), + ), ConfigKey::new( "GCP_MAX_RETRIES", false, @@ -525,11 +630,13 @@ impl Provider for GcpVertexAIProvider { #[cfg(test)] mod tests { use super::*; + use reqwest::StatusCode; #[test] fn test_retry_config_delay_calculation() { let config = RetryConfig { - max_retries: 5, + max_rate_limit_retries: 5, + max_overloaded_retries: 5, initial_interval_ms: 1000, backoff_multiplier: 2.0, max_interval_ms: 32000, @@ -552,6 +659,44 @@ mod tests { assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter) } + #[test] + fn test_max_retries_for_status() { + let config = RetryConfig { + max_rate_limit_retries: 5, + max_overloaded_retries: 10, + initial_interval_ms: 1000, + backoff_multiplier: 2.0, + max_interval_ms: 32000, + }; + + // Check that we get the right max retries for each error type + assert_eq!( + config.max_retries_for_status(StatusCode::TOO_MANY_REQUESTS), + 5 + ); + assert_eq!(config.max_retries_for_status(*STATUS_API_OVERLOADED), 10); + + // For any other status code, we should get the rate limit retries + assert_eq!(config.max_retries_for_status(StatusCode::BAD_REQUEST), 5); + } + + #[test] + fn test_status_overloaded_code() { + // Test that we correctly handle the 529 status code + + // Verify the custom status code is created correctly + assert_eq!(STATUS_API_OVERLOADED.as_u16(), 529); + + // This is not a standard HTTP status code, so it's classified as server error + assert!(STATUS_API_OVERLOADED.is_server_error()); + + // Should be different from TOO_MANY_REQUESTS (429) + assert_ne!(*STATUS_API_OVERLOADED, StatusCode::TOO_MANY_REQUESTS); + + // Should be different from SERVICE_UNAVAILABLE (503) + assert_ne!(*STATUS_API_OVERLOADED, StatusCode::SERVICE_UNAVAILABLE); + } + #[test] fn test_model_provider_conversion() { assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic"); @@ -596,7 +741,8 @@ mod tests { .collect(); assert!(model_names.contains(&"claude-3-5-sonnet-v2@20241022".to_string())); assert!(model_names.contains(&"gemini-1.5-pro-002".to_string())); - // Should contain the original 2 config keys plus 4 new retry-related ones - assert_eq!(metadata.config_keys.len(), 6); + assert!(model_names.contains(&"gemini-2.5-pro".to_string())); + // Should contain the original 2 config keys plus 6 new retry-related ones + assert_eq!(metadata.config_keys.len(), 8); } } diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index fc696bab21a8..f3175b7da2ed 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -10,9 +10,8 @@ use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use mcp_core::content::TextContent; use mcp_core::tool::Tool; -use mcp_core::Role; +use rmcp::model::Role; pub const GEMINI_CLI_DEFAULT_MODEL: &str = "default"; pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &["default"]; @@ -172,10 +171,7 @@ impl GeminiCliProvider { let message = Message::new( Role::Assistant, chrono::Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: response_text, - annotations: None, - })], + vec![MessageContent::text(response_text)], ); let usage = Usage::default(); // No usage info available for gemini CLI @@ -217,10 +213,7 @@ impl GeminiCliProvider { let message = Message::new( Role::Assistant, chrono::Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: description.clone(), - annotations: None, - })], + vec![MessageContent::text(description.clone())], ); let usage = Usage::default(); diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 97bd3ad589e0..ef7a9fbecc43 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -415,14 +415,10 @@ impl Provider for GithubCopilotProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 3716df0e6dc3..9c8c5af9fecf 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -139,14 +139,10 @@ impl Provider for GroqProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index ea892342c8d8..5d993b525b04 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use std::ops::Deref; use std::sync::Arc; use tokio::sync::Mutex; @@ -7,7 +8,8 @@ use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderU use super::errors::ProviderError; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use mcp_core::{tool::Tool, Content}; +use mcp_core::tool::Tool; +use rmcp::model::{Content, RawContent}; /// A provider that switches between a lead model and a worker model based on turn count /// and can fallback to lead model on consecutive failures @@ -239,7 +241,7 @@ impl LeadWorkerProvider { /// Check if tool output contains error indicators fn contains_error_indicators(&self, contents: &[Content]) -> bool { for content in contents { - if let Content::Text(text_content) = content { + if let RawContent::Text(text_content) = content.deref() { let text_lower = text_content.text.to_lowercase(); // Common error patterns in tool outputs @@ -455,7 +457,7 @@ mod tests { use crate::message::MessageContent; use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; use chrono::Utc; - use mcp_core::{content::TextContent, Role}; + use rmcp::model::{AnnotateAble, RawTextContent, Role}; #[derive(Clone)] struct MockProvider { @@ -483,10 +485,12 @@ mod tests { Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: format!("Response from {}", self.name), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: format!("Response from {}", self.name), + } + .no_annotation(), + )], ), ProviderUsage::new(self.name.clone(), Usage::default()), )) @@ -646,10 +650,12 @@ mod tests { Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: format!("Response from {}", self.name), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: format!("Response from {}", self.name), + } + .no_annotation(), + )], ), ProviderUsage::new(self.name.clone(), Usage::default()), )) diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs new file mode 100644 index 000000000000..591337bac6e7 --- /dev/null +++ b/crates/goose/src/providers/litellm.rs @@ -0,0 +1,357 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::time::Duration; +use url::Url; + +use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use super::embedding::EmbeddingCapable; +use super::errors::ProviderError; +use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::message::Message; +use crate::model::ModelConfig; +use mcp_core::tool::Tool; + +pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini"; +pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/"; + +#[derive(Debug, serde::Serialize)] +pub struct LiteLLMProvider { + #[serde(skip)] + client: Client, + host: String, + base_path: String, + api_key: String, + model: ModelConfig, + custom_headers: Option>, +} + +impl Default for LiteLLMProvider { + fn default() -> Self { + let model = ModelConfig::new(LiteLLMProvider::metadata().default_model); + LiteLLMProvider::from_env(model).expect("Failed to initialize LiteLLM provider") + } +} + +impl LiteLLMProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let api_key: String = config + .get_secret("LITELLM_API_KEY") + .unwrap_or_else(|_| String::new()); + let host: String = config + .get_param("LITELLM_HOST") + .unwrap_or_else(|_| "https://api.litellm.ai".to_string()); + let base_path: String = config + .get_param("LITELLM_BASE_PATH") + .unwrap_or_else(|_| "v1/chat/completions".to_string()); + let custom_headers: Option> = config + .get_secret("LITELLM_CUSTOM_HEADERS") + .or_else(|_| config.get_param("LITELLM_CUSTOM_HEADERS")) + .ok() + .map(parse_custom_headers); + let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600); + let client = Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build()?; + + Ok(Self { + client, + host, + base_path, + api_key, + model, + custom_headers, + }) + } + + fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(custom_headers) = &self.custom_headers { + for (key, value) in custom_headers { + request = request.header(key, value); + } + } + + request + } + + async fn fetch_models(&self) -> Result, ProviderError> { + let models_url = format!("{}/model/info", self.host); + + let mut req = self + .client + .get(&models_url) + .header("Authorization", format!("Bearer {}", self.api_key)); + + req = self.add_headers(req); + + let response = req + .send() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; + + if !response.status().is_success() { + return Err(ProviderError::RequestFailed(format!( + "Models endpoint returned status: {}", + response.status() + ))); + } + + let response_json: Value = response.json().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse models response: {}", e)) + })?; + + let models_data = response_json["data"].as_array().ok_or_else(|| { + ProviderError::RequestFailed("Missing data field in models response".to_string()) + })?; + + let mut models = Vec::new(); + for model_data in models_data { + if let Some(model_name) = model_data["model_name"].as_str() { + if model_name.contains("/*") { + continue; + } + + let model_info = &model_data["model_info"]; + let context_length = + model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize; + let supports_cache_control = model_info["supports_prompt_caching"].as_bool(); + + let mut model_info_obj = ModelInfo::new(model_name, context_length); + model_info_obj.supports_cache_control = supports_cache_control; + models.push(model_info_obj); + } + } + + Ok(models) + } + + async fn post(&self, payload: Value) -> Result { + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join(&self.base_path).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let request = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)); + + let request = self.add_headers(request); + + let response = request.json(&payload).send().await?; + + handle_response_openai_compat(response).await + } +} + +#[async_trait] +impl Provider for LiteLLMProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "litellm", + "LiteLLM", + "LiteLLM proxy supporting multiple models with automatic prompt caching", + LITELLM_DEFAULT_MODEL, + vec![], + LITELLM_DOC_URL, + vec![ + ConfigKey::new("LITELLM_API_KEY", false, true, None), + ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")), + ConfigKey::new( + "LITELLM_BASE_PATH", + true, + false, + Some("v1/chat/completions"), + ), + ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None), + ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument(skip_all, name = "provider_complete")] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let mut payload = super::formats::openai::create_request( + &self.model, + system, + messages, + tools, + &ImageFormat::OpenAi, + )?; + + if self.supports_cache_control() { + payload = update_request_for_cache_control(&payload); + } + + let response = self.post(payload.clone()).await?; + + let message = super::formats::openai::response_to_message(response.clone())?; + let usage = super::formats::openai::get_usage(&response); + let model = get_model(&response); + emit_debug_trace(&self.model, &payload, &response, &usage); + Ok((message, ProviderUsage::new(model, usage))) + } + + fn supports_embeddings(&self) -> bool { + true + } + + fn supports_cache_control(&self) -> bool { + if let Ok(models) = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.fetch_models()) + }) { + if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) { + return model_info.supports_cache_control.unwrap_or(false); + } + } + + self.model.model_name.to_lowercase().contains("claude") + } + + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + match self.fetch_models().await { + Ok(models) => { + let model_names: Vec = models.into_iter().map(|m| m.name).collect(); + Ok(Some(model_names)) + } + Err(e) => { + tracing::warn!("Failed to fetch models from LiteLLM: {}", e); + Ok(None) + } + } + } +} + +#[async_trait] +impl EmbeddingCapable for LiteLLMProvider { + async fn create_embeddings(&self, texts: Vec) -> Result>, anyhow::Error> { + let endpoint = format!("{}/v1/embeddings", self.host); + + let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + + let payload = json!({ + "input": texts, + "model": embedding_model, + "encoding_format": "float" + }); + + let mut req = self + .client + .post(&endpoint) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload); + + req = self.add_headers(req); + + let response = req.send().await?; + let response_text = response.text().await?; + let response_json: Value = serde_json::from_str(&response_text)?; + + let data = response_json["data"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Missing data field"))?; + + let mut embeddings = Vec::new(); + for item in data { + let embedding: Vec = item["embedding"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Missing embedding field"))? + .iter() + .map(|v| v.as_f64().unwrap_or(0.0) as f32) + .collect(); + embeddings.push(embedding); + } + + Ok(embeddings) + } +} + +/// Updates the request payload to include cache control headers for automatic prompt caching +/// Adds ephemeral cache control to the last 2 user messages, system message, and last tool +pub fn update_request_for_cache_control(original_payload: &Value) -> Value { + let mut payload = original_payload.clone(); + + if let Some(messages_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("messages")) + .and_then(|messages| messages.as_array_mut()) + { + let mut user_count = 0; + for message in messages_spec.iter_mut().rev() { + if message.get("role") == Some(&json!("user")) { + if let Some(content) = message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *content = json!([{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }]); + } + } + user_count += 1; + if user_count >= 2 { + break; + } + } + } + + if let Some(system_message) = messages_spec + .iter_mut() + .find(|msg| msg.get("role") == Some(&json!("system"))) + { + if let Some(content) = system_message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *system_message = json!({ + "role": "system", + "content": [{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }] + }); + } + } + } + } + + if let Some(tools_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("tools")) + .and_then(|tools| tools.as_array_mut()) + { + if let Some(last_tool) = tools_spec.last_mut() { + if let Some(function) = last_tool.get_mut("function") { + function + .as_object_mut() + .unwrap() + .insert("cache_control".to_string(), json!({ "type": "ephemeral" })); + } + } + } + payload +} + +fn parse_custom_headers(headers_str: String) -> HashMap { + let mut headers = HashMap::new(); + for line in headers_str.lines() { + if let Some((key, value)) = line.split_once(':') { + headers.insert(key.trim().to_string(), value.trim().to_string()); + } + } + headers +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index a7748044a3bd..4b155c69f277 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -16,6 +16,7 @@ pub mod githubcopilot; pub mod google; pub mod groq; pub mod lead_worker; +pub mod litellm; pub mod oauth; pub mod ollama; pub mod openai; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 4bbf1c392dae..bd18d593adcc 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -141,14 +141,10 @@ impl Provider for OllamaProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 9884d147bffc..767a8f47913a 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,9 +1,16 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; -use reqwest::Client; -use serde_json::Value; +use futures::TryStreamExt; +use reqwest::{Client, Response}; +use serde_json::{json, Value}; use std::collections::HashMap; +use std::io; use std::time::Duration; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; @@ -12,6 +19,9 @@ use super::formats::openai::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::base::MessageStream; +use crate::providers::formats::openai::response_to_streaming_message; +use crate::providers::utils::handle_status_openai_compat; use mcp_core::tool::Tool; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; @@ -103,7 +113,7 @@ impl OpenAiProvider { request } - async fn post(&self, payload: Value) -> Result { + async fn post(&self, payload: Value) -> Result { let base_url = url::Url::parse(&self.host) .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; let url = base_url.join(&self.base_path).map_err(|e| { @@ -117,9 +127,7 @@ impl OpenAiProvider { let request = self.add_headers(request); - let response = request.json(&payload).send().await?; - - handle_response_openai_compat(response).await + Ok(request.json(&payload).send().await?) } } @@ -170,18 +178,14 @@ impl Provider for OpenAiProvider { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; // Make request - let response = self.post(payload.clone()).await?; + let response = handle_response_openai_compat(self.post(payload.clone()).await?).await?; // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) @@ -236,6 +240,43 @@ impl Provider for OpenAiProvider { .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = + create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + payload["stream"] = serde_json::Value::Bool(true); + payload["stream_options"] = json!({ + "include_usage": true, + }); + + let response = handle_status_openai_compat(self.post(payload.clone()).await?).await?; + + let stream = response.bytes_stream().map_err(io::Error::other); + + let model_config = self.model.clone(); + // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } } fn parse_custom_headers(s: String) -> HashMap { diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 0352012fc462..782e5adf9210 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -199,23 +199,20 @@ fn update_request_for_anthropic(original_payload: &Value) -> Value { } fn create_request_based_on_model( - model_config: &ModelConfig, + provider: &OpenRouterProvider, system: &str, messages: &[Message], tools: &[Tool], ) -> anyhow::Result { let mut payload = create_request( - model_config, + &provider.model, system, messages, tools, &super::utils::ImageFormat::OpenAi, )?; - if model_config - .model_name - .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) - { + if provider.supports_cache_control() { payload = update_request_for_anthropic(&payload); } @@ -259,21 +256,17 @@ impl Provider for OpenRouterProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { // Create the base payload - let payload = create_request_based_on_model(&self.model, system, messages, tools)?; + let payload = create_request_based_on_model(self, system, messages, tools)?; // Make request let response = self.post(payload.clone()).await?; // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) @@ -369,4 +362,10 @@ impl Provider for OpenRouterProvider { models.sort(); Ok(Some(models)) } + + fn supports_cache_control(&self) -> bool { + self.model + .model_name + .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) + } } diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index c2ced2a57fde..d5da10583e89 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -16,8 +16,7 @@ use super::utils::emit_debug_trace; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use chrono::Utc; -use mcp_core::content::TextContent; -use mcp_core::role::Role; +use rmcp::model::Role; pub const SAGEMAKER_TGI_DOC_LINK: &str = "https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html"; @@ -206,10 +205,7 @@ impl SageMakerTgiProvider { Ok(Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: clean_text, - annotations: None, - })], + vec![MessageContent::text(clean_text)], )) } diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index 0647d0a06e74..f07f655667ac 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -38,9 +38,10 @@ use crate::model::ModelConfig; use crate::providers::formats::openai::create_request; use anyhow::Result; use mcp_core::tool::{Tool, ToolCall}; -use mcp_core::Content; use reqwest::Client; +use rmcp::model::RawContent; use serde_json::{json, Value}; +use std::ops::Deref; use std::time::Duration; use uuid::Uuid; @@ -340,8 +341,8 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec { Ok(contents) => { let text_contents: Vec = contents .iter() - .filter_map(|c| match c { - Content::Text(t) => Some(t.text.clone()), + .filter_map(|c| match c.deref() { + RawContent::Text(t) => Some(t.text.clone()), _ => None, }) .collect(); diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 7bcc172b6458..c76cdfcb894f 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -5,13 +5,13 @@ use anyhow::Result; use base64::Engine; use regex::Regex; use reqwest::{Response, StatusCode}; +use rmcp::model::{AnnotateAble, ImageContent, RawImageContent}; use serde::{Deserialize, Serialize}; use serde_json::{from_value, json, Map, Value}; use std::io::Read; use std::path::Path; use crate::providers::errors::{OpenAIError, ProviderError}; -use mcp_core::content::ImageContent; #[derive(serde::Deserialize)] struct OpenAIErrorResponse { @@ -47,48 +47,57 @@ pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value /// Handle response from OpenAI compatible endpoints /// Error codes: https://platform.openai.com/docs/guides/error-codes /// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 -pub async fn handle_response_openai_compat(response: Response) -> Result { +pub async fn handle_status_openai_compat(response: Response) -> Result { let status = response.status(); - // Try to parse the response body as JSON (if applicable) - let payload = match response.json::().await { - Ok(json) => json, - Err(e) => return Err(ProviderError::RequestFailed(e.to_string())), - }; match status { - StatusCode::OK => Ok(payload), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - if let Ok(err_resp) = from_value::(payload) { - let err = err_resp.error; - if err.is_context_length_exceeded() { - return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + StatusCode::OK => Ok(response), + _ => { + let body = response.json::().await; + match (body, status) { + (Err(e), _) => Err(ProviderError::RequestFailed(e.to_string())), + (Ok(body), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) => { + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", status, body))) + } + (Ok(body), StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) + ); + if let Ok(err_resp) = from_value::(body) { + let err = err_resp.error; + if err.is_context_length_exceeded() { + return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + } + return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); + } + Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) + } + (Ok(body), StatusCode::TOO_MANY_REQUESTS) => { + Err(ProviderError::RateLimitExceeded(format!("{:?}", body))) + } + (Ok(body), StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE) => { + Err(ProviderError::ServerError(format!("{:?}", body))) + } + (Ok(body), _) => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) } - return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); } - Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) } } } +pub async fn handle_response_openai_compat(response: Response) -> Result { + let response = handle_status_openai_compat(response).await?; + + response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e)) + }) +} + /// Check if the model is a Google model based on the "model" field in the payload. /// /// ### Arguments @@ -283,11 +292,11 @@ pub fn load_image_file(path: &str) -> Result { // Convert to base64 let data = base64::prelude::BASE64_STANDARD.encode(&bytes); - Ok(ImageContent { + Ok(RawImageContent { mime_type: mime_type.to_string(), data, - annotations: None, - }) + } + .no_annotation()) } pub fn unescape_json_values(value: &Value) -> Value { diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 9046cdd58c91..5d1eab5eb831 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -10,7 +10,8 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use mcp_core::{tool::Tool, Role, ToolCall, ToolResult}; +use mcp_core::{tool::Tool, ToolCall, ToolResult}; +use rmcp::model::Role; // ---------- Capability Flags ---------- #[derive(Debug)] diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 7e91a23f8b9e..cdaebdc0b1b8 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -166,14 +166,10 @@ impl Provider for XaiProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/recipe/build_recipe/mod.rs b/crates/goose/src/recipe/build_recipe/mod.rs new file mode 100644 index 000000000000..d567f1cad9a6 --- /dev/null +++ b/crates/goose/src/recipe/build_recipe/mod.rs @@ -0,0 +1,189 @@ +use crate::recipe::read_recipe_file_content::RecipeFile; +use crate::recipe::template_recipe::{parse_recipe_content, render_recipe_content_with_params}; +use crate::recipe::{ + Recipe, RecipeParameter, RecipeParameterRequirement, BUILT_IN_RECIPE_DIR_PARAM, +}; +use anyhow::Result; +use std::collections::{HashMap, HashSet}; + +#[derive(Debug, thiserror::Error)] +pub enum RecipeError { + #[error("Missing required parameters: {parameters:?}")] + MissingParams { parameters: Vec }, + #[error("Template rendering failed: {source}")] + TemplateRendering { source: anyhow::Error }, + #[error("Recipe parsing failed: {source}")] + RecipeParsing { source: anyhow::Error }, +} + +pub fn render_recipe_template( + recipe_file: RecipeFile, + params: Vec<(String, String)>, + user_prompt_fn: Option, +) -> Result<(String, Vec)> +where + F: Fn(&str, &str) -> Result, +{ + let RecipeFile { + content: recipe_file_content, + parent_dir: recipe_parent_dir, + .. + } = recipe_file; + let recipe_dir_str = recipe_parent_dir + .to_str() + .ok_or_else(|| anyhow::anyhow!("Error getting recipe directory"))?; + let recipe_parameters = validate_recipe_parameters(&recipe_file_content, recipe_dir_str)?; + + let (params_for_template, missing_params) = + apply_values_to_parameters(¶ms, recipe_parameters, recipe_dir_str, user_prompt_fn)?; + + let rendered_content = if missing_params.is_empty() { + render_recipe_content_with_params(&recipe_file_content, ¶ms_for_template)? + } else { + String::new() + }; + + Ok((rendered_content, missing_params)) +} + +pub fn validate_recipe_parameters( + recipe_file_content: &str, + recipe_dir_str: &str, +) -> Result>> { + let (raw_recipe, template_variables) = + parse_recipe_content(recipe_file_content, recipe_dir_str.to_string())?; + let recipe_parameters = raw_recipe.parameters; + validate_optional_parameters(&recipe_parameters)?; + validate_parameters_in_template(&recipe_parameters, &template_variables)?; + Ok(recipe_parameters) +} + +pub fn build_recipe_from_template( + recipe_file: RecipeFile, + params: Vec<(String, String)>, + user_prompt_fn: Option, +) -> Result +where + F: Fn(&str, &str) -> Result, +{ + let (rendered_content, missing_params) = + render_recipe_template(recipe_file, params.clone(), user_prompt_fn) + .map_err(|source| RecipeError::TemplateRendering { source })?; + + if !missing_params.is_empty() { + return Err(RecipeError::MissingParams { + parameters: missing_params, + }); + } + + let recipe = Recipe::from_content(&rendered_content) + .map_err(|source| RecipeError::RecipeParsing { source })?; + Ok(recipe) +} + +fn validate_parameters_in_template( + recipe_parameters: &Option>, + template_variables: &HashSet, +) -> Result<()> { + let mut template_variables = template_variables.clone(); + template_variables.remove(BUILT_IN_RECIPE_DIR_PARAM); + + let param_keys: HashSet = recipe_parameters + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|p| p.key.clone()) + .collect(); + + let missing_keys = template_variables + .difference(¶m_keys) + .collect::>(); + + let extra_keys = param_keys + .difference(&template_variables) + .collect::>(); + + if missing_keys.is_empty() && extra_keys.is_empty() { + return Ok(()); + } + + let mut message = String::new(); + + if !missing_keys.is_empty() { + message.push_str(&format!( + "Missing definitions for parameters in the recipe file: {}.", + missing_keys + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", ") + )); + } + + if !extra_keys.is_empty() { + message.push_str(&format!( + "\nUnnecessary parameter definitions: {}.", + extra_keys + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", ") + )); + } + Err(anyhow::anyhow!("{}", message.trim_end())) +} + +fn validate_optional_parameters(parameters: &Option>) -> Result<()> { + let optional_params_without_default_values: Vec = parameters + .as_ref() + .unwrap_or(&vec![]) + .iter() + .filter(|p| { + matches!(p.requirement, RecipeParameterRequirement::Optional) && p.default.is_none() + }) + .map(|p| p.key.clone()) + .collect(); + + if optional_params_without_default_values.is_empty() { + Ok(()) + } else { + Err(anyhow::anyhow!("Optional parameters missing default values in the recipe: {}. Please provide defaults.", optional_params_without_default_values.join(", "))) + } +} + +pub fn apply_values_to_parameters( + user_params: &[(String, String)], + recipe_parameters: Option>, + recipe_parent_dir: &str, + user_prompt_fn: Option, +) -> Result<(HashMap, Vec)> +where + F: Fn(&str, &str) -> Result, +{ + let mut param_map: HashMap = user_params.iter().cloned().collect(); + param_map.insert( + BUILT_IN_RECIPE_DIR_PARAM.to_string(), + recipe_parent_dir.to_string(), + ); + let mut missing_params: Vec = Vec::new(); + for param in recipe_parameters.unwrap_or_default() { + if !param_map.contains_key(¶m.key) { + match (¶m.default, ¶m.requirement) { + (Some(default), _) => param_map.insert(param.key.clone(), default.clone()), + (None, RecipeParameterRequirement::UserPrompt) if user_prompt_fn.is_some() => { + let input_value = + user_prompt_fn.as_ref().unwrap()(¶m.key, ¶m.description)?; + param_map.insert(param.key.clone(), input_value) + } + _ => { + missing_params.push(param.key.clone()); + None + } + }; + } + } + Ok((param_map, missing_params)) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/recipe/build_recipe/tests.rs b/crates/goose/src/recipe/build_recipe/tests.rs new file mode 100644 index 000000000000..acd6dbbe89a5 --- /dev/null +++ b/crates/goose/src/recipe/build_recipe/tests.rs @@ -0,0 +1,352 @@ +#[cfg(test)] +mod tests { + use crate::recipe::build_recipe::{build_recipe_from_template, RecipeError}; + use crate::recipe::read_recipe_file_content::RecipeFile; + use crate::recipe::{RecipeParameterInputType, RecipeParameterRequirement}; + use tempfile::TempDir; + + const NO_USER_PROMPT: Option Result> = None; + + fn setup_recipe_file(instructions_and_parameters: &str) -> (TempDir, RecipeFile) { + let recipe_content = format!( + r#"{{ + "version": "1.0.0", + "title": "Test Recipe", + "description": "A test recipe", + {} + }}"#, + instructions_and_parameters + ); + let temp_dir = tempfile::tempdir().unwrap(); + let recipe_path = temp_dir.path().join("test_recipe.json"); + + std::fs::write(&recipe_path, recipe_content).unwrap(); + + let recipe_file = RecipeFile { + content: std::fs::read_to_string(&recipe_path).unwrap(), + parent_dir: temp_dir.path().to_path_buf(), + file_path: recipe_path, + }; + + (temp_dir, recipe_file) + } + + fn setup_yaml_recipe_files( + parent_content: &str, + child_content: &str, + ) -> (TempDir, RecipeFile, RecipeFile) { + let temp_dir = tempfile::tempdir().unwrap(); + let temp_path = temp_dir.path(); + + let parent_path = temp_path.join("parent.yaml"); + std::fs::write(&parent_path, parent_content).unwrap(); + + let child_path = temp_path.join("child.yaml"); + std::fs::write(&child_path, child_content).unwrap(); + + let parent_recipe_file = RecipeFile { + content: std::fs::read_to_string(&parent_path).unwrap(), + parent_dir: temp_path.to_path_buf(), + file_path: parent_path, + }; + + let child_recipe_file = RecipeFile { + content: std::fs::read_to_string(&child_path).unwrap(), + parent_dir: temp_path.to_path_buf(), + file_path: child_path, + }; + + (temp_dir, parent_recipe_file, child_recipe_file) + } + + #[test] + fn test_build_recipe_from_template_success() { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ my_name }}", + "parameters": [ + { + "key": "my_name", + "input_type": "string", + "requirement": "required", + "description": "A test parameter" + } + ]"#; + + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let params = vec![("my_name".to_string(), "value".to_string())]; + let recipe = build_recipe_from_template(recipe_file, params, NO_USER_PROMPT).unwrap(); + + assert_eq!(recipe.title, "Test Recipe"); + assert_eq!(recipe.description, "A test recipe"); + assert_eq!(recipe.instructions.unwrap(), "Test instructions with value"); + // Verify parameters match recipe definition + assert_eq!(recipe.parameters.as_ref().unwrap().len(), 1); + let param = &recipe.parameters.as_ref().unwrap()[0]; + assert_eq!(param.key, "my_name"); + assert!(matches!(param.input_type, RecipeParameterInputType::String)); + assert!(matches!( + param.requirement, + RecipeParameterRequirement::Required + )); + assert_eq!(param.description, "A test parameter"); + } + + #[test] + fn test_build_recipe_from_template_success_variable_in_prompt() { + let instructions_and_parameters = r#" + "instructions": "Test instructions", + "prompt": "My prompt {{ my_name }}", + "parameters": [ + { + "key": "my_name", + "input_type": "string", + "requirement": "required", + "description": "A test parameter" + } + ]"#; + + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let params = vec![("my_name".to_string(), "value".to_string())]; + let recipe = build_recipe_from_template(recipe_file, params, NO_USER_PROMPT).unwrap(); + + assert_eq!(recipe.title, "Test Recipe"); + assert_eq!(recipe.description, "A test recipe"); + assert_eq!(recipe.instructions.unwrap(), "Test instructions"); + assert_eq!(recipe.prompt.unwrap(), "My prompt value"); + let param = &recipe.parameters.as_ref().unwrap()[0]; + assert_eq!(param.key, "my_name"); + assert!(matches!(param.input_type, RecipeParameterInputType::String)); + assert!(matches!( + param.requirement, + RecipeParameterRequirement::Required + )); + assert_eq!(param.description, "A test parameter"); + } + + #[test] + fn test_build_recipe_from_template_wrong_parameters_in_recipe_file() { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ expected_param1 }} {{ expected_param2 }}", + "parameters": [ + { + "key": "wrong_param_key", + "input_type": "string", + "requirement": "required", + "description": "A test parameter" + } + ]"#; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let build_recipe_result = + build_recipe_from_template(recipe_file, Vec::new(), NO_USER_PROMPT); + assert!(build_recipe_result.is_err()); + let err = build_recipe_result.unwrap_err(); + println!("{}", err.to_string()); + + match err { + RecipeError::TemplateRendering { source } => { + let err_str = source.to_string(); + assert!(err_str.contains("Unnecessary parameter definitions: wrong_param_key.")); + assert!(err_str.contains("Missing definitions for parameters in the recipe file:")); + assert!(err_str.contains("expected_param1")); + assert!(err_str.contains("expected_param2")); + } + _ => panic!("Expected TemplateRendering error"), + } + } + + #[test] + fn test_build_recipe_from_template_with_default_values_in_recipe_file() { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ param_with_default }} {{ param_without_default }}", + "parameters": [ + { + "key": "param_with_default", + "input_type": "string", + "requirement": "optional", + "default": "my_default_value", + "description": "A test parameter" + }, + { + "key": "param_without_default", + "input_type": "string", + "requirement": "required", + "description": "A test parameter" + } + ]"#; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + let params = vec![("param_without_default".to_string(), "value1".to_string())]; + + let recipe = build_recipe_from_template(recipe_file, params, NO_USER_PROMPT).unwrap(); + + assert_eq!(recipe.title, "Test Recipe"); + assert_eq!(recipe.description, "A test recipe"); + assert_eq!( + recipe.instructions.unwrap(), + "Test instructions with my_default_value value1" + ); + } + + #[test] + fn test_build_recipe_from_template_optional_parameters_with_empty_default_values_in_recipe_file( + ) { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ optional_param }}", + "parameters": [ + { + "key": "optional_param", + "input_type": "string", + "requirement": "optional", + "description": "A test parameter", + "default": "" + } + ]"#; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let recipe = build_recipe_from_template(recipe_file, Vec::new(), NO_USER_PROMPT).unwrap(); + assert_eq!(recipe.title, "Test Recipe"); + assert_eq!(recipe.description, "A test recipe"); + assert_eq!(recipe.instructions.unwrap(), "Test instructions with "); + } + + #[test] + fn test_build_recipe_from_template_optional_parameters_without_default_values_in_recipe_file() { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ optional_param }}", + "parameters": [ + { + "key": "optional_param", + "input_type": "string", + "requirement": "optional", + "description": "A test parameter" + } + ]"#; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let build_recipe_result = + build_recipe_from_template(recipe_file, Vec::new(), NO_USER_PROMPT); + assert!(build_recipe_result.is_err()); + let err = build_recipe_result.unwrap_err(); + println!("{}", err.to_string()); + match err { + RecipeError::TemplateRendering { source } => { + assert!(source.to_string().to_lowercase().contains("missing")); + } + _ => panic!("Expected TemplateRendering error"), + } + } + + #[test] + fn test_build_recipe_from_template_wrong_input_type_in_recipe_file() { + let instructions_and_parameters = r#" + "instructions": "Test instructions with {{ param }}", + "parameters": [ + { + "key": "param", + "input_type": "some_invalid_type", + "requirement": "required", + "description": "A test parameter" + } + ]"#; + let params = vec![("param".to_string(), "value".to_string())]; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let build_recipe_result = build_recipe_from_template(recipe_file, params, NO_USER_PROMPT); + assert!(build_recipe_result.is_err()); + let err = build_recipe_result.unwrap_err(); + match err { + RecipeError::TemplateRendering { source } => { + let err_msg = source.to_string(); + eprint!("Error: {}", err_msg); + assert!(err_msg.contains("unknown variant `some_invalid_type`")); + } + _ => panic!("Expected TemplateRendering error, got: {:?}", err), + } + } + + #[test] + fn test_build_recipe_from_template_success_without_parameters() { + let instructions_and_parameters = r#" + "instructions": "Test instructions" + "#; + let (_temp_dir, recipe_file) = setup_recipe_file(instructions_and_parameters); + + let recipe = build_recipe_from_template(recipe_file, Vec::new(), NO_USER_PROMPT).unwrap(); + assert_eq!(recipe.instructions.unwrap(), "Test instructions"); + assert!(recipe.parameters.is_none()); + } + + #[test] + fn test_template_inheritance() { + let parent_content = r#" + version: 1.0.0 + title: Parent + description: Parent recipe + prompt: | + show me the news for day: {{ date }} + {% block prompt -%} + What is the capital of France? + {%- endblock %} + {% if is_enabled %} + Feature is enabled. + {% else %} + Feature is disabled. + {% endif %} + parameters: + - key: date + input_type: string + requirement: required + description: date specified by the user + - key: is_enabled + input_type: boolean + requirement: required + description: whether the feature is enabled + "#; + + let child_content = r#" + {% extends "parent.yaml" -%} + {% block prompt -%} + What is the capital of Germany? + {%- endblock %} + "#; + + let (_temp_dir, parent_recipe_file, child_recipe_file) = + setup_yaml_recipe_files(parent_content, child_content); + + let params = vec![ + ("date".to_string(), "today".to_string()), + ("is_enabled".to_string(), "true".to_string()), + ]; + + let parent_recipe = + build_recipe_from_template(parent_recipe_file, params.clone(), NO_USER_PROMPT).unwrap(); + assert_eq!(parent_recipe.description, "Parent recipe"); + assert_eq!( + parent_recipe.prompt.unwrap(), + "show me the news for day: today\nWhat is the capital of France?\n\n Feature is enabled.\n" + ); + assert_eq!(parent_recipe.parameters.as_ref().unwrap().len(), 2); + assert_eq!(parent_recipe.parameters.as_ref().unwrap()[0].key, "date"); + assert_eq!( + parent_recipe.parameters.as_ref().unwrap()[1].key, + "is_enabled" + ); + + let child_recipe = + build_recipe_from_template(child_recipe_file, params, NO_USER_PROMPT).unwrap(); + assert_eq!(child_recipe.title, "Parent"); + assert_eq!(child_recipe.description, "Parent recipe"); + assert_eq!( + child_recipe.prompt.unwrap().trim(), + "show me the news for day: today\nWhat is the capital of Germany?\n\n Feature is enabled." + ); + assert_eq!(child_recipe.parameters.as_ref().unwrap().len(), 2); + assert_eq!(child_recipe.parameters.as_ref().unwrap()[0].key, "date"); + assert_eq!( + child_recipe.parameters.as_ref().unwrap()[1].key, + "is_enabled" + ); + } +} diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index c38c06971460..2d753b20c9c9 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -7,6 +7,12 @@ use crate::agents::extension::ExtensionConfig; use serde::de::Deserializer; use serde::{Deserialize, Serialize}; +pub mod build_recipe; +pub mod read_recipe_file_content; +pub mod template_recipe; + +pub const BUILT_IN_RECIPE_DIR_PARAM: &str = "recipe_dir"; + fn default_version() -> String { "1.0.0".to_string() } @@ -137,6 +143,8 @@ pub struct SubRecipe { pub path: String, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, + #[serde(default)] + pub sequential_when_repeated: bool, } fn deserialize_value_map_as_string<'de, D>( diff --git a/crates/goose/src/recipe/read_recipe_file_content.rs b/crates/goose/src/recipe/read_recipe_file_content.rs new file mode 100644 index 000000000000..1aeac3bc2d57 --- /dev/null +++ b/crates/goose/src/recipe/read_recipe_file_content.rs @@ -0,0 +1,46 @@ +use anyhow::{anyhow, Result}; +use std::fs; +use std::path::{Path, PathBuf}; +pub struct RecipeFile { + pub content: String, + pub parent_dir: PathBuf, + pub file_path: PathBuf, +} + +pub fn read_recipe_file>(recipe_path: P) -> Result { + let raw_path = recipe_path.as_ref(); + let path = convert_path_with_tilde_expansion(raw_path); + + let content = fs::read_to_string(&path) + .map_err(|e| anyhow!("Failed to read recipe file {}: {}", path.display(), e))?; + + let canonical = path.canonicalize().map_err(|e| { + anyhow!( + "Failed to resolve absolute path for {}: {}", + path.display(), + e + ) + })?; + + let parent_dir = canonical + .parent() + .ok_or_else(|| anyhow!("Resolved path has no parent: {}", canonical.display()))? + .to_path_buf(); + + Ok(RecipeFile { + content, + parent_dir, + file_path: canonical, + }) +} + +fn convert_path_with_tilde_expansion(path: &Path) -> PathBuf { + if let Some(path_str) = path.to_str() { + if let Some(stripped) = path_str.strip_prefix("~/") { + if let Some(home_dir) = dirs::home_dir() { + return home_dir.join(stripped); + } + } + } + PathBuf::from(path) +} diff --git a/crates/goose-cli/src/recipes/template_recipe.rs b/crates/goose/src/recipe/template_recipe.rs similarity index 92% rename from crates/goose-cli/src/recipes/template_recipe.rs rename to crates/goose/src/recipe/template_recipe.rs index aed1fe1d03f2..7396bb99f5af 100644 --- a/crates/goose-cli/src/recipes/template_recipe.rs +++ b/crates/goose/src/recipe/template_recipe.rs @@ -3,14 +3,14 @@ use std::{ path::Path, }; +use crate::recipe::{Recipe, BUILT_IN_RECIPE_DIR_PARAM}; use anyhow::Result; -use goose::recipe::Recipe; use minijinja::{Environment, UndefinedBehavior}; use regex::Regex; -use crate::recipes::recipe::BUILT_IN_RECIPE_DIR_PARAM; - const CURRENT_TEMPLATE_NAME: &str = "current_template"; +const OPEN_BRACE: &str = "{{"; +const CLOSE_BRACE: &str = "}}"; fn preprocess_template_variables(content: &str) -> Result { let all_template_variables = extract_template_variables(content); @@ -27,6 +27,7 @@ fn extract_template_variables(content: &str) -> Vec { .collect() } +// filter out variables that are not only alphanumeric and underscores fn filter_complex_variables(template_variables: &[String]) -> Vec { let valid_var_re = Regex::new(r"^\s*[a-zA-Z_][a-zA-Z0-9_]*\s*$").unwrap(); template_variables @@ -40,15 +41,14 @@ fn filter_unparseable_variables(template_variables: &[String]) -> Result, -) -> Result { - let preprocessed_content = preprocess_template_variables(content)?; - - let mut env = minijinja::Environment::new(); - env.set_undefined_behavior(UndefinedBehavior::Lenient); - let template = env.template_from_str(&preprocessed_content)?; - let rendered_content = template.render(params)?; - Ok(rendered_content) -} - fn add_template_in_env( content: &str, recipe_dir: String, @@ -218,7 +205,7 @@ mod tests { mod render_content_with_params_tests { use std::collections::HashMap; - use crate::recipes::template_recipe::render_recipe_content_with_params; + use crate::recipe::template_recipe::render_recipe_content_with_params; #[test] fn test_render_content_with_params() { diff --git a/crates/goose/src/recipe_deeplink.rs b/crates/goose/src/recipe_deeplink.rs new file mode 100644 index 000000000000..7123f4ef3e38 --- /dev/null +++ b/crates/goose/src/recipe_deeplink.rs @@ -0,0 +1,108 @@ +use anyhow::Result; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use thiserror::Error; + +use crate::recipe::Recipe; + +#[derive(Error, Debug)] +pub enum DecodeError { + #[error("All decoding methods failed")] + AllMethodsFailed, +} + +pub fn encode(recipe: &Recipe) -> Result { + let recipe_json = serde_json::to_string(recipe)?; + let encoded = URL_SAFE_NO_PAD.encode(recipe_json.as_bytes()); + Ok(encoded) +} + +pub fn decode(link: &str) -> Result { + // Handle the current format: URL-safe Base64 without padding. + if let Ok(decoded_bytes) = URL_SAFE_NO_PAD.decode(link) { + if let Ok(recipe_json) = String::from_utf8(decoded_bytes) { + if let Ok(recipe) = serde_json::from_str::(&recipe_json) { + return Ok(recipe); + } + } + } + + // Handle legacy formats of 'standard base64 encoded' and standard base64 encoded that was then url encoded. + if let Ok(url_decoded) = urlencoding::decode(link) { + if let Ok(decoded_bytes) = + base64::engine::general_purpose::STANDARD.decode(url_decoded.as_bytes()) + { + if let Ok(recipe_json) = String::from_utf8(decoded_bytes) { + if let Ok(recipe) = serde_json::from_str::(&recipe_json) { + return Ok(recipe); + } + } + } + } + + Err(DecodeError::AllMethodsFailed) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::recipe::Recipe; + + fn create_test_recipe() -> Recipe { + Recipe::builder() + .title("Test Recipe") + .description("A test recipe for deeplink encoding/decoding") + .instructions("Act as a helpful assistant") + .build() + .expect("Failed to build test recipe") + } + + #[test] + fn test_encode_decode_round_trip() { + let original_recipe = create_test_recipe(); + + let encoded = encode(&original_recipe).expect("Failed to encode recipe"); + assert!(!encoded.is_empty()); + + let decoded_recipe = decode(&encoded).expect("Failed to decode recipe"); + + assert_eq!(original_recipe.title, decoded_recipe.title); + assert_eq!(original_recipe.description, decoded_recipe.description); + assert_eq!(original_recipe.instructions, decoded_recipe.instructions); + assert_eq!(original_recipe.version, decoded_recipe.version); + } + + #[test] + fn test_decode_legacy_standard_base64() { + let recipe = create_test_recipe(); + let recipe_json = serde_json::to_string(&recipe).unwrap(); + let legacy_encoded = + base64::engine::general_purpose::STANDARD.encode(recipe_json.as_bytes()); + + let decoded_recipe = decode(&legacy_encoded).expect("Failed to decode legacy format"); + assert_eq!(recipe.title, decoded_recipe.title); + assert_eq!(recipe.description, decoded_recipe.description); + assert_eq!(recipe.instructions, decoded_recipe.instructions); + } + + #[test] + fn test_decode_legacy_url_encoded_base64() { + let recipe = create_test_recipe(); + let recipe_json = serde_json::to_string(&recipe).unwrap(); + let base64_encoded = + base64::engine::general_purpose::STANDARD.encode(recipe_json.as_bytes()); + let url_encoded = urlencoding::encode(&base64_encoded); + + let decoded_recipe = + decode(&url_encoded).expect("Failed to decode URL-encoded legacy format"); + assert_eq!(recipe.title, decoded_recipe.title); + assert_eq!(recipe.description, decoded_recipe.description); + assert_eq!(recipe.instructions, decoded_recipe.instructions); + } + + #[test] + fn test_decode_invalid_input() { + let result = decode("invalid_base64!"); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), DecodeError::AllMethodsFailed)); + } +} diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 20c455deb9b6..ef1dee412775 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1219,7 +1219,7 @@ async fn run_scheduled_job_internal( match message_result { Ok(AgentEvent::Message(msg)) => { - if msg.role == mcp_core::role::Role::Assistant { + if msg.role == rmcp::model::Role::Assistant { tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content); } all_session_messages.push(msg); @@ -1267,6 +1267,7 @@ async fn run_scheduled_job_internal( working_dir: current_dir.clone(), description: String::new(), schedule_id: Some(job.id.clone()), + project_id: None, message_count: all_session_messages.len(), total_tokens: None, input_tokens: None, @@ -1330,7 +1331,8 @@ mod tests { providers::base::{ProviderMetadata, ProviderUsage, Usage}, providers::errors::ProviderError, }; - use mcp_core::{content::TextContent, tool::Tool, Role}; + use mcp_core::tool::Tool; + use rmcp::model::{AnnotateAble, RawTextContent, Role}; // Removed: use crate::session::storage::{get_most_recent_session, read_metadata}; // `read_metadata` is still used by the test itself, so keep it or its module. use crate::session::storage::read_metadata; @@ -1373,10 +1375,12 @@ mod tests { Message::new( Role::Assistant, Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - text: "Mocked scheduled response".to_string(), - annotations: None, - })], + vec![MessageContent::Text( + RawTextContent { + text: "Mocked scheduled response".to_string(), + } + .no_annotation(), + )], ), ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()), )) diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 689006fa9902..4f7557f3ab06 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -7,6 +7,7 @@ use crate::message::Message; use crate::providers::base::Provider; +use crate::utils::safe_truncate; use anyhow::Result; use chrono::Local; use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; @@ -14,6 +15,7 @@ use regex::Regex; use serde::{Deserialize, Serialize}; use std::fs; use std::io::{self, BufRead, Write}; +use std::ops::DerefMut; use std::path::{Path, PathBuf}; use std::sync::Arc; use utoipa::ToSchema; @@ -40,6 +42,8 @@ pub struct SessionMetadata { pub description: String, /// ID of the schedule that triggered this session, if any pub schedule_id: Option, + /// ID of the project this session belongs to, if any + pub project_id: Option, /// Number of messages in the session pub message_count: usize, /// The total number of tokens used in the session. Retrieved from the provider's last usage. @@ -67,6 +71,7 @@ impl<'de> Deserialize<'de> for SessionMetadata { description: String, message_count: usize, schedule_id: Option, // For backward compatibility + project_id: Option, // For backward compatibility total_tokens: Option, input_tokens: Option, output_tokens: Option, @@ -88,6 +93,7 @@ impl<'de> Deserialize<'de> for SessionMetadata { description: helper.description, message_count: helper.message_count, schedule_id: helper.schedule_id, + project_id: helper.project_id, total_tokens: helper.total_tokens, input_tokens: helper.input_tokens, output_tokens: helper.output_tokens, @@ -112,6 +118,7 @@ impl SessionMetadata { working_dir, description: String::new(), schedule_id: None, + project_id: None, message_count: 0, total_tokens: None, input_tokens: None, @@ -605,7 +612,7 @@ pub fn read_messages_with_truncation( // Log details about corrupted lines (with limited detail for security) for (num, line) in &corrupted_lines { let preview = if line.len() > 50 { - format!("{}... (truncated)", &line[..50]) + format!("{}... (truncated)", safe_truncate(line, 50)) } else { line.clone() }; @@ -673,16 +680,16 @@ fn parse_message_with_truncation( /// Truncate content within a message in place fn truncate_message_content_in_place(message: &mut Message, max_content_size: usize) { use crate::message::MessageContent; - use mcp_core::{Content, ResourceContents}; + use rmcp::model::{RawContent, ResourceContents}; for content in &mut message.content { match content { MessageContent::Text(text_content) => { - if text_content.text.len() > max_content_size { + if text_content.text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... content truncated during session loading from {} to {} characters ...]", - &text_content.text[..max_content_size.min(text_content.text.len())], - text_content.text.len(), + safe_truncate(&text_content.text, max_content_size), + text_content.text.chars().count(), max_content_size ); text_content.text = truncated; @@ -691,27 +698,27 @@ fn truncate_message_content_in_place(message: &mut Message, max_content_size: us MessageContent::ToolResponse(tool_response) => { if let Ok(ref mut result) = tool_response.tool_result { for content_item in result { - match content_item { - Content::Text(ref mut text_content) => { - if text_content.text.len() > max_content_size { + match content_item.deref_mut() { + RawContent::Text(ref mut text_content) => { + if text_content.text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... tool response truncated during session loading from {} to {} characters ...]", - &text_content.text[..max_content_size.min(text_content.text.len())], - text_content.text.len(), + safe_truncate(&text_content.text, max_content_size), + text_content.text.chars().count(), max_content_size ); text_content.text = truncated; } } - Content::Resource(ref mut resource_content) => { + RawContent::Resource(ref mut resource_content) => { if let ResourceContents::TextResourceContents { text, .. } = &mut resource_content.resource { - if text.len() > max_content_size { + if text.chars().count() > max_content_size { let truncated = format!( "{}\n\n[... resource content truncated during session loading from {} to {} characters ...]", - &text[..max_content_size.min(text.len())], - text.len(), + safe_truncate(text, max_content_size), + text.chars().count(), max_content_size ); *text = truncated; @@ -751,7 +758,7 @@ fn attempt_corruption_recovery(json_str: &str, max_content_size: Option) // Strategy 4: Create a placeholder message with the raw content println!("[SESSION] All recovery strategies failed, creating placeholder message"); let preview = if json_str.len() > 200 { - format!("{}...", &json_str[..200]) + format!("{}...", safe_truncate(json_str, 200)) } else { json_str.to_string() }; @@ -856,11 +863,11 @@ fn try_extract_partial_message(json_str: &str) -> Result { // Try to extract role let role = if json_str.contains("\"role\":\"user\"") { - mcp_core::role::Role::User + rmcp::model::Role::User } else if json_str.contains("\"role\":\"assistant\"") { - mcp_core::role::Role::Assistant + rmcp::model::Role::Assistant } else { - mcp_core::role::Role::User // Default fallback + rmcp::model::Role::User // Default fallback }; // Try to extract text content @@ -895,8 +902,8 @@ fn try_extract_partial_message(json_str: &str) -> Result { if !extracted_text.is_empty() { let message = match role { - mcp_core::role::Role::User => Message::user(), - mcp_core::role::Role::Assistant => Message::assistant(), + rmcp::model::Role::User => Message::user(), + rmcp::model::Role::Assistant => Message::assistant(), }; return Ok(message.with_text(format!("[PARTIALLY RECOVERED] {}", extracted_text))); @@ -968,7 +975,7 @@ fn truncate_json_string(json_str: &str, max_content_size: usize) -> String { if text_content.len() > max_content_size { let truncated_text = format!( "{}\n\n[... content truncated during JSON parsing from {} to {} characters ...]", - &text_content[..max_content_size.min(text_content.len())], + safe_truncate(text_content, max_content_size), text_content.len(), max_content_size ); @@ -1036,13 +1043,13 @@ pub fn read_metadata(session_file: &Path) -> Result { /// /// Security features: /// - Validates file paths to prevent directory traversal -/// - Uses secure file operations via persist_messages_with_schedule_id pub async fn persist_messages( session_file: &Path, messages: &[Message], provider: Option>, + working_dir: Option, ) -> Result<()> { - persist_messages_with_schedule_id(session_file, messages, provider, None).await + persist_messages_with_schedule_id(session_file, messages, provider, None, working_dir).await } /// Write messages to a session file with metadata, including an optional scheduled job ID @@ -1059,6 +1066,7 @@ pub async fn persist_messages_with_schedule_id( messages: &[Message], provider: Option>, schedule_id: Option, + working_dir: Option, ) -> Result<()> { // Validate the session file path for security let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; @@ -1072,23 +1080,42 @@ pub async fn persist_messages_with_schedule_id( // Count user messages let user_message_count = messages .iter() - .filter(|m| m.role == mcp_core::role::Role::User && !m.as_concat_text().trim().is_empty()) + .filter(|m| m.role == rmcp::model::Role::User && !m.as_concat_text().trim().is_empty()) .count(); // Check if we need to update the description (after 1st or 3rd user message) match provider { Some(provider) if user_message_count < 4 => { //generate_description is responsible for writing the messages - generate_description_with_schedule_id(&secure_path, messages, provider, schedule_id) - .await + generate_description_with_schedule_id( + &secure_path, + messages, + provider, + schedule_id, + working_dir, + ) + .await } _ => { - // Read existing metadata - let mut metadata = read_metadata(&secure_path)?; + // Read existing metadata or create new with proper working_dir + let mut metadata = if secure_path.exists() { + read_metadata(&secure_path)? + } else { + // Create new metadata with the provided working_dir or fall back to home + let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); + SessionMetadata::new(work_dir) + }; + + // Update the working_dir if provided (even for existing files) + if let Some(work_dir) = working_dir { + metadata.working_dir = work_dir; + } + // Update the schedule_id if provided if schedule_id.is_some() { metadata.schedule_id = schedule_id; } + // Write the file with metadata and messages save_messages_with_metadata(&secure_path, &metadata, messages) } @@ -1226,11 +1253,12 @@ pub async fn generate_description( session_file: &Path, messages: &[Message], provider: Arc, + working_dir: Option, ) -> Result<()> { - generate_description_with_schedule_id(session_file, messages, provider, None).await + generate_description_with_schedule_id(session_file, messages, provider, None, working_dir).await } -/// Generate a description for the session using the provider, including an optional scheduled job ID +/// Generate a description for the session using the provider, including an optional scheduled job ID and working directory /// /// This function is called when appropriate to generate a short description /// of the session based on the conversation history. @@ -1244,6 +1272,7 @@ pub async fn generate_description_with_schedule_id( messages: &[Message], provider: Arc, schedule_id: Option, + working_dir: Option, ) -> Result<()> { // Validate the path for security let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; @@ -1265,15 +1294,11 @@ pub async fn generate_description_with_schedule_id( // get context from messages so far, limiting each message to 300 chars for security let context: Vec = messages .iter() - .filter(|m| m.role == mcp_core::role::Role::User) + .filter(|m| m.role == rmcp::model::Role::User) .take(3) // Use up to first 3 user messages for context .map(|m| { let text = m.as_concat_text(); - if text.len() > 300 { - format!("{}...", &text[..300]) - } else { - text - } + safe_truncate(&text, 300) }) .collect(); @@ -1302,14 +1327,21 @@ pub async fn generate_description_with_schedule_id( let description = result.0.as_concat_text(); // Validate description length for security - let sanitized_description = if description.len() > 100 { + let sanitized_description = if description.chars().count() > 100 { tracing::warn!("Generated description too long, truncating"); - format!("{}...", &description[..97]) + safe_truncate(&description, 100) } else { description }; - let mut metadata = read_metadata(&secure_path)?; + // Create metadata with proper working_dir or read existing and update + let mut metadata = if secure_path.exists() { + read_metadata(&secure_path)? + } else { + // Create new metadata with the provided working_dir or fall back to home + let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); + SessionMetadata::new(work_dir) + }; // Update description and schedule_id metadata.description = sanitized_description; @@ -1317,6 +1349,11 @@ pub async fn generate_description_with_schedule_id( metadata.schedule_id = schedule_id; } + // Update the working_dir if provided (even for existing files) + if let Some(work_dir) = working_dir { + metadata.working_dir = work_dir; + } + // Update the file with the new metadata and existing messages save_messages_with_metadata(&secure_path, &metadata, messages) } @@ -1379,9 +1416,9 @@ mod tests { println!( "[TEST] Input: {}", if corrupt_json.len() > 100 { - &corrupt_json[..100] + safe_truncate(corrupt_json, 100) } else { - corrupt_json + corrupt_json.to_string() } ); @@ -1428,7 +1465,7 @@ mod tests { ]; // Write messages - persist_messages(&file_path, &messages, None).await?; + persist_messages(&file_path, &messages, None, None).await?; // Read them back let read_messages = read_messages(&file_path)?; @@ -1536,7 +1573,7 @@ mod tests { } // Write messages with special characters - persist_messages(&file_path, &messages, None).await?; + persist_messages(&file_path, &messages, None, None).await?; // Read them back let read_messages = read_messages(&file_path)?; @@ -1601,7 +1638,7 @@ mod tests { ]; // Write messages - persist_messages(&file_path, &messages, None).await?; + persist_messages(&file_path, &messages, None, None).await?; // Read them back - should be truncated let read_messages = read_messages(&file_path)?; @@ -1692,6 +1729,162 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_working_dir_preservation() -> Result<()> { + let dir = tempdir()?; + let file_path = dir.path().join("test.jsonl"); + + // Create a temporary working directory + let working_dir = tempdir()?; + let working_dir_path = working_dir.path().to_path_buf(); + + // Create messages + let messages = vec![Message::user().with_text("test message")]; + + // Use persist_messages_with_schedule_id to set working dir + persist_messages_with_schedule_id( + &file_path, + &messages, + None, + None, + Some(working_dir_path.clone()), + ) + .await?; + + // Read back the metadata and verify working_dir is preserved + let metadata = read_metadata(&file_path)?; + assert_eq!(metadata.working_dir, working_dir_path); + + // Verify the messages are also preserved + let read_messages = read_messages(&file_path)?; + assert_eq!(read_messages.len(), 1); + assert_eq!(read_messages[0].role, messages[0].role); + + Ok(()) + } + + #[tokio::test] + async fn test_working_dir_issue_fixed() -> Result<()> { + // This test demonstrates that the working_dir issue in jsonl files is fixed + let dir = tempdir()?; + let file_path = dir.path().join("test.jsonl"); + + // Create a temporary working directory (this simulates the actual working directory) + let working_dir = tempdir()?; + let working_dir_path = working_dir.path().to_path_buf(); + + // Create messages + let messages = vec![Message::user().with_text("test message")]; + + // Get the home directory for comparison + let home_dir = get_home_dir(); + + // Test 1: Using the old persist_messages function (without working_dir) + // This will fall back to home directory since no working_dir is provided + persist_messages(&file_path, &messages, None, None).await?; + + // Read back the metadata - this should now have the home directory as working_dir + let metadata_old = read_metadata(&file_path)?; + assert_eq!( + metadata_old.working_dir, home_dir, + "persist_messages should use home directory when no working_dir is provided" + ); + + // Test 2: Using persist_messages_with_schedule_id function + // This should properly set the working_dir (this is the main fix) + persist_messages_with_schedule_id( + &file_path, + &messages, + None, + None, + Some(working_dir_path.clone()), + ) + .await?; + + // Read back the metadata - this should now have the correct working_dir + let metadata_new = read_metadata(&file_path)?; + assert_eq!( + metadata_new.working_dir, working_dir_path, + "persist_messages_with_schedule_id should use provided working_dir" + ); + assert_ne!( + metadata_new.working_dir, home_dir, + "working_dir should be different from home directory" + ); + + // Test 3: Create a new session file without working_dir (should fall back to home) + let file_path_2 = dir.path().join("test2.jsonl"); + persist_messages_with_schedule_id( + &file_path_2, + &messages, + None, + None, + None, // No working_dir provided + ) + .await?; + + let metadata_fallback = read_metadata(&file_path_2)?; + assert_eq!(metadata_fallback.working_dir, home_dir, "persist_messages_with_schedule_id should fall back to home directory when no working_dir is provided"); + + // Test 4: Test that the fix works for existing files + // Create a session file and then add to it with different working_dir + let file_path_3 = dir.path().join("test3.jsonl"); + + // First, create with home directory + persist_messages(&file_path_3, &messages, None, None).await?; + let metadata_initial = read_metadata(&file_path_3)?; + assert_eq!( + metadata_initial.working_dir, home_dir, + "Initial session should use home directory" + ); + + // Then update with a specific working_dir + persist_messages_with_schedule_id( + &file_path_3, + &messages, + None, + None, + Some(working_dir_path.clone()), + ) + .await?; + + let metadata_updated = read_metadata(&file_path_3)?; + assert_eq!( + metadata_updated.working_dir, working_dir_path, + "Updated session should use new working_dir" + ); + + // Test 5: Most important test - simulate the real-world scenario where + // CLI and web interfaces pass the current directory instead of None + let file_path_4 = dir.path().join("test4.jsonl"); + let current_dir = std::env::current_dir()?; + + // This is what web.rs and session/mod.rs do now after the fix + persist_messages_with_schedule_id( + &file_path_4, + &messages, + None, + None, + Some(current_dir.clone()), + ) + .await?; + + let metadata_current = read_metadata(&file_path_4)?; + assert_eq!( + metadata_current.working_dir, current_dir, + "Session should use current directory when explicitly provided" + ); + // This should NOT be the home directory anymore (unless current_dir == home_dir) + if current_dir != home_dir { + assert_ne!( + metadata_current.working_dir, home_dir, + "working_dir should be different from home directory when current_dir is different" + ); + } + + Ok(()) + } + #[test] fn test_windows_path_validation() -> Result<()> { // Test the Windows path validation logic @@ -1779,12 +1972,13 @@ mod tests { Message::assistant().with_text("Test response"), ]; - // Test persist_messages_with_schedule_id with save_session = true + // Test persist_messages_with_schedule_id with working_dir parameter persist_messages_with_schedule_id( &file_path, &messages, None, Some("test_schedule".to_string()), + None, ) .await?; diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 6db01dd3cede..6f519beb98d6 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -371,8 +371,8 @@ pub async fn create_async_token_counter() -> Result { mod tests { use super::*; use crate::message::{Message, MessageContent}; - use mcp_core::role::Role; use mcp_core::tool::Tool; + use rmcp::model::Role; use serde_json::json; #[test] diff --git a/crates/goose-cli/src/utils.rs b/crates/goose/src/utils.rs similarity index 95% rename from crates/goose-cli/src/utils.rs rename to crates/goose/src/utils.rs index 69daddf1d2a1..60121f1bfe4d 100644 --- a/crates/goose-cli/src/utils.rs +++ b/crates/goose/src/utils.rs @@ -1,4 +1,3 @@ -/// Utility functions for safe string handling and other common operations /// Safely truncate a string at character boundaries, not byte boundaries /// /// This function ensures that multi-byte UTF-8 characters (like Japanese, emoji, etc.) diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 3f7cff0c4e20..ad8ffe5dc1b1 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -535,7 +535,11 @@ mod schedule_tool_tests { #[cfg(test)] mod final_output_tool_tests { use super::*; - use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME; + use futures::stream; + use goose::agents::final_output_tool::{ + FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME, + }; + use goose::providers::base::MessageStream; use goose::recipe::Response; #[tokio::test] @@ -609,9 +613,9 @@ mod final_output_tool_tests { let content = final_result.unwrap(); let text = content.first().unwrap().as_text().unwrap(); assert!( - text.contains("Final output successfully collected."), + text.text.contains("Final output successfully collected."), "Tool result missing expected content: {}", - text + text.text ); // Simulate the reply stream continuing after the final output tool call. @@ -631,12 +635,130 @@ mod final_output_tool_tests { let last_message = responses.last().unwrap(); // Check that the last message is an assistant message with our final output - assert_eq!(last_message.role, mcp_core::role::Role::Assistant); + assert_eq!(last_message.role, rmcp::model::Role::Assistant); let message_text = last_message.as_concat_text(); assert_eq!(message_text, r#"{"result":"Test output"}"#); Ok(()) } + + #[tokio::test] + async fn test_when_final_output_not_called_in_reply() -> Result<()> { + use async_trait::async_trait; + use goose::model::ModelConfig; + use goose::providers::base::{Provider, ProviderUsage}; + use goose::providers::errors::ProviderError; + use mcp_core::tool::Tool; + + #[derive(Clone)] + struct MockProvider { + model_config: ModelConfig, + } + + #[async_trait] + impl Provider for MockProvider { + fn metadata() -> goose::providers::base::ProviderMetadata { + goose::providers::base::ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + let deltas = vec![ + Ok((Some(Message::assistant().with_text("Hello")), None)), + Ok((Some(Message::assistant().with_text("Hi!")), None)), + Ok(( + Some(Message::assistant().with_text("What is the final output?")), + None, + )), + ]; + + let stream = stream::iter(deltas.into_iter()); + Ok(Box::pin(stream)) + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + Err(ProviderError::NotImplemented("Not implemented".to_string())) + } + } + + let agent = Agent::new(); + + let model_config = ModelConfig::new("test-model".to_string()); + let mock_provider = Arc::new(MockProvider { model_config }); + agent.update_provider(mock_provider).await?; + + let response = Response { + json_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + }, + "required": ["result"] + })), + }; + agent.add_final_output_tool(response).await; + + // Simulate the reply stream being called. + let reply_stream = agent.reply(&vec![], None).await?; + tokio::pin!(reply_stream); + + let mut responses = Vec::new(); + let mut count = 0; + while let Some(response_result) = reply_stream.next().await { + match response_result { + Ok(AgentEvent::Message(response)) => { + responses.push(response); + count += 1; + if count >= 4 { + // Limit to 4 messages to avoid infinite loop due to mock provider + break; + } + } + Ok(_) => {} + Err(e) => return Err(e), + } + } + + assert!(!responses.is_empty(), "Should have received responses"); + println!("Responses: {:?}", responses); + let last_message = responses.last().unwrap(); + + // Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE + for (i, response) in responses.iter().take(3).enumerate() { + let message_text = response.as_concat_text(); + assert_ne!( + message_text, + FINAL_OUTPUT_CONTINUATION_MESSAGE, + "Message {} should not be the continuation message, got: '{}'", + i + 1, + message_text + ); + } + + // Check that the last message after the llm stream is the message directing the agent to continue + assert_eq!(last_message.role, rmcp::model::Role::User); + let message_text = last_message.as_concat_text(); + assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE); + + Ok(()) + } } #[cfg(test)] diff --git a/crates/goose/tests/pricing_integration_test.rs b/crates/goose/tests/pricing_integration_test.rs index 9e4472905f6b..083f96daf74d 100644 --- a/crates/goose/tests/pricing_integration_test.rs +++ b/crates/goose/tests/pricing_integration_test.rs @@ -3,6 +3,10 @@ use std::time::Instant; #[tokio::test] async fn test_pricing_cache_performance() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_perf_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + // Initialize the cache let start = Instant::now(); initialize_pricing_cache() @@ -65,10 +69,18 @@ async fn test_pricing_cache_performance() { first_fetch_duration, second_fetch_duration ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } #[tokio::test] async fn test_pricing_refresh() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_refresh_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + // Initialize first initialize_pricing_cache() .await @@ -90,10 +102,18 @@ async fn test_pricing_refresh() { refreshed_pricing.is_some(), "Expected pricing after refresh" ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } #[tokio::test] async fn test_model_not_in_openrouter() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_model_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + initialize_pricing_cache() .await .expect("Failed to initialize pricing cache"); @@ -104,12 +124,20 @@ async fn test_model_not_in_openrouter() { pricing.is_none(), "Should return None for non-existent model" ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } #[tokio::test] async fn test_concurrent_access() { use tokio::task; + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_concurrent_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + initialize_pricing_cache() .await .expect("Failed to initialize pricing cache"); @@ -133,4 +161,8 @@ async fn test_concurrent_access() { assert!(has_pricing, "Task {} should have gotten pricing", task_id); println!("Task {} took: {:?}", task_id, duration); } + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } diff --git a/crates/goose/tests/private_tests.rs b/crates/goose/tests/private_tests.rs index 78e7cdd7658c..d2ec7a06e8ae 100644 --- a/crates/goose/tests/private_tests.rs +++ b/crates/goose/tests/private_tests.rs @@ -1,6 +1,6 @@ #![cfg(test)] -use mcp_core::{Content, ToolError}; +use mcp_core::ToolError; use serde_json::json; use goose::agents::platform_tools::PLATFORM_MANAGE_SCHEDULE_TOOL_NAME; @@ -33,7 +33,7 @@ async fn test_schedule_tool_list_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content.text.contains("Scheduled Jobs:")); assert!(text_content.text.contains("job1")); assert!(text_content.text.contains("job2")); @@ -63,7 +63,7 @@ async fn test_schedule_tool_list_action_empty() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content.text.contains("Scheduled Jobs:")); } @@ -127,7 +127,7 @@ async fn test_schedule_tool_create_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Successfully created scheduled job")); @@ -286,7 +286,7 @@ async fn test_schedule_tool_run_now_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Successfully started job 'job1'")); @@ -370,7 +370,7 @@ async fn test_schedule_tool_pause_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content.text.contains("Successfully paused job 'job1'")); } @@ -455,7 +455,7 @@ async fn test_schedule_tool_unpause_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Successfully unpaused job 'job1'")); @@ -487,7 +487,7 @@ async fn test_schedule_tool_delete_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Successfully deleted job 'job1'")); @@ -521,7 +521,7 @@ async fn test_schedule_tool_kill_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Successfully killed running job 'job1'")); @@ -585,7 +585,7 @@ async fn test_schedule_tool_inspect_action_running() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Job 'job1' is currently running")); @@ -617,7 +617,7 @@ async fn test_schedule_tool_inspect_action_not_running() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Job 'job1' is not currently running")); @@ -663,7 +663,7 @@ async fn test_schedule_tool_sessions_action() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content.text.contains("Sessions for job 'job1'")); assert!(text_content.text.contains("session1")); assert!(text_content.text.contains("session2")); @@ -738,7 +738,7 @@ async fn test_schedule_tool_sessions_action_empty() { let content = result.unwrap(); assert_eq!(content.len(), 1); - if let Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("No sessions found for job 'job1'")); @@ -809,7 +809,7 @@ async fn test_schedule_tool_session_content_action_with_real_session() { if let Ok(content) = result { assert_eq!(content.len(), 1); - if let mcp_core::Content::Text(text_content) = &content[0] { + if let Some(text_content) = content[0].as_text() { assert!(text_content .text .contains("Session 'test_session_real' Content:")); diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index c4884b7c437b..c7b5c426eefe 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -4,10 +4,11 @@ use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; use goose::providers::{ - anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake, xai, + anthropic, azure, bedrock, databricks, google, groq, litellm, ollama, openai, openrouter, + snowflake, xai, }; -use mcp_core::content::Content; use mcp_core::tool::Tool; +use rmcp::model::{AnnotateAble, Content, RawImageContent}; use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; @@ -158,7 +159,7 @@ impl ProviderTester { .content .iter() .filter_map(|message| message.as_tool_request()) - .last() + .next_back() .expect("got tool request") .id; @@ -256,7 +257,6 @@ impl ProviderTester { async fn test_image_content_support(&self) -> Result<()> { use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; - use mcp_core::content::ImageContent; use std::fs; // Try to read the test image @@ -273,11 +273,11 @@ impl ProviderTester { }; let base64_image = BASE64.encode(image_data); - let image_content = ImageContent { + let image_content = RawImageContent { data: base64_image, mime_type: "image/png".to_string(), - annotations: None, - }; + } + .no_annotation(); // Test 1: Direct image message let message_with_image = @@ -324,8 +324,13 @@ impl ProviderTester { serde_json::json!({}), )), ); - let tool_response = - Message::user().with_tool_response("test_id", Ok(vec![Content::Image(image_content)])); + let tool_response = Message::user().with_tool_response( + "test_id", + Ok(vec![Content::image( + image_content.data.clone(), + image_content.mime_type.clone(), + )]), + ); let result2 = self .provider @@ -592,6 +597,28 @@ async fn test_sagemaker_tgi_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_litellm_provider() -> Result<()> { + if std::env::var("LITELLM_HOST").is_err() { + println!("LITELLM_HOST not set, skipping test"); + TEST_REPORT.record_skip("LiteLLM"); + return Ok(()); + } + + let env_mods = HashMap::from_iter([ + ("LITELLM_HOST", Some("http://localhost:4000".to_string())), + ("LITELLM_API_KEY", Some("".to_string())), + ]); + + test_provider( + "LiteLLM", + &[], // No required environment variables + Some(env_mods), + litellm::LiteLLMProvider::default, + ) + .await +} + #[tokio::test] async fn test_xai_provider() -> Result<()> { test_provider("Xai", &["XAI_API_KEY"], None, xai::XaiProvider::default).await diff --git a/crates/goose/tests/test_support.rs b/crates/goose/tests/test_support.rs index cfea855b1085..a2a3a2e5f6e8 100644 --- a/crates/goose/tests/test_support.rs +++ b/crates/goose/tests/test_support.rs @@ -32,6 +32,7 @@ pub struct ConfigurableMockScheduler { sessions_data: Arc>>>, } +#[allow(dead_code)] impl ConfigurableMockScheduler { pub fn new() -> Self { Self { @@ -404,6 +405,7 @@ pub fn create_test_session_metadata(message_count: usize, working_dir: &str) -> working_dir: PathBuf::from(working_dir), description: "Test session".to_string(), schedule_id: Some("test_job".to_string()), + project_id: None, total_tokens: Some(100), input_tokens: Some(50), output_tokens: Some(50), diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 2489af1bfbd0..afe10e8dc577 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -225,9 +225,22 @@ impl StdioTransport { #[cfg(windows)] command.creation_flags(0x08000000); // CREATE_NO_WINDOW flag - let mut process = command - .spawn() - .map_err(|e| Error::StdioProcessError(e.to_string()))?; + let mut process = command.spawn().map_err(|e| { + let command = command.into_std(); + Error::StdioProcessError(format!( + "Could not run extension command (`{} {}`): {}", + command + .get_program() + .to_str() + .unwrap_or("[invalid command]"), + command + .get_args() + .map(|arg| arg.to_str().unwrap_or("[invalid arg]")) + .collect::>() + .join(" "), + e + )) + })?; let stdin = process .stdin diff --git a/crates/mcp-core/Cargo.toml b/crates/mcp-core/Cargo.toml index 4f470798d2c0..893e63f5337b 100644 --- a/crates/mcp-core/Cargo.toml +++ b/crates/mcp-core/Cargo.toml @@ -8,6 +8,7 @@ workspace = true [dependencies] async-trait = "0.1" +rmcp = { workspace = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" diff --git a/crates/mcp-core/src/content.rs b/crates/mcp-core/src/content.rs deleted file mode 100644 index 91ede3d80fe5..000000000000 --- a/crates/mcp-core/src/content.rs +++ /dev/null @@ -1,313 +0,0 @@ -/// Content sent around agents, extensions, and LLMs -/// The various content types can be display to humans but also understood by models -/// They include optional annotations used to help inform agent usage -use super::role::Role; -use crate::resource::ResourceContents; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use utoipa::ToSchema; - -#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Annotations { - #[serde(skip_serializing_if = "Option::is_none")] - pub audience: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub priority: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(value_type = String, format = "date-time", example = "2023-01-01T00:00:00Z")] - // for openapi - pub timestamp: Option>, -} - -impl Annotations { - /// Creates a new Annotations instance specifically for resources - /// optional priority, and a timestamp (defaults to now if None) - pub fn for_resource(priority: f32, timestamp: DateTime) -> Self { - assert!( - (0.0..=1.0).contains(&priority), - "Priority {priority} must be between 0.0 and 1.0" - ); - Annotations { - priority: Some(priority), - timestamp: Some(timestamp), - audience: None, - } - } -} - -#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TextContent { - pub text: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ImageContent { - pub data: String, - pub mime_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct EmbeddedResource { - pub resource: ResourceContents, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -impl EmbeddedResource { - pub fn get_text(&self) -> String { - match &self.resource { - ResourceContents::TextResourceContents { text, .. } => text.clone(), - _ => String::new(), - } - } -} - -#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum Content { - Text(TextContent), - Image(ImageContent), - Resource(EmbeddedResource), -} - -impl Content { - pub fn text>(text: S) -> Self { - Content::Text(TextContent { - text: text.into(), - annotations: None, - }) - } - - pub fn image, T: Into>(data: S, mime_type: T) -> Self { - Content::Image(ImageContent { - data: data.into(), - mime_type: mime_type.into(), - annotations: None, - }) - } - - pub fn resource(resource: ResourceContents) -> Self { - Content::Resource(EmbeddedResource { - resource, - annotations: None, - }) - } - - pub fn embedded_text, T: Into>(uri: S, content: T) -> Self { - Content::Resource(EmbeddedResource { - resource: ResourceContents::TextResourceContents { - uri: uri.into(), - mime_type: Some("text".to_string()), - text: content.into(), - }, - annotations: None, - }) - } - - /// Get the text content if this is a TextContent variant - pub fn as_text(&self) -> Option<&str> { - match self { - Content::Text(text) => Some(&text.text), - _ => None, - } - } - - /// Get the image content if this is an ImageContent variant - pub fn as_image(&self) -> Option<(&str, &str)> { - match self { - Content::Image(image) => Some((&image.data, &image.mime_type)), - _ => None, - } - } - - /// Set the audience for the content - pub fn with_audience(mut self, audience: Vec) -> Self { - let annotations = match &mut self { - Content::Text(text) => &mut text.annotations, - Content::Image(image) => &mut image.annotations, - Content::Resource(resource) => &mut resource.annotations, - }; - *annotations = Some(match annotations.take() { - Some(mut a) => { - a.audience = Some(audience); - a - } - None => Annotations { - audience: Some(audience), - priority: None, - timestamp: None, - }, - }); - self - } - - /// Set the priority for the content - /// # Panics - /// Panics if priority is not between 0.0 and 1.0 inclusive - pub fn with_priority(mut self, priority: f32) -> Self { - if !(0.0..=1.0).contains(&priority) { - panic!("Priority must be between 0.0 and 1.0"); - } - let annotations = match &mut self { - Content::Text(text) => &mut text.annotations, - Content::Image(image) => &mut image.annotations, - Content::Resource(resource) => &mut resource.annotations, - }; - *annotations = Some(match annotations.take() { - Some(mut a) => { - a.priority = Some(priority); - a - } - None => Annotations { - audience: None, - priority: Some(priority), - timestamp: None, - }, - }); - self - } - - /// Get the audience if set - pub fn audience(&self) -> Option<&Vec> { - match self { - Content::Text(text) => text.annotations.as_ref().and_then(|a| a.audience.as_ref()), - Content::Image(image) => image.annotations.as_ref().and_then(|a| a.audience.as_ref()), - Content::Resource(resource) => resource - .annotations - .as_ref() - .and_then(|a| a.audience.as_ref()), - } - } - - /// Get the priority if set - pub fn priority(&self) -> Option { - match self { - Content::Text(text) => text.annotations.as_ref().and_then(|a| a.priority), - Content::Image(image) => image.annotations.as_ref().and_then(|a| a.priority), - Content::Resource(resource) => resource.annotations.as_ref().and_then(|a| a.priority), - } - } - - pub fn unannotated(&self) -> Self { - match self { - Content::Text(text) => Content::text(text.text.clone()), - Content::Image(image) => Content::image(image.data.clone(), image.mime_type.clone()), - Content::Resource(resource) => Content::resource(resource.resource.clone()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_content_text() { - let content = Content::text("hello"); - assert_eq!(content.as_text(), Some("hello")); - assert_eq!(content.as_image(), None); - } - - #[test] - fn test_content_image() { - let content = Content::image("data", "image/png"); - assert_eq!(content.as_text(), None); - assert_eq!(content.as_image(), Some(("data", "image/png"))); - } - - #[test] - fn test_content_annotations_basic() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), Some(0.5)); - } - - #[test] - fn test_content_annotations_order_independence() { - let content1 = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - let content2 = Content::text("hello") - .with_priority(0.5) - .with_audience(vec![Role::User]); - - assert_eq!(content1.audience(), content2.audience()); - assert_eq!(content1.priority(), content2.priority()); - } - - #[test] - fn test_content_annotations_overwrite() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5) - .with_audience(vec![Role::Assistant]) - .with_priority(0.8); - - assert_eq!(content.audience(), Some(&vec![Role::Assistant])); - assert_eq!(content.priority(), Some(0.8)); - } - - #[test] - fn test_content_annotations_image() { - let content = Content::image("data", "image/png") - .with_audience(vec![Role::User]) - .with_priority(0.5); - - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), Some(0.5)); - } - - #[test] - fn test_content_annotations_preservation() { - let text_content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - - match &text_content { - Content::Text(TextContent { annotations, .. }) => { - assert!(annotations.is_some()); - let ann = annotations.as_ref().unwrap(); - assert_eq!(ann.audience, Some(vec![Role::User])); - assert_eq!(ann.priority, Some(0.5)); - } - _ => panic!("Expected Text content"), - } - } - - #[test] - #[should_panic(expected = "Priority must be between 0.0 and 1.0")] - fn test_invalid_priority() { - Content::text("hello").with_priority(1.5); - } - - #[test] - fn test_unannotated() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - let unannotated = content.unannotated(); - assert_eq!(unannotated.audience(), None); - assert_eq!(unannotated.priority(), None); - } - - #[test] - fn test_partial_annotations() { - let content = Content::text("hello").with_priority(0.5); - assert_eq!(content.audience(), None); - assert_eq!(content.priority(), Some(0.5)); - - let content = Content::text("hello").with_audience(vec![Role::User]); - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), None); - } -} diff --git a/crates/mcp-core/src/lib.rs b/crates/mcp-core/src/lib.rs index 5a37ceea4ad2..4bd7d1ad4818 100644 --- a/crates/mcp-core/src/lib.rs +++ b/crates/mcp-core/src/lib.rs @@ -1,5 +1,3 @@ -pub mod content; -pub use content::{Annotations, Content, ImageContent, TextContent}; pub mod handler; pub mod role; pub use role::Role; diff --git a/crates/mcp-core/src/prompt.rs b/crates/mcp-core/src/prompt.rs index 4a0106e349f1..9b50396fa7a6 100644 --- a/crates/mcp-core/src/prompt.rs +++ b/crates/mcp-core/src/prompt.rs @@ -1,7 +1,6 @@ -use crate::content::{Annotations, EmbeddedResource, ImageContent}; use crate::handler::PromptError; -use crate::resource::ResourceContents; use base64::engine::{general_purpose::STANDARD as BASE64_STANDARD, Engine}; +use rmcp::model::{Annotations, EmbeddedResource, ImageContent}; use serde::{Deserialize, Serialize}; /// A prompt that can be used to generate text from a model @@ -113,8 +112,7 @@ impl PromptMessage { role, content: PromptMessageContent::Image { image: ImageContent { - data, - mime_type, + raw: rmcp::model::RawImageContent { data, mime_type }, annotations, }, }, @@ -129,7 +127,7 @@ impl PromptMessage { text: Option, annotations: Option, ) -> Self { - let resource_contents = ResourceContents::TextResourceContents { + let resource_contents = rmcp::model::ResourceContents::TextResourceContents { uri, mime_type: Some(mime_type), text: text.unwrap_or_default(), @@ -139,7 +137,9 @@ impl PromptMessage { role, content: PromptMessageContent::Resource { resource: EmbeddedResource { - resource: resource_contents, + raw: rmcp::model::RawEmbeddedResource { + resource: resource_contents, + }, annotations, }, }, diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 202d514dac76..e7d1c12f314b 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -1,11 +1,11 @@ /// The protocol messages exchanged between client and server use crate::{ - content::Content, prompt::{Prompt, PromptMessage}, resource::Resource, resource::ResourceContents, tool::Tool, }; +use rmcp::model::Content; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/mcp-core/src/resource.rs b/crates/mcp-core/src/resource.rs index ae9c06898d6f..7ad35d71d2e7 100644 --- a/crates/mcp-core/src/resource.rs +++ b/crates/mcp-core/src/resource.rs @@ -1,7 +1,7 @@ -use crate::content::Annotations; /// Resources that servers provide to clients use anyhow::{anyhow, Result}; use chrono::{DateTime, Utc}; +use rmcp::model::Annotations; use serde::{Deserialize, Serialize}; use url::Url; use utoipa::ToSchema; diff --git a/crates/mcp-core/src/role.rs b/crates/mcp-core/src/role.rs index e15740d1a98d..2d753e907226 100644 --- a/crates/mcp-core/src/role.rs +++ b/crates/mcp-core/src/role.rs @@ -1,10 +1,3 @@ -/// Roles to describe the origin/ownership of content -use serde::{Deserialize, Serialize}; -use utoipa::ToSchema; - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, -} +// passthrough, which will be deleted with the rest of the mcp-core crate after it is no longer used +// needed because it has internal references in this crate which leak out to usages used in goose etc crates +pub use rmcp::model::Role; diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index fbaf90ba74da..c6392b757f77 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.94" thiserror = "1.0" mcp-core = { path = "../mcp-core" } mcp-macros = { path = "../mcp-macros" } +rmcp = { workspace = true } serde = { version = "1.0.216", features = ["derive"] } serde_json = "1.0.133" schemars = "0.8" diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index ad23cd518b13..2d7edfe030e7 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use mcp_core::content::Content; use mcp_core::handler::{PromptError, ResourceError}; use mcp_core::prompt::{Prompt, PromptArgument}; use mcp_core::protocol::JsonRpcMessage; @@ -7,6 +6,7 @@ use mcp_core::tool::ToolAnnotations; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; +use rmcp::model::Content; use serde_json::Value; use std::{future::Future, pin::Pin, sync::Arc}; use tokio::sync::mpsc; diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index 6370bd1fb528..06c294a621a5 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -7,7 +7,6 @@ use std::{ type PromptFuture = Pin> + Send + 'static>>; use mcp_core::{ - content::Content, handler::{PromptError, ResourceError, ToolError}, prompt::{Prompt, PromptMessage, PromptMessageRole}, protocol::{ @@ -18,6 +17,7 @@ use mcp_core::{ }, ResourceContents, }; +use rmcp::model::Content; use serde_json::Value; use tokio::sync::mpsc; use tower_service::Service; diff --git a/documentation/docs/getting-started/providers.md b/documentation/docs/getting-started/providers.md index fd71cadc1132..68431109b6f9 100644 --- a/documentation/docs/getting-started/providers.md +++ b/documentation/docs/getting-started/providers.md @@ -25,7 +25,8 @@ Goose relies heavily on tool calling capabilities and currently works best with | [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. Supports both API key and Azure credential chain authentication. | `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, `AZURE_OPENAI_API_KEY` (optional) | | [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` | | [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` | -| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be [configured in advance](https://cloud.google.com/vertex-ai/docs/authentication).** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). | +| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be [configured in advance](https://cloud.google.com/vertex-ai/docs/authentication).** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optionally `GCP_MAX_RATE_LIMIT_RETRIES` (5), `GCP_MAX_OVERLOADED_RETRIES` (5), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). | +| [GitHub Copilot](https://docs.github.com/en/copilot/using-github-copilot/ai-models) | Access to GitHub Copilot's chat models including gpt-4o, o1, o3-mini, and Claude models. Uses device code authentication flow for secure access. | Uses GitHub device code authentication flow (no API key needed) | | [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` | | [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms).** | `OLLAMA_HOST` | | [Ramalama](https://ramalama.ai/) | Local model using native [OCI](https://opencontainers.org/) container runtimes, [CNCF](https://www.cncf.io/) tools, and supporting models as OCI artifacts. Ramalama API an compatible alternative to Ollama and can be used with the Goose Ollama provider. Supports Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms).** | `OLLAMA_HOST` | diff --git a/documentation/docs/guides/environment-variables.md b/documentation/docs/guides/environment-variables.md index 151e12994a20..e125195ae1de 100644 --- a/documentation/docs/guides/environment-variables.md +++ b/documentation/docs/guides/environment-variables.md @@ -108,6 +108,8 @@ These variables control how Goose manages conversation sessions and context. | `GOOSE_CONTEXT_STRATEGY` | Controls how Goose handles context limit exceeded situations | "summarize", "truncate", "clear", "prompt" | "prompt" (interactive), "summarize" (headless) | | `GOOSE_MAX_TURNS` | [Maximum number of turns](/docs/guides/smart-context-management#maximum-turns) allowed without user input | Integer (e.g., 10, 50, 100) | 1000 | | `GOOSE_CLI_THEME` | [Theme](/docs/guides/goose-cli-commands#themes) for CLI response markdown | "light", "dark", "ansi" | "dark" | +| `GOOSE_SCHEDULER_TYPE` | Controls which scheduler Goose uses for [scheduled recipes](/docs/guides/recipes/session-recipes.md#schedule-recipe) | "legacy" or "temporal" | "legacy" (Goose's built-in cron scheduler) | +| `GOOSE_TEMPORAL_BIN` | Optional custom path to your Temporal binary | /path/to/temporal-service | None | **Examples** @@ -129,6 +131,12 @@ export GOOSE_MAX_TURNS=100 # Set the ANSI theme for the session export GOOSE_CLI_THEME=ansi + +# Use Temporal for scheduled recipes +export GOOSE_SCHEDULER_TYPE=temporal + +# Custom Temporal binary (optional) +export GOOSE_TEMPORAL_BIN=/path/to/temporal-service ``` ### Model Context Limit Overrides diff --git a/documentation/docs/guides/goose-cli-commands.md b/documentation/docs/guides/goose-cli-commands.md index 5bb11cf72aeb..94473072769e 100644 --- a/documentation/docs/guides/goose-cli-commands.md +++ b/documentation/docs/guides/goose-cli-commands.md @@ -458,7 +458,7 @@ goose recipe help --- ### schedule -Automate recipes by running them on a schedule using a cron job. +Automate recipes by running them on a [schedule](/docs/guides/recipes/session-recipes.md#schedule-recipe). **Usage:** ```bash @@ -472,9 +472,13 @@ goose schedule - `sessions`: List sessions created by a scheduled recipe - `run-now`: Run a scheduled recipe immediately +Use the following commands if you're scheduling recipes using the [Temporal scheduler](https://docs.temporal.io/evaluate/development-production-features/schedules) (requires the Temporal CLI): +- `services-status`: Check if any Temporal services are running +- `services-stop`: Stop any running Temporal services + **Options:** - `--id `: A unique ID for the scheduled job (e.g. `daily-report`) -- `--cron "* * * * * *"`: Specifies when a job should run using a 6-field [cron expression](https://en.wikipedia.org/wiki/Cron#Cron_expression) represented as a string in the format "seconds minutes hours day-of-month month day-of-week" +- `--cron "* * * * * *"`: Specifies when a job should run using a [cron expression](https://en.wikipedia.org/wiki/Cron#Cron_expression) represented as a string with either 5, 6, or 7 digits in the format "seconds minutes hours day-of-month month day-of-week year" - `--recipe-source `: Path to the recipe YAML file - `--limit `: (Optional) max number of sessions to display when using the `sessions` command diff --git a/documentation/docs/guides/goose-permissions.md b/documentation/docs/guides/goose-permissions.md index 691570c62bd7..6ffbb2bf2405 100644 --- a/documentation/docs/guides/goose-permissions.md +++ b/documentation/docs/guides/goose-permissions.md @@ -6,6 +6,7 @@ sidebar_label: Goose Permissions import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { PanelLeft } from 'lucide-react'; Goose’s permissions determine how much autonomy it has when modifying files, using extensions, and performing automated actions. By selecting a permission mode, you have full control over how Goose interacts with your development environment. @@ -44,20 +45,22 @@ Here's how to configure: You can change modes before or during a session and it will take effect immediately. - + + Click the Goose Mode option from the bottom menu. - 1. Click `...` in the upper right corner - 2. Click `Settings` - 3. Under `Mode Selection`, choose the mode you'd like + 1. Click the button on the top-left to open the sidebar. + 2. Click the `Settings` button on the sidebar. + 3. Click `Chat`. + 4. Under `Mode`, choose the mode you'd like. - + To change modes mid-session, use the `/mode` command. diff --git a/documentation/docs/guides/managing-goose-sessions.md b/documentation/docs/guides/managing-goose-sessions.md index 19b154c93178..4bf62aa2e971 100644 --- a/documentation/docs/guides/managing-goose-sessions.md +++ b/documentation/docs/guides/managing-goose-sessions.md @@ -5,6 +5,7 @@ sidebar_label: Managing Sessions --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { AppWindow, PanelLeft, FolderDot, Paperclip } from 'lucide-react'; A session is a single, continuous interaction between you and Goose, providing a space to ask questions and prompt action. In this guide, we'll cover how to start, exit, and resume a session. @@ -14,13 +15,34 @@ A session is a single, continuous interaction between you and Goose, providing a - After choosing an LLM provider, you'll see the session interface ready for use. Type your questions, tasks, or instructions directly into the input field, and Goose will immediately get to work. + After choosing an LLM provider, you'll see the session interface ready for use. Type your questions, tasks, or instructions directly into the input field, and Goose will immediately get to work. You can start a new session in the same directory or in a different directory. - To start a new session at any time, click the gear icon `⚙️` in the top-right corner of the application and select **New session** from the dropdown menu. + + - To start a new session in a different directory, click the current directory displayed at the top of the application, or click the gear icon `⚙️` in the top-right corner of the application and select **Open directory** from the dropdown menu. You can select an existing folder or create a new folder. + To start a session in the same window: + 1. Click the button in the top-left to open the sidebar + 2. Click `Home` in the sidebar + 3. Send your first prompt from the chat box + To start a session in a new window: + 1. Click the button in the top-left + 2. In the new Goose window, send your first prompt from the chat box + + + + + 1. Click the directory switcher at the bottom of the app + 2. Navigate to the new directory or create a new folder + 3. Click `Open` to open a new Goose window for the selected directory + 4. Send your first prompt from the chat box + + + + + :::tip On macOS, you can drag and drop a folder onto the Goose icon in the dock to open a new session in that directory. + ::: You can also use keyboard shortcuts to start a new session or bring focus to open Goose windows. @@ -51,7 +73,7 @@ If this is your first session, Goose will prompt you for an API key to access an ## Name Session - Within the Desktop app, sessions are automatically named using the current timestamp in the format `YYYYMMDD_HHMMSS`. Goose also provides a description of the session based on context. + Within the Desktop app, sessions are automatically named based on the context of your initial prompt. By default, Goose names your session using the current timestamp in the format `YYYYMMDD_HHMMSS`. If you'd like to provide a specific name, this is where you'd do so. For example to name your session `react-migration`, you would run: @@ -86,10 +108,17 @@ Note that sessions are automatically saved when you exit. - 1. Click `⚙️` in the upper right corner - 2. Click `Session history` - 3. Click a session - 4. Click `Resume Session` in the upper right corner + 1. Click the button in the top-left to open the sidebar + 2. Click `History` in the sidebar + 3. Click the session you'd like to resume + 4. Choose how to resume: + - Click `Resume` to continue in the current window + - Click `New Window` to open in a new window + + :::tip + You can also quickly resume one of your three most recent sessions by clicking it in the `Recent chats` section on the `Home` page. + ::: + To resume your latest session, you can run the following command: @@ -121,8 +150,8 @@ Note that sessions are automatically saved when you exit. In Goose Desktop, you can search session metadata including the description, filename, and working directory path. The search is text-based and supports case-sensitive matching, but doesn't search session content or support regex patterns. - 1. Click `⚙️` in the upper right corner - 2. Click `Session history` + 1. Click the button in the top-left to open the sidebar + 2. Click `History` in the sidebar 3. Use `Cmd+F` to open the search bar 4. Enter your search term 5. Use search features to refine and navigate results @@ -167,72 +196,46 @@ Note that sessions are automatically saved when you exit. ### Resume Session Across Interfaces -You can resume a CLI session in Desktop and vice versa. +You can resume a CLI session in Desktop. All saved sessions are listed in the Desktop app, even CLI sessions. To resume a CLI session within the Desktop: - 1. Click `⚙️` in the upper right corner - 2. Click `Session history` + 1. Click the button in the top-left to open the sidebar + 2. Click `History` in the sidebar 3. Click the session you'd like to resume + 4. Choose how to resume: + - Click `Resume` to continue in the current window + - Click `New Window` to open in a new window - :::tip - If you named the session, you'll recognize the filename. However, if you don't remember the exact session name, there is a description of the topic. - ::: - - 4. Click `Resume Session` in the upper right corner - - :::note Example - - **CLI Command** - - ```sh - goose session -n react-migration - ``` - - **Desktop Session** - - | Session Description | Session Filename | - |-------------------------|------------------------------| - | Code Migration to React | **react-migration**.jsonl | - - - - ::: - To resume a Desktop session within CLI, get the name of the session from the Desktop app. Note that unless you specifically named the session, its default name is a timestamp in the format `YYYYMMDD_HHMMSS`. - - 1. Open Goose Desktop - 2. Click `⚙️` in the upper right corner - 3. Click `Session history` - 4. Find the session that you want to resume, and copy the basename (without the `.jsonl` extension). - :::note Example - - **Desktop Session** - - | Session Description | Session Filename | - |------------------------|------------------------------| - | GitHub PR Access Issue | **20250305_113223**.jsonl | - - - **CLI Command** - ```sh - goose session -r --name 20250305_113223 - ``` - ::: - + Currently, you cannot resume a Desktop session within the CLI. ## Project-Based Sessions -You can also use the [`project`](/docs/guides/goose-cli-commands#project) and [`projects`](/docs/guides/goose-cli-commands#projects) commands to start or resume sessions from a project, which is a tracked working directory with session metadata. For a complete guide to using Projects, see [Managing Projects Guide](/docs/guides/managing-projects). + + + Project-based sessions are only available through the CLI. + + + You can use the [`project`](/docs/guides/goose-cli-commands#project) and [`projects`](/docs/guides/goose-cli-commands#projects) commands to start or resume sessions from a project, which is a tracked working directory with session metadata. For a complete guide to using Projects, see [Managing Projects Guide](/docs/guides/managing-projects). + + ## Remove Sessions -You can remove sessions using CLI commands. For detailed instructions on session removal, see the [CLI Commands documentation](/docs/guides/goose-cli-commands#session-remove-options). + + + Removing sessions is only available through the CLI. + + + You can remove sessions using CLI commands. For detailed instructions on session removal, see the [CLI Commands documentation](/docs/guides/goose-cli-commands#session-remove-options). + + ## Export Sessions @@ -240,15 +243,7 @@ Export sessions to Markdown to share with your team, create documentation, archi - Session export is currently only available through the CLI. You can export Desktop sessions via their session filename: - - 1. Click `⚙️` in the upper right corner - 2. Click `Session history` - 3. Note the session filename (e.g., `20250305_113223.jsonl`) - 4. In the CLI, run the command: - ```bash - goose session export --name 20250305_113223 --output my-session.md - ``` + Session export is currently only available through the CLI. Export sessions using the `export` subcommand: @@ -258,28 +253,30 @@ Export sessions to Markdown to share with your team, create documentation, archi goose session export ``` + For more details on export options, available flags, and output formats, see the [CLI commands documentation](/docs/guides/goose-cli-commands#session-export-options). -For more details on export options, available flags, and output formats, see the [CLI commands documentation](/docs/guides/goose-cli-commands#session-export-options). - ## Voice Dictation Speak to Goose directly instead of typing your prompts. To enable voice dictation: - 1. Click `⚙️` in the upper right corner - 2. Click **Advanced Settings** - 3. Scroll to **Voice Dictation** - 4. Toggle **Enable Voice Dictation** on - 5. Choose between **OpenAI Whisper** or **ElevenLabs** as your dictation provider + 1. Click the button in the top-left to open the sidebar + 2. Click `Settings` in the sidebar + 3. Click `Chat` + 4. Under `Voice Dictation`, toggle `Enable Voice Dictation` on + 5. Choose between `OpenAI Whisper` or `ElevenLabs` as your dictation provider 6. Enter your API key for the provider you chose - 7. Click the back button at the top of the settings screen to return to the chat. A microphone icon appears to the right of the text input field. - **If you don't see the microphone icon**, check the [models you have configured](/docs/getting-started/providers.md). ElevenLabs can be used as a dictation provider alongside any LLM, but OpenAI Whisper requires that you have an OpenAI model configured in Goose, even if using another LLM provider for chat. + To use voice dictation: + 1. Return to the chat interface (click `Chat` in the sidebar) + 2. Click the microphone on the right of the chat box and begin speaking + + The first time you use voice dictation, Goose will request access to your microphone. While recording, you'll see a live waveform of your audio in the input field, a timer, and the current size of your recording. Click the microphone button again to finish recording. - To start using voice dictation, click the microphone icon and begin speaking. The first time, Goose will request access to your microphone. While recording, you'll see a live waveform of your audio in the input field, a timer, and the current size of your recording. Click the microphone button again to finish recording. + **If you don't see the microphone**, check the [models you have configured](/docs/getting-started/providers.md). ElevenLabs can be used as a dictation provider alongside any LLM, but OpenAI Whisper requires that you have an OpenAI model configured in Goose, even if using another LLM provider for chat. #### Important Notes * You can record up to 10 minutes or 25MB of audio. @@ -331,13 +328,15 @@ Search allows you to find specific content within your current session. The sear - Share files with Goose in three ways: + Share files with Goose in several ways: 1. **Drag and Drop**: Simply drag files from your computer's file explorer/finder and drop them anywhere in the chat window. The file paths will be automatically added to your message. - 2. **File Browser**: Click the paperclip icon (📎) in the bottom left corner of the chat input to open your system's file browser and select files. + 2. **File Browser**: Click the button at the bottom of the app to open your system's file browser and select files. 3. **Manual Path**: Type or paste the file path directly into the chat input. + + 4. **Quick File Search**: Use the [`@` shortcut key](/docs/guides/file-management#quick-file-search-in-goose-desktop) to quickly find and include files. You can reference files by their paths directly in your messages. Since you're already in a terminal, you can use standard shell commands to help with file paths: diff --git a/documentation/docs/guides/managing-tools/adjust-tool-output.md b/documentation/docs/guides/managing-tools/adjust-tool-output.md index 3df0bc65ab86..e7a708665fe4 100644 --- a/documentation/docs/guides/managing-tools/adjust-tool-output.md +++ b/documentation/docs/guides/managing-tools/adjust-tool-output.md @@ -5,15 +5,17 @@ sidebar_label: Adjust Tool Output --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { PanelLeft } from 'lucide-react'; Response Styles customize how tool interactions are displayed in the Goose Desktop chat window. To change this setting: -1. Click the menu icon in the top right corner. -2. Select `Advanced Settings` from the menu. -3. Under `Response Styles`, select either `Detailed` or `Concise`. +1. Click the button on the top-left to open the sidebar. +2. Click the `Settings` button on the sidebar. +3. Click `Chat`. +4. Under `Response Styles`, select either `Detailed` or `Concise`. - **Concise** (Default) - Tool calls are collapsed by default diff --git a/documentation/docs/guides/managing-tools/tool-permissions.md b/documentation/docs/guides/managing-tools/tool-permissions.md index 03c1a9d40abb..7d33211de22e 100644 --- a/documentation/docs/guides/managing-tools/tool-permissions.md +++ b/documentation/docs/guides/managing-tools/tool-permissions.md @@ -6,6 +6,7 @@ sidebar_label: Tool Permissions import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { PanelLeft, Tornado, Settings } from 'lucide-react'; Tool permissions provide fine-grained control over how Goose uses different tools within extensions. This guide will help you understand and configure these permissions effectively. @@ -27,6 +28,8 @@ Goose performs best with fewer than 25 total tools enabled across all extensions ## Permission Levels +Tool permissions work alongside [Goose permission modes](/docs/guides/goose-permissions). The mode sets the default behavior, while tool permissions let you override the behavior of specific tools. + Each tool can be set to one of three permission levels: | Permission Level | Description | Best For | Examples | @@ -35,36 +38,31 @@ Each tool can be set to one of three permission levels: | **Ask Before** | Requires confirmation | State-changing operations | • File writing/editing

• System commands

• Resource creation | | **Never Allow** | Tool cannot be used | Sensitive operations | • Credential access

• System-critical files

• Resource deletion | -:::info -Tool permissions work alongside [Goose Permission Modes](/docs/guides/goose-permissions). The mode sets default behavior, while tool permissions let you override specific tools. -::: - ## Configuring Tool Permissions - - You can configure tool permissions through either Manual or Smart Approval modes: - - - - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Mode Selection`, choose `Manual Approval` - 4. Click on an extension name - 5. Use the dropdown next to each tool to set its permission level + You can configure fine-grained tool permissions for enabled extensions when using `Manual` or `Smart` approval mode. These rules can be accessed from the mode toggle or `Settings` page. + + + + 1. Click the button at the bottom of the app + 2. Click the button next to your selected `Manual` or `Smart` mode + 3. Click the extension whose tools you want to configure + 4. Use the dropdown next to each tool to set its permission level + 5. Click `Save Changes` - - :::tip - In Smart Approval mode, Goose will automatically detect and allow read-only operations while requiring approval for state-changing actions. - ::: - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Mode Selection`, choose `Smart Approval` - 4. Click on an extension name - 5. Use the dropdown next to each tool to set its permission level + + 1. Click the button in the top-left to open the sidebar + 2. Click the `Settings` button on the sidebar + 3. Click `Chat` + 4. Under `Mode`, click the button next to your selected `Manual` or `Smart` mode + 5. Click the extension whose tools you want to configure + 6. Use the dropdown next to each tool to set its permission level + 7. Click `Save Changes` - + + diff --git a/documentation/docs/guides/managing-tools/tool-router.md b/documentation/docs/guides/managing-tools/tool-router.md index 17d3aef9906c..33b94713ce29 100644 --- a/documentation/docs/guides/managing-tools/tool-router.md +++ b/documentation/docs/guides/managing-tools/tool-router.md @@ -7,6 +7,7 @@ description: Configure smart tool selection to load only relevant tools, improvi import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { PanelLeft } from 'lucide-react'; :::info Preview Feature The Tool Selection Strategy is currently in preview. The Vector selection strategy is currently limited to Claude models served on Databricks. @@ -70,9 +71,10 @@ The LLM-based strategy leverages natural language understanding to analyze tools - 1. Click the gear icon ⚙️ on the top toolbar - 2. Click `Advanced settings` - 3. Under `Tool Selection Strategy`, select your preferred strategy: + 1. Click the button in the top-left to open the sidebar + 2. Click the `Settings` button on the sidebar + 3. Click `Chat` + 4. Under `Tool Selection Strategy`, select your preferred strategy: - `Default` - `Vector` - `LLM-based` diff --git a/documentation/docs/guides/recipes/session-recipes.md b/documentation/docs/guides/recipes/session-recipes.md index c1d6757c05a5..b9426181825d 100644 --- a/documentation/docs/guides/recipes/session-recipes.md +++ b/documentation/docs/guides/recipes/session-recipes.md @@ -6,6 +6,7 @@ description: "Share a Goose session setup (including tools, goals, and instructi import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import { PanelLeft, Bot } from 'lucide-react'; Sometimes you finish a task in Goose and realize, "Hey, this setup could be useful again." Maybe you have curated a great combination of tools, defined a clear goal, and want to preserve that flow. Or maybe you're trying to help someone else replicate what you just did without walking them through it step by step. @@ -16,15 +17,36 @@ You can turn your current Goose session into a reusable recipe that includes the - 1. While in the session you want to save as a recipe, click the gear icon `⚙️` in the top right corner - 2. Select **Make recipe from this session** - 3. In the dialog that appears: - - Provide a **title** for the recipe - - Provide a **description** - - A set of **instructions** will also be automatically generated. Review and edit as needed. - - Provide an optional initial **prompt** to display in the chat box. - - Some **activities** will be automatically generated. Add or remove as needed. - 4. Copy the recipe URL and use it however you like (e.g., share it with teammates, drop it in documentation, or keep it for yourself) + Create a recipe from the current session or from a template. + + + + 1. While in the session you want to save as a recipe, click the button at the bottom of the app + 2. Click `Create a recipe from this session` + 3. A dialog opens with automatically generated instructions and activities: + - Provide a **title** and **description** for the recipe + - Review the **instructions** and edit them as needed + - Provide an optional **initial prompt** to display in the chat box + - Add or remove optional **activities** to display as buttons + 4. When you're finished, you can: + - Copy the recipe link to share the recipe with others or [open it from the link](#use-recipe) + - Click `Save Recipe` to [save the recipe](/docs/guides/recipes/storing-recipes) locally + - Click `Create Schedule` to [schedule the recipe](#schedule-recipe) + + + 1. Click the button in the top-left to open the sidebar + 2. Click the `Recipes` button in the sidebar + 3. Click `Create Recipe` + 4. A dialog opens with placeholder content: + - Provide a **title** and **description** for the recipe + - Edit the **instructions** as needed + - Provide an optional **initial prompt** to display in the chat box + - Add or remove optional **activities** to display as buttons + - Provide a **recipe name** + - Choose to [save the recipe](/docs/guides/recipes/storing-recipes) with **global** or **directory** availability + 5. Click `Create Recipe` + + :::warning You cannot create a recipe from an existing recipe session, but you can view or [edit the recipe](#edit-recipe). @@ -162,16 +184,18 @@ You can turn your current Goose session into a reusable recipe that includes the - 1. While in the session created from a recipe, click the gear icon `⚙️` in the top right corner - 2. Select **View recipe** - 3. In the dialog that appears, you can edit the: + 1. While in the session that's using the recipe, click the button at the bottom of the app + 2. Click `View recipe` + 3. Edit any of the following: - Title - Description - Instructions - Initial prompt - Activities - 4. Copy the new recipe URL. The original recipe and your current session are not affected by your edits. - 5. Use and share the URL for your new recipe. + 4. When you're finished, you can: + - Copy the recipe link to share the recipe with others or [open it from the link](#use-recipe) + - Click `Save Recipe` to [save the recipe](/docs/guides/recipes/storing-recipes) locally + - Click `Create Schedule` to [schedule the recipe](#schedule-recipe) @@ -186,19 +210,25 @@ You can turn your current Goose session into a reusable recipe that includes the - 1. Open the recipe using a direct link or manual URL entry: + 1. Open the recipe using a direct link or manual URL entry, or from your Recipe library: + + **Direct Link:** + + 1. Click a recipe link shared with you - - **Direct Link** - - Click a recipe link shared with you - - The recipe will automatically open in Goose Desktop + **Manual URL Entry:** + + 1. Paste a recipe link into your browser's address bar + 2. Press `Enter` and click the `Open Goose.app` prompt + + **Recipe Library:** - - **Manual URL Entry** - - Copy a recipe URL - - Paste it into your browser's address bar - - You will see a prompt to "Open Goose" - - Goose Desktop will open with the recipe + 1. Click the button in the top-left to open the sidebar + 2. Click `Recipes` in the sidebar + 3. Find your recipe in the Recipe Library + 4. Click `Use` next to the recipe you want to open - 2. If the recipe contains parameters, enter your values in the **Recipe Parameters** dialog and click **Start Recipe**. + 2. If the recipe contains parameters, enter your values in the `Recipe Parameters` dialog and click `Start Recipe`. Parameters are dynamic values used in the recipe: @@ -223,7 +253,7 @@ You can turn your current Goose session into a reusable recipe that includes the Using a recipe with the Goose CLI might involve the following tasks: - [Configuring your recipe location](#configure-recipe-location) - [Running a recipe](#run-a-recipe) - - [Scheduling a recipe](#schedule-a-recipe) + - [Scheduling a recipe](#schedule-recipe) #### Configure Recipe Location @@ -383,19 +413,6 @@ You can turn your current Goose session into a reusable recipe that includes the - Your session won't affect the original recipe creator's session ::: - #### Schedule a Recipe - Automate Goose recipes by running them on a schedule. - - **Create a schedule** - Create a scheduled cron job that runs the recipe on the specified cadence. - - ```bash - # Add a new scheduled recipe which runs every day at 9 AM - goose schedule add --id daily-report --cron "0 0 9 * * *" --recipe-source ./recipes/daily-report.yaml - ``` - The [cron expression](https://en.wikipedia.org/wiki/Cron#Cron_expression) follows the format "seconds minutes hours day-of-month month day-of-week". - - See the [`schedule` command documentation](/docs/guides/goose-cli-commands#schedule) for detailed examples and options. - @@ -434,6 +451,12 @@ You can turn your current Goose session into a reusable recipe that includes the Share your recipe with Desktop users by copying the recipe URL from the recipe creation dialog. When someone clicks the URL, it will open Goose Desktop with your recipe configuration. + To copy the recipe URL: + 1. [Open the recipe](#use-recipe) + 2. Click the button at the bottom of the app + 3. Click `View recipe` + 4. Scroll down and copy the link + Share your recipe with CLI users by directly sending them the recipe file or converting it to a shareable [deep link](/docs/guides/goose-cli-commands#recipe) for Desktop users: @@ -445,6 +468,49 @@ You can turn your current Goose session into a reusable recipe that includes the +## Schedule Recipe + + +Automate Goose recipes by running them on a schedule. + + 1. Click the button in the top-left to open the sidebar + 2. Click `Scheduler` + 3. Click `Create Schedule` + 3. In the dialog that appears: + - Provide a **name** for the schedule + - Select the **source** of your recipe. This can be either a `yaml` file or link generated by Goose Desktop. + - Select whether you want your recipe to run in the background or foreground **execution mode**. Recipes run in the background don't open a window, but the session results are saved. Recipes run in the foreground will open a window if the Goose Desktop app is running. Otherwise, the recipe runs in the background. + - Choose the **frequency** and **time** to run your recipe. Your selected frequency (e.g. every 20 minutes, weekly at 10 AM on Friday) is converted into a [cron expression](https://en.wikipedia.org/wiki/Cron#Cron_expression) used by Goose. + - Click `Create Schedule` + + Your new scheduled recipe is listed in the `Scheduler` page. Click on the schedule to view details, see when it was last run, and perform actions with the scheduled recipe: + - `Run Schedule Now` to trigger the recipe manually + - `Edit Schedule` to change the scheduled frequency + - `Pause Schedule` to stop the recipe from running automatically. + + At the bottom of the `Schedule Details` page you can view the list of sessions created by the scheduled recipe and open or restore each session. + + + + Automate Goose recipes by scheduling them to run with a [cron expression](https://en.wikipedia.org/wiki/Cron#Cron_expression). + + ```bash + # Add a new scheduled recipe which runs every day at 9 AM + goose schedule add --id daily-report --cron "0 0 9 * * *" --recipe-source ./recipes/daily-report.yaml + ``` + You can use either a 5, 6, or 7-digit cron expression for full scheduling precision, following the format "seconds minutes hours day-of-month month day-of-week year". + + See the [`schedule` command documentation](/docs/guides/goose-cli-commands.md#schedule) for detailed examples and options. + +When scheduling Goose recipes with the CLI, you can use Goose's built-in cron scheduler (default), or the [Temporal scheduler](https://docs.temporal.io/evaluate/development-production-features/schedules) (requires the Temporal CLI). Switch from the default legacy scheduler by setting the `GOOSE_SCHEDULER_TYPE` [environment variable](/docs/guides/environment-variables.md#session-management): + + ```bash + export GOOSE_SCHEDULER_TYPE=temporal + ``` + Use Temporal scheduling if you want an advanced workflow engine with monitoring features. The scheduling engines do not share schedules, so schedules created with the legacy Goose scheduler cannot be run with the Temporal scheduler, and vice-versa. + + + ## Core Components A recipe needs these core components: diff --git a/documentation/docs/guides/using-goosehints.md b/documentation/docs/guides/using-goosehints.md index 89f70e1eedc7..305f5e938a18 100644 --- a/documentation/docs/guides/using-goosehints.md +++ b/documentation/docs/guides/using-goosehints.md @@ -27,7 +27,7 @@ Goose supports two types of hint files: You can use both global and local hints at the same time. When both exist, Goose will consider both your global preferences and project-specific requirements. If the instructions in your local hints file conflict with your global preferences, Goose will prioritize the local hints. ::: - + #### Global hints file @@ -35,14 +35,16 @@ You can use both global and local hints at the same time. When both exist, Goose #### Local hints file - 1. Change the directory to where you'd like to set up the file. You can do this by clicking the directory path in the lower left bottom of the screen - 2. Click `...` in the upper right corner - 3. Click `Configure .goosehints` - 4. Enter your local tips into the textarea - 5. Click `Save` + 1. Change the directory to where you'd like to set up the file. You can do this by clicking the directory path on the bottom of the Goose window. + 2. Click the `Configure goosehints` icon on the bottom right of the Goose window. + 4. Enter your local tips into the text area. + 5. Click `Save`. + 6. Restart your session so Goose can read the updated `.goosehints`. + + If a `.goosehints` file already exists in the given directory, you can edit or add to it from this screen. :::tip - You may have to adjust the screen size to fully see the Save and Cancel buttons. + You may have to scroll or adjust the screen size to fully see the Save and Cancel buttons. ::: diff --git a/documentation/docs/mcp/agentql-mcp.md b/documentation/docs/mcp/agentql-mcp.md index 09c27879074f..a5d5e8a7e866 100644 --- a/documentation/docs/mcp/agentql-mcp.md +++ b/documentation/docs/mcp/agentql-mcp.md @@ -5,6 +5,7 @@ description: Add AgentQL MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -37,11 +38,18 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=agentql-mcp&id=agentql&name=AgentQL&description=Transform%20unstructured%20web%20content%20into%20structured%20data&env=AGENTQL_API_KEY%3DAgentQL%20API%20Key) - 2. Press `Yes` to confirm the installation - 3. Obtain an [AGENTQL_API_KEY](https://dev.agentql.com/api-keys) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/alby-mcp.md b/documentation/docs/mcp/alby-mcp.md index 8dbf1c9fe6aa..42cbfc5219c0 100644 --- a/documentation/docs/mcp/alby-mcp.md +++ b/documentation/docs/mcp/alby-mcp.md @@ -7,6 +7,8 @@ description: Connect Goose to your Bitcoin Lightning Wallet import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import CLIExtensionInstructions from '@site/src/components/CLIExtensionInstructions'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; +import { PanelLeft } from 'lucide-react'; @@ -47,15 +49,21 @@ You'll need [Node.js](https://nodejs.org/) installed on your system to run this - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40getalby%2Fmcp&id=alby&name=Alby&description=Connect%20Goose%20to%20your%20Bitcoin%20Lightning%20Wallet&env=NWC_CONNECTION_STRING%3DNWC%20Connection%20Secret) - 2. Press `Yes` to confirm the installation - 3. Obtain a NWC Connection secret from your lightning wallet (`nostr+walletconnect://...`) and paste it in to the "NWC Connection Secret" field - 4. Click `Add Extension` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40getalby%2Fmcp&id=alby&name=Alby&description=Connect%20Goose%20to%20your%20Bitcoin%20Lightning%20Wallet) - 2. Press `Yes` to confirm the installation + 2. Press `OK` to confirm the installation 3. Change the type to "Streamable HTTP" 4. Change the endpoint to `https://mcp.getalby.com/mcp` 5. Add a request header with Header name = `Authorization` and Value: @@ -64,8 +72,9 @@ Bearer nostr+walletconnect://... ``` 6. Press the `+Add` button to finish adding the request header - 7. Click `Add Extension` - 8. Scroll to the top and click `Exit` from the upper left corner + 7. Press `Add Extension` + 8. Click the button in the top-left to open the sidebar + 9. Navigate to the chat diff --git a/documentation/docs/mcp/asana-mcp.md b/documentation/docs/mcp/asana-mcp.md index d4a1c7a7a16c..e28fa7b0f1eb 100644 --- a/documentation/docs/mcp/asana-mcp.md +++ b/documentation/docs/mcp/asana-mcp.md @@ -6,6 +6,7 @@ description: Add Asana MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -39,14 +40,21 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40roychri%2Fmcp-server-asana&id=asana&name=Asana&description=enable%20task%20automation%2C%20project%20tracking%2C%20and%20team%20collaboration&env=ASANA_ACCESS_TOKEN%3DAsana%20Access%20Token) - 2. Press `Yes` to confirm the installation - 3. Obtain a [Asana Access Token](https://app.asana.com/0/my-apps) and paste it in + :::info See [Asana's developer docs](https://developers.asana.com/docs/personal-access-token) if you need detailed instructions on creating an access token. ::: - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner 1. Run the `configure` command: diff --git a/documentation/docs/mcp/blender-mcp.md b/documentation/docs/mcp/blender-mcp.md index 3e154076630e..9cf530726233 100644 --- a/documentation/docs/mcp/blender-mcp.md +++ b/documentation/docs/mcp/blender-mcp.md @@ -6,6 +6,7 @@ description: Add Blender MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -51,10 +52,13 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=uvx&arg=blender-mcp&id=blender&name=Blender&description=Blender%203D%20scene%20creation%20integration) - 2. Press `Yes` to confirm the installation - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/brave-mcp.md b/documentation/docs/mcp/brave-mcp.md index a8524dac2d15..d502b960034b 100644 --- a/documentation/docs/mcp/brave-mcp.md +++ b/documentation/docs/mcp/brave-mcp.md @@ -6,6 +6,7 @@ description: Add Brave Search API as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -37,11 +38,16 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-brave-search&id=brave-search&name=Brave%20Search&description=Brave%20Search%20API&env=BRAVE_API_KEY%3DYour%20API%20Key) - 2. Press `Yes` to confirm the installation - 3. Get your [Brave Search API Key](https://api-dashboard.search.brave.com/app/keys) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/browserbase-mcp.md b/documentation/docs/mcp/browserbase-mcp.md index 1a844edcf4c9..8c62bc745f20 100644 --- a/documentation/docs/mcp/browserbase-mcp.md +++ b/documentation/docs/mcp/browserbase-mcp.md @@ -5,6 +5,7 @@ description: Add Browserbase MCP Server as a Goose Extension for Web Automation import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; This tutorial covers how to add the Browserbase MCP Server as a Goose extension for browser automation, enabling programmatic control over navigation, page interactions, and content capture. @@ -32,11 +33,19 @@ This tutorial covers how to add the Browserbase MCP Server as a Goose extension - 1. [Launch the installer](goose://extension?cmd=npx&arg=@browserbasehq/mcp&id=browserbase&name=Browserbase&description=Automate%20web%20browsing%20and%20data%20extraction&env=BROWSERBASE_PROJECT_ID%3DBrowserbase%20Project%20ID&env=BROWSERBASE_API_KEY%3DBrowserbase%20API%20Key) - 2. Press `Yes` to confirm the installation - 3. Obtain your [Browserbase Project ID and Browserbase API Key](https://browserbase.io/dashboard) and paste them in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/cloudflare-mcp.md b/documentation/docs/mcp/cloudflare-mcp.md new file mode 100644 index 000000000000..5e05d9482d0a --- /dev/null +++ b/documentation/docs/mcp/cloudflare-mcp.md @@ -0,0 +1,293 @@ +--- +title: Cloudflare MCP Server +description: Add Cloudflare MCP Servers as Goose Extensions +unlisted: true +--- + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +This tutorial covers how to add [Cloudflare's MCP Servers](https://github.com/cloudflare/mcp-server-cloudflare) as Goose extensions to manage your Cloudflare infrastructure, debug applications, analyze traffic, and more using natural language. + +Cloudflare provides multiple specialized MCP servers for different aspects of their platform, allowing you to interact with Workers, DNS, security features, analytics, and development tools. + +:::tip TLDR + + + + [Launch the installer](goose://extension?cmd=npx&arg=mcp-remote&arg=https%3A%2F%2Fobservability.mcp.cloudflare.com%2Fsse&id=cloudflare-observability&name=Cloudflare%20Observability&description=Debug%20and%20get%20insight%20into%20your%20application%27s%20logs%20and%20analytics&env=CLOUDFLARE_API_TOKEN%3DCloudflare%20API%20Token) + + + **Command** + ```sh + npx mcp-remote https://observability.mcp.cloudflare.com/sse + ``` + + + **Environment Variable** + ``` + CLOUDFLARE_API_TOKEN: Your Cloudflare API token with appropriate permissions + ``` +::: + +## Available Cloudflare MCP Servers + +Cloudflare provides multiple specialized MCP servers for different use cases: + +| Server | Description | Use Cases | +|--------|-------------|-----------| +| **Documentation** | Get up-to-date reference information on Cloudflare | API reference, feature documentation, troubleshooting guides | +| **Workers Bindings** | Build Workers applications with storage, AI, and compute primitives | KV storage, R2 buckets, AI models, Durable Objects | +| **Workers Builds** | Get insights and manage your Cloudflare Workers builds | Deployment status, build logs, version management | +| **Observability** | Debug and get insight into your application's logs and analytics | Error tracking, performance monitoring, request analysis | +| **Radar** | Global Internet traffic insights, trends, URL scans, and utilities | Traffic analysis, threat intelligence, URL scanning | +| **Container** | Spin up sandbox development environments | Isolated testing, development containers | +| **Browser Rendering** | Fetch web pages, convert to markdown, take screenshots | Web scraping, content analysis, visual testing | +| **Logpush** | Get quick summaries for Logpush job health | Log management, data pipeline monitoring | +| **AI Gateway** | Search logs, get details about prompts and responses | AI usage analytics, prompt optimization | +| **AutoRAG** | List and search documents on your AutoRAGs | Document retrieval, knowledge base management | +| **Audit Logs** | Query audit logs and generate reports for review | Security monitoring, compliance reporting | +| **DNS Analytics** | Optimize DNS performance and debug issues | DNS troubleshooting, performance optimization | +| **Digital Experience Monitoring** | Get insight on critical applications for your organization | Application performance, user experience monitoring | +| **Cloudflare One CASB** | Identify security misconfigurations for SaaS applications | Security posture, compliance checking | +| **GraphQL** | Get analytics data using Cloudflare's GraphQL API | Custom analytics, data visualization | + +## Prerequisites + +- A [Cloudflare account](https://dash.cloudflare.com/sign-up) +- [Cloudflare API Token](https://dash.cloudflare.com/profile/api-tokens) with appropriate permissions +- Node.js installed (for `npx` command) + +## Configuration + +### Step 1: Create API Token + +1. Go to [Cloudflare API Tokens](https://dash.cloudflare.com/profile/api-tokens) +2. Click **"Create Token"** +3. Choose **"Custom token"** for specific permissions or **"Global API Key"** for full access +4. Configure permissions based on which MCP servers you plan to use: + - **Zone:Read** - For DNS, analytics, and general zone information + - **Zone:Edit** - For making configuration changes + - **Account:Read** - For account-level resources + - **Workers:Read/Edit** - For Workers-related servers + - **Logs:Read** - For observability and audit logs + +### Step 2: Add MCP Server to Goose + +Choose one or more servers based on your needs. Here are the most popular configurations: + +#### Observability Server (Recommended for debugging) + + + + 1. [Launch the installer](goose://extension?cmd=npx&arg=mcp-remote&arg=https%3A%2F%2Fobservability.mcp.cloudflare.com%2Fsse&id=cloudflare-observability&name=Cloudflare%20Observability&description=Debug%20and%20get%20insight%20into%20your%20application%27s%20logs%20and%20analytics&env=CLOUDFLARE_API_TOKEN%3DCloudflare%20API%20Token) + 2. Press `Yes` to confirm the installation + 3. Enter your Cloudflare API Token + 4. Click `Save Configuration` + 5. Scroll to the top and click `Exit` from the upper left corner + + + 1. Run the `configure` command: + ```sh + goose configure + ``` + + 2. Choose to add a `Command-line Extension` + 3. Give your extension a name: `cloudflare-observability` + 4. Enter the command: `npx mcp-remote https://observability.mcp.cloudflare.com/sse` + 5. Set timeout: `300` seconds + 6. Add environment variable: + - Name: `CLOUDFLARE_API_TOKEN` + - Value: Your Cloudflare API token + + + + +#### Workers Bindings Server (For Workers development) + + + + [Launch the installer](goose://extension?cmd=npx&arg=mcp-remote&arg=https%3A%2F%2Fbindings.mcp.cloudflare.com%2Fsse&id=cloudflare-bindings&name=Cloudflare%20Workers%20Bindings&description=Build%20Workers%20applications%20with%20storage%2C%20AI%2C%20and%20compute%20primitives&env=CLOUDFLARE_API_TOKEN%3DCloudflare%20API%20Token) + + + Command: `npx mcp-remote https://bindings.mcp.cloudflare.com/sse` + + + +#### Radar Server (For traffic insights) + + + + [Launch the installer](goose://extension?cmd=npx&arg=mcp-remote&arg=https%3A%2F%2Fradar.mcp.cloudflare.com%2Fsse&id=cloudflare-radar&name=Cloudflare%20Radar&description=Get%20global%20Internet%20traffic%20insights%2C%20trends%2C%20URL%20scans%2C%20and%20other%20utilities&env=CLOUDFLARE_API_TOKEN%3DCloudflare%20API%20Token) + + + Command: `npx mcp-remote https://radar.mcp.cloudflare.com/sse` + + + +## Example Usage + +### Debugging Application Issues + +Let's use the Observability server to debug performance issues with a Workers application: + +:::info LLM +Anthropic's Claude 3.5 Sonnet was used for this task. +::: + +#### Goose Prompt +``` +I'm seeing high error rates on my Workers application "my-api-worker". Can you help me: +1. Check the recent error logs +2. Identify the most common error types +3. Analyze the performance metrics +4. Suggest optimizations based on the data +``` + +#### Expected Goose Output +``` +I'll help you debug the issues with your "my-api-worker" application. Let me gather the observability data to identify the problems. + +First, let me check the recent error logs for your worker: + +[Goose uses the observability MCP server to fetch error logs] + +I found several issues with your worker: + +**Error Analysis:** +- 45% of errors are 500 Internal Server Errors +- 30% are timeout errors (exceeded 30s CPU limit) +- 25% are memory limit exceeded errors + +**Performance Metrics:** +- Average response time: 2.3s (up 150% from last week) +- CPU usage: 85% average (critical threshold) +- Memory usage: 92% average (near limit) + +**Root Causes Identified:** +1. Database connection pooling issues causing timeouts +2. Memory leaks in JSON parsing for large payloads +3. Inefficient regex operations in request validation + +**Recommended Optimizations:** +1. Implement connection pooling with proper cleanup +2. Stream large JSON payloads instead of loading into memory +3. Replace regex with more efficient string operations +4. Add caching for frequently accessed data + +Would you like me to help implement any of these fixes? +``` + +### Managing Workers Bindings + +Using the Workers Bindings server to set up storage and AI capabilities: + +#### Goose Prompt +``` +I need to set up a new Workers project with: +1. KV storage for caching +2. R2 bucket for file uploads +3. AI binding for text analysis +4. Durable Object for real-time features + +Can you help me configure these bindings? +``` + +### Analyzing Traffic with Radar + +Using the Radar server for security and traffic analysis: + +#### Goose Prompt +``` +Can you help me analyze the security posture of my domain example.com? I want to: +1. Check for any security threats or malicious traffic +2. Analyze global traffic patterns +3. Scan for vulnerabilities +4. Get recommendations for improving security +``` + +## Common Use Cases + +### 1. Application Debugging +- **Observability Server**: Monitor errors, performance, and user experience +- **Logpush Server**: Analyze log patterns and data pipeline health +- **DNS Analytics**: Debug DNS resolution issues + +### 2. Development & Deployment +- **Workers Bindings**: Configure storage, AI, and compute resources +- **Workers Builds**: Monitor deployment status and build health +- **Container Server**: Set up isolated development environments + +### 3. Security & Compliance +- **Audit Logs**: Track configuration changes and access patterns +- **Cloudflare One CASB**: Monitor SaaS application security +- **Radar Server**: Threat intelligence and URL scanning + +### 4. Analytics & Insights +- **GraphQL Server**: Custom analytics and reporting +- **Digital Experience Monitoring**: Application performance insights +- **AI Gateway**: AI usage analytics and optimization + +### 5. Content & Web Management +- **Browser Rendering**: Web scraping and content analysis +- **AutoRAG**: Document management and retrieval +- **Documentation Server**: API reference and troubleshooting + +## Best Practices + +### Security +- Use scoped API tokens with minimal required permissions +- Regularly rotate API tokens +- Monitor API usage through audit logs +- Set up alerts for unusual activity + +### Performance +- Use appropriate timeout values for different operations +- Cache frequently accessed data when possible +- Monitor rate limits and usage quotas +- Implement proper error handling and retries + +### Development Workflow +- Start with the Documentation server for API reference +- Use Container server for isolated testing +- Monitor with Observability server during development +- Analyze with Radar server before going live + +## Troubleshooting + +### Common Issues + +**Authentication Errors:** +- Verify API token has correct permissions +- Check token hasn't expired +- Ensure token is properly set in environment variables + +**Rate Limiting:** +- Monitor API usage in Cloudflare dashboard +- Implement exponential backoff for retries +- Consider upgrading plan for higher limits + +**Connection Issues:** +- Verify network connectivity to Cloudflare APIs +- Check firewall settings +- Ensure proper DNS resolution + +### Getting Help + +If you encounter issues: + +1. Check the [Cloudflare MCP Server repository](https://github.com/cloudflare/mcp-server-cloudflare) for documentation +2. Review [Cloudflare API documentation](https://developers.cloudflare.com/api/) +3. Join our [Discord community](https://discord.gg/block-opensource) for support +4. Check [Cloudflare Community](https://community.cloudflare.com/) for platform-specific help + +## Next Steps + +With Cloudflare MCP servers enabled in Goose, you can: + +- **Monitor and debug** your applications with natural language queries +- **Manage infrastructure** through conversational commands +- **Analyze security** and performance data effortlessly +- **Automate workflows** across Cloudflare's entire platform + +Try starting with the Observability server to get insights into your current applications, then expand to other servers based on your specific needs. diff --git a/documentation/docs/mcp/cloudinary-asset-management-mcp.md b/documentation/docs/mcp/cloudinary-asset-management-mcp.md index 5a1ed6e2788e..36cf53474130 100644 --- a/documentation/docs/mcp/cloudinary-asset-management-mcp.md +++ b/documentation/docs/mcp/cloudinary-asset-management-mcp.md @@ -5,6 +5,10 @@ description: Add Cloudinary Asset Management MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; +import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; + + This tutorial covers how to add the [Cloudinary Asset Management MCP Server](https://github.com/cloudinary-community/cloudinary-mcp) as a Goose extension to automate complex image processing workflows that would typically require specialized design software or manual editing. @@ -35,11 +39,16 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=--package&arg=@cloudinary/asset-management&arg=--&arg=mcp&arg=start&id=cloudinary&name=Cloudinary%20Asset%20Management&description=Powerful%20media%20processing%20and%20transformation&env=CLOUDINARY_URL%3DCloudinary%20URL) - 2. Press `Yes` to confirm the installation - 3. Obtain your [CLOUDINARY_URL](https://console.cloudinary.com/settings/api-keys) from your Cloudinary dashboard and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/computer-controller-mcp.md b/documentation/docs/mcp/computer-controller-mcp.md index 00b6fc126465..ce974f9ec2aa 100644 --- a/documentation/docs/mcp/computer-controller-mcp.md +++ b/documentation/docs/mcp/computer-controller-mcp.md @@ -6,6 +6,7 @@ description: Use Computer Controller MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; @@ -19,13 +20,12 @@ Let Goose complete its tasks without interruption - avoid using your mouse or ke ## Configuration -1. Ensure extension is enabled: - - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, toggle `Computer Controller` to on. + diff --git a/documentation/docs/mcp/context7-mcp.mdx b/documentation/docs/mcp/context7-mcp.mdx index 06d6a62aeb96..a6093fa90243 100644 --- a/documentation/docs/mcp/context7-mcp.mdx +++ b/documentation/docs/mcp/context7-mcp.mdx @@ -8,6 +8,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; import CLIExtensionInstructions from '@site/src/components/CLIExtensionInstructions'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -31,27 +32,16 @@ This tutorial covers how to add the [Context7 MCP Server](https://github.com/ups ## Configuration -:::info -Note that you'll need [Node.js](https://nodejs.org/) installed on your system to run this command, as it uses `npx`. -::: - - - - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40upstash%2Fcontext7-mcp&id=context7&name=Context7&description=Context7%20MCP%20server%20for%20up-to-date%20code%20and%20docs) - 2. Press Yes to confirm the installation - 3. Click Save Configuration - 4. Scroll to the top and click Exit from the upper left corner - - - - - - - + ## Example Usage diff --git a/documentation/docs/mcp/developer-mcp.md b/documentation/docs/mcp/developer-mcp.md index 83f2de84b2b7..0c02b5767d05 100644 --- a/documentation/docs/mcp/developer-mcp.md +++ b/documentation/docs/mcp/developer-mcp.md @@ -6,6 +6,7 @@ description: Use Developer MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; @@ -20,13 +21,12 @@ This tutorial will cover enabling and using the Developer MCP Server, which is a The Developer extension is already enabled by default when Goose is installed. ::: -1. Ensure extension is enabled: - - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, toggle `Developer` to on. + diff --git a/documentation/docs/mcp/elevenlabs-mcp.md b/documentation/docs/mcp/elevenlabs-mcp.md index 7029001f7d4e..1b23890b936d 100644 --- a/documentation/docs/mcp/elevenlabs-mcp.md +++ b/documentation/docs/mcp/elevenlabs-mcp.md @@ -6,6 +6,7 @@ description: Add ElevenLabs MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -33,168 +34,18 @@ This tutorial covers how to add the [ElevenLabs MCP Server](https://github.com/y ## Configuration -:::info -Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on your system to run this command, as it uses `uvx`. -::: - - - - 1. [Launch the installer](goose://extension?cmd=uvx&arg=elevenlabs-mcp&id=elevenlabs&name=ElevenLabs&description=ElevenLabs%20voice%20synthesis%20server&env=ELEVENLABS_API_KEY) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner - - - 1. Run the `configure` command: - ```sh - goose configure - ``` - - 2. Choose to add a `Command-line Extension` - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◆ What type of extension would you like to add? - │ ○ Built-in Extension - // highlight-start - │ ● Command-line Extension (Run a local command or script) - // highlight-end - │ ○ Remote Extension (SSE) - │ ○ Remote Extension (Streaming HTTP) - └ - ``` - - 3. Give your extension a name - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◇ What type of extension would you like to add? - │ Command-line Extension - │ - // highlight-start - ◆ What would you like to call this extension? - │ elevenlabs - // highlight-end - └ - ``` - - 4. Enter the command - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◇ What type of extension would you like to add? - │ Command-line Extension - │ - ◇ What would you like to call this extension? - │ elevenlabs - │ - // highlight-start - ◆ What command should be run? - │ uvx elevenlabs-mcp - // highlight-end - └ - ``` - - 5. Enter the number of seconds Goose should wait for actions to complete before timing out. Default is 300s - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◇ What type of extension would you like to add? - │ Command-line Extension - │ - ◇ What would you like to call this extension? - │ elevenlabs - │ - ◇ What command should be run? - │ uvx elevenlabs-mcp - │ - // highlight-start - ◆ Please set the timeout for this tool (in secs): - │ 300 - // highlight-end - │ - └ - ``` - - 6. Choose to add a description. If you select "Yes" here, you will be prompted to enter a description for the extension. - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◇ What type of extension would you like to add? - │ Command-line Extension - │ - ◇ What would you like to call this extension? - │ elevenlabs - │ - ◇ What command should be run? - │ uvx elevenlabs-mcp - │ - ◆ Please set the timeout for this tool (in secs): - │ 300 - │ - // highlight-start - ◇ Would you like to add a description? - │ No - // highlight-end - └ - ``` - - 7. Obtain an [ElevenLabs API Key](https://elevenlabs.io/app/settings/api-keys) and paste it in. - - ```sh - ┌ goose-configure - │ - ◇ What would you like to configure? - │ Add Extension (Connect to a new extension) - │ - ◇ What type of extension would you like to add? - │ Command-line Extension - │ - ◇ What would you like to call this extension? - │ elevenlabs - │ - ◇ What command should be run? - │ uvx elevenlabs-mcp - │ - ◇ Please set the timeout for this tool (in secs): - │ 300 - │ - ◇ Would you like to add a description? - │ No - | - // highlight-start - ◆ Would you like to add environment variables? - │ Yes - │ - ◇ Environment variable name: - │ ELEVENLABS_API_KEY - │ - ◇ Environment variable value: - │ ▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪▪ - │ - ◇ Add another environment variable? - │ No - // highlight-end - └ Added elevenlabs extension - ``` - - - + ## Example Usage diff --git a/documentation/docs/mcp/fetch-mcp.md b/documentation/docs/mcp/fetch-mcp.md index 27f78f1942e1..576faf70f333 100644 --- a/documentation/docs/mcp/fetch-mcp.md +++ b/documentation/docs/mcp/fetch-mcp.md @@ -6,6 +6,7 @@ description: Add Fetch MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -37,8 +38,13 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=uvx&arg=mcp-server-fetch&id=fetch&name=Fetch&description=Web%20content%20fetching%20and%20processing%20capabilities) - 2. Press `Yes` to confirm the installation + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/figma-mcp.md b/documentation/docs/mcp/figma-mcp.md index d4dee065bccc..f9d853c501eb 100644 --- a/documentation/docs/mcp/figma-mcp.md +++ b/documentation/docs/mcp/figma-mcp.md @@ -6,6 +6,7 @@ description: Add Figma MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -38,11 +39,18 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40hapins%2Ffigma-mcp&id=figma&name=Figma&description=Figma%20design%20tool%20integration&env=FIGMA_ACCESS_TOKEN%3DAccess%20token%20from%20Figma%20user%20settings) - 2. Press `Yes` to confirm the installation - 3. Obtain a [Figma Access Token](https://www.figma.com/developers/api#access-tokens) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/filesystem-mcp.md b/documentation/docs/mcp/filesystem-mcp.md index 51df62f9c4ba..b933e82871d4 100644 --- a/documentation/docs/mcp/filesystem-mcp.md +++ b/documentation/docs/mcp/filesystem-mcp.md @@ -6,6 +6,7 @@ description: Add Filesystem MCP Server as Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -28,15 +29,18 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, click the `Add Custom Extensions` button - 4. On the `Add Extension Manually` modal, enter the following: - * **Type**: `Standard IO` - * **Extension Name**: `filesystem` (_set this to whatever you want_) - * **Description**: `filesystem MCP Server` (_(optional) set this to whatever you want_) - * **Command**: `npx -y @modelcontextprotocol/server-filesystem
` (_Add additional allowed directories if desired, separated by a space._) - 5. Click `Add Extension` button + diff --git a/documentation/docs/mcp/github-mcp.md b/documentation/docs/mcp/github-mcp.md index 497ebe15ea99..0ad4d2432722 100644 --- a/documentation/docs/mcp/github-mcp.md +++ b/documentation/docs/mcp/github-mcp.md @@ -6,64 +6,65 @@ description: Add GitHub MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; -import CLIExtensionInstructions from '@site/src/components/CLIExtensionInstructions'; +import CLIStreamExtensionInstructions from '@site/src/components/CLIStreamExtensionInstructions'; +import { PanelLeft } from 'lucide-react'; -This tutorial covers how to add the [GitHub MCP Server](https://github.com/modelcontextprotocol/servers/tree/main/src/github) as a Goose extension to enable file operations, repository management, search functionality, and more. +This tutorial covers how to add the [GitHub MCP Server](https://github.com/github/github-mcp-server) as a Goose extension to enable file operations, repository management, search functionality, and more. :::tip TLDR - [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-github&id=github&name=GitHub&description=GitHub%20API&env=GITHUB_PERSONAL_ACCESS_TOKEN%3DGitHub%20Personal%20Access%20Token) + Use `Add custom extension` in Settings → Extensions to add a `Streamable HTTP` extension type with: - **Command** - ```sh - npx -y @modelcontextprotocol/server-github - ``` + Use `goose configure` to add a `Remote Extension (Streaming HTTP)` extension type with: - **Environment Variable** + **Endpoint URL** + ``` + https://api.githubcopilot.com/mcp/ ``` - GITHUB_PERSONAL_ACCESS_TOKEN: + **Custom Request Header** + ``` + Authorization: Bearer ``` ::: ## Configuration - -:::info -Note that you'll need [Node.js](https://nodejs.org/) installed on your system to run this command, as it uses `npx`. -::: +These steps configure the Remote MCP Server. For other deployment options, see the [official GitHub MCP Server documentation](https://github.com/github/github-mcp-server). - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-github&id=github&name=GitHub&description=GitHub%20API&env=GITHUB_PERSONAL_ACCESS_TOKEN%3DGitHub%20Personal%20Access%20Token) - 2. Press `Yes` to confirm the installation - 3. Obtain a [GitHub Personal Access Token](https://github.com/settings/personal-access-tokens) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Obtain a [GitHub Personal Access Token](https://github.com/settings/personal-access-tokens) + 2. Click the button in the top-left to open the sidebar + 3. Click `Extensions` + 4. Click `Add custom extension` + 5. On the `Add custom extension` modal, enter the following: + - **Extension Name**: GitHub + - **Type**: Streamable HTTP + - **Endpoint**: `https://api.githubcopilot.com/mcp/` + - **Request Headers**: + - **Header name**: `Authorization` + - **Value**: `Bearer ` + 6. Click `+ Add` to save the header + 7. Click `Add Extension` to save the extension + 8. Navigate to the chat + - - - When creating your access token, you can specify the repositories and granular permissions you'd like Goose to have access to.{" "} - - Create one here - . + Obtain your GitHub Personal Access Token and paste it in } /> diff --git a/documentation/docs/mcp/google-drive-mcp.md b/documentation/docs/mcp/google-drive-mcp.md index 3406bf5f621f..bd1c7dc16668 100644 --- a/documentation/docs/mcp/google-drive-mcp.md +++ b/documentation/docs/mcp/google-drive-mcp.md @@ -6,6 +6,7 @@ description: Add Google Drive MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -92,24 +93,24 @@ You'll need to re-authenticate once a day when using the Google Drive extension. - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-gdrive&id=google-drive&name=Google%20Drive&description=Google%20Drive%20integration&env=GDRIVE_CREDENTIALS_PATH%3DPath%20to%20Google%20Drive%20credentials&env=GDRIVE_OAUTH_PATH%3DPath%20to%20OAuth%20token) - 2. Press `Yes` to confirm the installation - 3. For `GDRIVE_CREDENTIALS_PATH`, enter the following: - ```sh - $USER_HOME/.config/.gdrive-server-credentials.json - ``` - :::info - Replace `$USER_HOME` with your home directory. You must specify an absolute path for this extension to work. - ::: - 4. For `GDRIVE_OAUTH_PATH`, enter the following: - ```sh - $USER_HOME/.config/gcp-oauth.keys.json - ``` - :::info - Replace `$USER_HOME` with your home directory. You must specify an absolute path for this extension to work. - ::: - 5. Click `Save Configuration` - 6. Scroll to the top and click `Exit` from the upper left corner + + + :::info + - For `GDRIVE_CREDENTIALS_PATH`, enter `$USER_HOME/.config/.gdrive-server-credentials.json` + - For `GDRIVE_OAUTH_PATH`, enter `$USER_HOME/.config/gcp-oauth.keys.json` + + Replace `$USER_HOME` with your home directory. You must specify an absolute path for this extension to work. + ::: 1. Run the `configure` command: diff --git a/documentation/docs/mcp/google-maps-mcp.md b/documentation/docs/mcp/google-maps-mcp.md index 613a604caefe..4e253a1d46dc 100644 --- a/documentation/docs/mcp/google-maps-mcp.md +++ b/documentation/docs/mcp/google-maps-mcp.md @@ -6,6 +6,7 @@ description: Add Google Maps MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; This tutorial covers how to add the [Google Maps MCP Server](https://github.com/modelcontextprotocol/servers/tree/main/src/google-maps) as a Goose extension to enable geocoding, place searching, distance calculations, elevation data retrieval, and directions. @@ -36,11 +37,18 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-google-maps&id=google-maps&name=Google%20Maps&description=Google%20Maps%20API%20integration&env=GOOGLE_MAPS_API_KEY%3DGoogle%20Maps%20API%20key) - 2. Press `Yes` to confirm the installation - 3. Obtain a [GOOGLE_MAPS_API_KEY](https://developers.google.com/maps/documentation/javascript/get-api-key) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/jetbrains-mcp.md b/documentation/docs/mcp/jetbrains-mcp.md index 031702676f49..7b6bafaf1b28 100644 --- a/documentation/docs/mcp/jetbrains-mcp.md +++ b/documentation/docs/mcp/jetbrains-mcp.md @@ -6,6 +6,7 @@ description: Use JetBrains MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; @@ -21,9 +22,10 @@ This tutorial covers how to enable and use the JetBrains MCP Server as a built-i - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, toggle `Jetbrains` to on. + diff --git a/documentation/docs/mcp/knowledge-graph-mcp.md b/documentation/docs/mcp/knowledge-graph-mcp.md index 7b2a14e490ce..c36c76d07a22 100644 --- a/documentation/docs/mcp/knowledge-graph-mcp.md +++ b/documentation/docs/mcp/knowledge-graph-mcp.md @@ -6,6 +6,7 @@ description: Add Knowledge Graph Memory MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -34,9 +35,13 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-memory&id=knowledge_graph_memory&name=Knowledge%20Graph%20Memory&description=Graph-based%20memory%20system%20for%20persistent%20knowledge%20storage) - 2. Press `Yes` to confirm the installation - 3. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/mbot-mcp.md b/documentation/docs/mcp/mbot-mcp.md index 8e23add3cb6b..4a0b07613d62 100644 --- a/documentation/docs/mcp/mbot-mcp.md +++ b/documentation/docs/mcp/mbot-mcp.md @@ -6,6 +6,7 @@ description: Control a MakeBlock mbot2 rover through MQTT and MCP as a Goose Ext import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -35,11 +36,18 @@ This tutorial will get you started with [deemkeen's MQTT MCP server](https://git - 1. [Launch the installer](goose://extension?cmd=/path/to/java&arg=-jar&arg=/path/to/mbotmcp-0.0.1-SNAPSHOT.jar&name=mbot2&description=mbot2&env=MQTT_SERVER_URI%3Dtcp://1.2.3.4:1883&env=MQTT_USERNAME%3Dyour_username&env=MQTT_PASSWORD%3Dyour_password) - 2. Press `Yes` to confirm the installation - 3. Get your MQTT server URI, and your username/password if your MQTT uses authentication - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: @@ -205,13 +213,13 @@ This tutorial will get you started with [deemkeen's MQTT MCP server](https://git │ └ Added mbot2 extension ``` - :::info - MQTT_USERNAME and MQTT_PASSWORD are required to exist, but can be empty strings if your MQTT server does not require authentication. - ::: + :::info + MQTT_USERNAME and MQTT_PASSWORD are required to exist, but can be empty strings if your MQTT server does not require authentication. + ::: ## Example Usage diff --git a/documentation/docs/mcp/memory-mcp.md b/documentation/docs/mcp/memory-mcp.md index 5f772aea3703..e2878998884f 100644 --- a/documentation/docs/mcp/memory-mcp.md +++ b/documentation/docs/mcp/memory-mcp.md @@ -6,6 +6,7 @@ description: Use Memory MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; @@ -15,14 +16,12 @@ This tutorial covers enabling and using the Memory MCP Server, which is a built- ## Configuration -1. Ensure extension is enabled: - - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, toggle `Memory` to on. - 4. Scroll to the top and click `Exit` from the upper left corner + diff --git a/documentation/docs/mcp/netlify-mcp.md b/documentation/docs/mcp/netlify-mcp.md index 91495ef9f2e1..286215f262f9 100644 --- a/documentation/docs/mcp/netlify-mcp.md +++ b/documentation/docs/mcp/netlify-mcp.md @@ -6,6 +6,7 @@ description: Add Netlify MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -41,8 +42,13 @@ netlify login - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40netlify%2Fmcp&id=netlify&name=Netlify&description=Build%2C%20deploy%2C%20and%20manage%20sites%20with%20Netlify%27s%20official%20MCP%20server.) - 2. Press `OK` to confirm the installation + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/nostrbook-mcp.md b/documentation/docs/mcp/nostrbook-mcp.md index 7df8a41d3f93..18ed81dab21b 100644 --- a/documentation/docs/mcp/nostrbook-mcp.md +++ b/documentation/docs/mcp/nostrbook-mcp.md @@ -5,6 +5,7 @@ description: Add Nostrbook MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -29,10 +30,13 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40nostrbook%2Fmcp&id=nostrbook&name=Nostrbook&description=A%20comprehensive%20registry%20of%20Nostr%20documentation%20with%20structured%20knowledge) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/pdf-mcp.md b/documentation/docs/mcp/pdf-mcp.md index 5ee9a41f6b18..4913c4d9f7f6 100644 --- a/documentation/docs/mcp/pdf-mcp.md +++ b/documentation/docs/mcp/pdf-mcp.md @@ -6,6 +6,7 @@ description: Add PDF Reader MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -33,11 +34,14 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=uvx&arg=mcp-read-pdf&id=pdf_read&name=PDF%20Reader&description=Read%20large%20and%20complex%20PDF%20documents) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: ```sh diff --git a/documentation/docs/mcp/pieces-mcp.md b/documentation/docs/mcp/pieces-mcp.md index 32f4d8365029..bd4090ce37cd 100644 --- a/documentation/docs/mcp/pieces-mcp.md +++ b/documentation/docs/mcp/pieces-mcp.md @@ -5,8 +5,8 @@ description: Add Pieces for Developers MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; - import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -34,8 +34,12 @@ http://localhost:39300/model_context_protocol/2024-11-05/sse - 1. [Launch the installer](goose://extension?url=http%3A%2F%2Flocalhost%3A39300%2Fmodel_context_protocol%2F2024-11-05%2Fsse&id=pieces&name=Pieces%20for%20Developers&description=Provides%20access%20to%20your%20Pieces%20Long-Term%20Memory.%20You%20need%20to%20have%20Pieces%20installed%20to%20use%20this.) - 2. Press `Yes` to confirm the installation + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/playwright-mcp.md b/documentation/docs/mcp/playwright-mcp.md index 73dd84c94d1e..f30c18610661 100644 --- a/documentation/docs/mcp/playwright-mcp.md +++ b/documentation/docs/mcp/playwright-mcp.md @@ -6,6 +6,7 @@ description: Add Playwright MCP Server as a Goose Extension for Modern Web Testi import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -33,9 +34,13 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=@playwright/mcp@latest&id=playwright&name=Playwright&description=Modern%20web%20testing%20and%20automation) - 2. Press `Yes` to confirm the installation - 3. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/postgres-mcp.md b/documentation/docs/mcp/postgres-mcp.md index d7e15a023c82..39a3a79aab49 100644 --- a/documentation/docs/mcp/postgres-mcp.md +++ b/documentation/docs/mcp/postgres-mcp.md @@ -6,6 +6,7 @@ description: Add PostgreSQL MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -60,11 +61,18 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=@modelcontextprotocol/server-postgres&arg=Your%20PostgreSQL%20connection%20URL&id=postgres&name=PostgreSQL&description=PostgreSQL%20database%20integration) - 2. Press `Yes` to confirm the installation - 3. Enter your PostgreSQL connection URL in the format: `postgresql://username:password@hostname:5432/database` - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner + + + :::info + Enter your PostgreSQL connection URL in the format: `postgresql://username:password@hostname:5432/database` + ::: + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/puppeteer-mcp.md b/documentation/docs/mcp/puppeteer-mcp.md index 46deb4b22ffa..97f4c6f17787 100644 --- a/documentation/docs/mcp/puppeteer-mcp.md +++ b/documentation/docs/mcp/puppeteer-mcp.md @@ -6,6 +6,7 @@ description: Add Puppeteer MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -34,10 +35,14 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40modelcontextprotocol%2Fserver-puppeteer&id=puppeteer&name=Puppeteer&description=Headless%20browser%20automation) - 2. Press `Yes` to confirm the installation - 3. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: ```sh diff --git a/documentation/docs/mcp/reddit-mcp.md b/documentation/docs/mcp/reddit-mcp.md index b10a7cf89ad2..0ac6e9472a29 100644 --- a/documentation/docs/mcp/reddit-mcp.md +++ b/documentation/docs/mcp/reddit-mcp.md @@ -8,6 +8,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; import CLIExtensionInstructions from '@site/src/components/CLIExtensionInstructions'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -37,10 +38,13 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=uvx&arg=--from&arg=git%2Bhttps%3A%2F%2Fgithub.com%2Fadhikasp%2Fmcp-reddit.git&arg=mcp-reddit&id=reddit&name=Reddit&description=Fetch%20and%20analyze%20Reddit%20content) - 2. Press Yes to confirm the installation - 3. Click Save Configuration - 4. Scroll to the top and click Exit from the upper left corner + diff --git a/documentation/docs/mcp/repomix-mcp.md b/documentation/docs/mcp/repomix-mcp.md index b02f2027e461..2e34a8c706f2 100644 --- a/documentation/docs/mcp/repomix-mcp.md +++ b/documentation/docs/mcp/repomix-mcp.md @@ -6,6 +6,7 @@ description: Add Repomix MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -34,11 +35,14 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=repomix&arg=--mcp&id=repomix&name=Repomix&description=Pack%20repositories%20into%20AI-friendly%20formats%20for%20Goose) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: ```sh diff --git a/documentation/docs/mcp/selenium-mcp.md b/documentation/docs/mcp/selenium-mcp.md index 11e96401daba..56538b5eea6a 100644 --- a/documentation/docs/mcp/selenium-mcp.md +++ b/documentation/docs/mcp/selenium-mcp.md @@ -6,6 +6,7 @@ description: Add Selenium MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -35,11 +36,14 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=%40angiejones%2Fmcp-selenium&id=selenium-mcp&name=Selenium%20MCP&description=automates%20browser%20interactions) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: ```sh diff --git a/documentation/docs/mcp/speech-mcp.md b/documentation/docs/mcp/speech-mcp.md index 0a5a07a81b5e..038ad7e5964a 100644 --- a/documentation/docs/mcp/speech-mcp.md +++ b/documentation/docs/mcp/speech-mcp.md @@ -6,6 +6,7 @@ description: Add Speech MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -19,7 +20,7 @@ This tutorial covers how to add the [Speech MCP Server](https://github.com/Kvadr :::tip TLDR - [Launch the installer](goose://extension?cmd=uvx&&arg=-p&arg=3.10.14&arg=speech-mcp@latest&id=speech_mcp&name=Speech%20Interface&description=Voice%20interaction%20with%20audio%20visualization%20for%20Goose) + [Launch the installer](goose://extension?cmd=uvx&arg=-p&arg=3.10.14&arg=speech-mcp@latest&id=speech_mcp&name=Speech%20Interface&description=Voice%20interaction%20with%20audio%20visualization%20for%20Goose) **Command** @@ -40,10 +41,13 @@ Before adding this extension, make sure [PortAudio](https://github.com/GoogleClo - 1. [Launch the installer](goose://extension?cmd=uvx&&arg=-p&arg=3.10.14&arg=speech-mcp@latest&id=speech_mcp&name=Speech%20Interface&description=Voice%20interaction%20with%20audio%20visualization%20for%20Goose) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/square-mcp.md b/documentation/docs/mcp/square-mcp.md index f3fa08fd12fd..7ea63649c31c 100644 --- a/documentation/docs/mcp/square-mcp.md +++ b/documentation/docs/mcp/square-mcp.md @@ -6,6 +6,7 @@ description: Add the Square API as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller';
🎥 Square MCP Server Video Walkthrough @@ -53,9 +54,11 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to 1. [Launch the installer](https://mcp.squareup.com/goose) - 2. Goose will open and ask you to confirm installation. - 3. Goose should open a browser tab to an OAuth permissions page. Double-check which permissions you want to allow, and click 'Grant Access'. + 2. Click `OK` to confirm the installation + 3. Goose should open a browser tab to an OAuth permissions page. Double-check which permissions you want to allow, and click `Grant Access`. 4. It will ask you to login or reauthenticate to Square, and may ask you to confirm the permissions you want to allow. + 5. In Goose, navigate to the chat + 1. Run the `configure` command: @@ -207,7 +210,8 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to :::tip TLDR - [Launch the installer](goose://extension?cmd=npx&arg=square-mcp-server%20start&id=square-mcp&name=Square%20MCP%20Server&description=Square%20MCP%20Server&env=ACCESS_TOKEN%3DYour%20Access%20Token&env=SANDBOX%3Dtrue) + [Launch the installer](goose://extension?cmd=npx&arg=square-mcp-server&arg=start&id=mcp_square_api&name=Square%20MCP%20Server&description=Square%20API%20MCP%20Server&env=ACCESS_TOKEN%3DYour%20Access%20Token&env=SANDBOX%3Dtrue) + **Command** @@ -231,13 +235,21 @@ Note that you'll need [Node.js](https://nodejs.org/) installed on your system to - 1. [Launch the installer](goose://extension?cmd=npx&arg=square-mcp-server%20start&id=square-mcp&name=Square%20MCP%20Server&description=Square%20MCP%20Server&env=ACCESS_TOKEN%3DYour%20Access%20Token&env=SANDBOX%3Dtrue) - 2. Press `Yes` to confirm the installation - 3. Get your [Square Access Token](https://developer.squareup.com/apps) and paste it in - 4. Keep `SANDBOX` as the environment variable, or change to `PRODUCTION`, and set its value to `true` - 5. Click `Save Configuration` - 6. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: diff --git a/documentation/docs/mcp/tavily-mcp.md b/documentation/docs/mcp/tavily-mcp.md index 518237bb2cd1..99b1f1cb8bce 100644 --- a/documentation/docs/mcp/tavily-mcp.md +++ b/documentation/docs/mcp/tavily-mcp.md @@ -6,6 +6,7 @@ description: Add Tavily MCP Server as a Goose Extension import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -37,12 +38,19 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=tavily-mcp&id=tavily&name=Tavily%20Web%20Search&description=Search%20the%20web%20with%20Tavily%20MCP&env=TAVILY_API_KEY%3DTavily%20API%20Key) - 2. Press `Yes` to confirm the installation - 3. Obtain a [TAVILY_API_KEY](https://tavily.com/) and paste it in - 4. Click `Save Configuration` - 5. Scroll to the top and click `Exit` from the upper left corner - + + 1. Run the `configure` command: ```sh diff --git a/documentation/docs/mcp/tutorial-mcp.md b/documentation/docs/mcp/tutorial-mcp.md index ad4e1d4ba00d..560135c31165 100644 --- a/documentation/docs/mcp/tutorial-mcp.md +++ b/documentation/docs/mcp/tutorial-mcp.md @@ -5,6 +5,7 @@ description: Learn how to use Goose's built-in Tutorial extension for guided lea import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; The Tutorial extension is a built-in feature of Goose that provides interactive, step-by-step guidance for learning various aspects of Goose and its capabilities. It's designed to help users get comfortable with Goose's features through hands-on practice. @@ -19,9 +20,9 @@ The Tutorial extension serves as an interactive learning tool that: - 1. Click `...` in the upper right corner - 2. Click `Advanced Settings` - 3. Under `Extensions`, toggle `Tutorial` to on. + diff --git a/documentation/docs/mcp/youtube-transcript-mcp.md b/documentation/docs/mcp/youtube-transcript-mcp.md index 7a591792fa1f..77c5c52b15dd 100644 --- a/documentation/docs/mcp/youtube-transcript-mcp.md +++ b/documentation/docs/mcp/youtube-transcript-mcp.md @@ -6,6 +6,7 @@ description: Add YouTube Transcript MCP Server as a Goose Extension for accessin import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; @@ -34,10 +35,13 @@ Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on - 1. [Launch the installer](goose://extension?cmd=uvx&arg=--from&arg=git%2Bhttps%3A%2F%2Fgithub.com%2Fjkawamoto%2Fmcp-youtube-transcript&arg=mcp-youtube-transcript&id=youtube-transcript&name=YouTube%20Transcript&description=Access%20YouTube%20video%20transcripts) - 2. Press `Yes` to confirm the installation - 3. Click `Save Configuration` - 4. Scroll to the top and click `Exit` from the upper left corner + 1. Run the `configure` command: diff --git a/documentation/docs/quickstart.md b/documentation/docs/quickstart.md index 24ab20f88b5a..42306856f397 100644 --- a/documentation/docs/quickstart.md +++ b/documentation/docs/quickstart.md @@ -11,6 +11,7 @@ import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; import MacDesktopInstallButtons from '@site/src/components/MacDesktopInstallButtons'; import WindowsDesktopInstallButtons from '@site/src/components/WindowsDesktopInstallButtons'; import LinuxDesktopInstallButtons from '@site/src/components/LinuxDesktopInstallButtons'; +import { PanelLeft } from 'lucide-react'; # Goose in 5 minutes @@ -183,10 +184,10 @@ While you're able to manually navigate to your working directory and open the HT - 1. Locate the menu (`...`) in the top right corner of the Goose Desktop. - 2. Select `Advanced settings` from the menu. - 3. Under the `Extensions` section, toggle the `Computer Controller` extension to enable it. This [extension](https://block.github.io/goose/v1/extensions/detail/nondeveloper) enables webscraping, file caching, and automations. - 4. Scroll back to the top and click `<- Back` in the upper left corner to return to your session. + 1. Click the button in the top-left to open the sidebar. + 2. Click `Extensions` in the sidebar menu. + 3. Toggle the `Computer Controller` extension to enable it. This [extension](https://block.github.io/goose/v1/extensions/detail/nondeveloper) enables webscraping, file caching, and automations. + 4. Return to your session to continue. 5. Now that Goose has browser capabilities, let's ask it to launch your game in a browser: diff --git a/documentation/docs/troubleshooting.md b/documentation/docs/troubleshooting.md index cf99e9b4265a..8e6efcb6475f 100644 --- a/documentation/docs/troubleshooting.md +++ b/documentation/docs/troubleshooting.md @@ -172,7 +172,7 @@ When the keyring is disabled, secrets are stored here: Many of the external extensions require a package runner. For example, if you run into an error like this one: ``` -Failed to start extension: {extension name}, "No such file or directory (os error 2)" +Failed to start extension `{extension name}`: Could not run extension command (`{extension command}`): No such file or directory (os error 2) Please check extension configuration for {extension name}. ``` diff --git a/documentation/docs/tutorials/_template_.mdx b/documentation/docs/tutorials/_template_.mdx index d94f749aff7f..ec2f6e62b2dd 100644 --- a/documentation/docs/tutorials/_template_.mdx +++ b/documentation/docs/tutorials/_template_.mdx @@ -8,6 +8,8 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import YouTubeShortEmbed from '@site/src/components/YouTubeShortEmbed'; import CLIExtensionInstructions from '@site/src/components/CLIExtensionInstructions'; +import GooseDesktopInstaller from '@site/src/components/GooseDesktopInstaller'; +import GooseBuiltinInstaller from '@site/src/components/GooseBuiltinInstaller'; @@ -35,87 +37,47 @@ This tutorial covers how to add the [{name} MCP Server](/) as a Goose extension ## Configuration -:::info -Note that you'll need [Node.js](https://nodejs.org/) installed on your system to run this command, as it uses `npx`. -::: - -:::info -Note that you'll need [uv](https://docs.astral.sh/uv/#installation) installed on your system to run this command, as it uses `uvx`. -::: - -:::info -Note that you'll need [JBang](https://www.jbang.dev/download) installed on your system to run this command, as it uses `jbang`. -::: - - - - 1. [Launch the installer]({goose_url}) - 2. Press Yes to confirm the installation - 3. Obtain a [XYZ Access Token](/) and paste it in - 4. Click Save Configuration - 5. Scroll to the top and click Exit from the upper left corner - - - - - - - - - - When creating your access token, you can specify the repositories and granular permissions you'd like Goose to have access to.{" "} - - Create one here - . - - } - /> - - - - - - When creating your access token, you can specify the repositories and granular permissions you'd like Goose to have access to.{" "} - - Create one here - . - - } - /> - - - - - + + + + + + + + + Get your API key from{" "} + + example.com + . + + } + note="Note that you'll need Node.js installed on your system to run this command, as it uses npx." +/> ## Example Usage diff --git a/documentation/docs/tutorials/lead-worker.md b/documentation/docs/tutorials/lead-worker.md index 74aab2feb84d..0c89be1d7ae2 100644 --- a/documentation/docs/tutorials/lead-worker.md +++ b/documentation/docs/tutorials/lead-worker.md @@ -1,5 +1,5 @@ --- -description: Enable multi-modal functionality by pairing LLMs to complete your tasks +description: Enable multi-model functionality by pairing LLMs to complete your tasks --- import Tabs from '@theme/Tabs'; diff --git a/documentation/docs/tutorials/mongodb-mcp.md b/documentation/docs/tutorials/mongodb-mcp.md new file mode 100644 index 000000000000..03c78cd1143c --- /dev/null +++ b/documentation/docs/tutorials/mongodb-mcp.md @@ -0,0 +1,260 @@ +--- +title: MongoDB Extension +description: Add MongoDB MCP Server as a Goose Extension +--- + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +The MongoDB MCP Server extension allows Goose to interact directly with your MongoDB databases, enabling comprehensive database operations including querying, document manipulation, collection management, and database administration. This makes it easy to work with your MongoDB databases through natural language interactions. + +:::tip TLDR + + + [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=mongodb-mcp-server&arg=--connection-string&arg=mongodb://localhost:27017&id=mongodb&name=MongoDB&description=MongoDB%20database%20integration) + + + **Command** + ```sh + npx -y mongodb-mcp-server --connection-string mongodb://localhost:27017 + ``` + + +::: + +## Customizing Your Connection + +The MongoDB MCP server connects to a single MongoDB database instance using a connection string. The connection string must be specified using the `--connection-string` flag. We're using `mongodb://localhost:27017` as an example here to access a local MongoDB instance, but you can configure this for your own environment. + +The MongoDB connection string follows this format: +``` +mongodb://username:password@hostname:27017/database +``` + +Where: +- `username`: Your MongoDB user (optional for local development) +- `password`: Your MongoDB password (optional for local development) +- `hostname`: The host where MongoDB is running (e.g., localhost, IP address, or domain) +- `27017`: The default MongoDB port (change if using a different port) +- `database`: The name of your database (optional, will connect to default) + +Examples: +- Local database: `mongodb://localhost:27017` +- Local with credentials: `mongodb://myuser:mypass@localhost:27017/mydb` +- MongoDB Atlas: `mongodb+srv://user:pass@cluster.mongodb.net/database` + +:::caution +Never commit connection strings with credentials to version control! Use environment variables or secure configuration management. For MongoDB Atlas, ensure your IP address is whitelisted and use strong passwords. +::: + +## Configuration + +:::info +Note that you'll need [Node.js](https://nodejs.org/) installed on your system to run this command, as it uses `npx`. You'll also need a running MongoDB instance or access to MongoDB Atlas. +::: + + + + 1. [Launch the installer](goose://extension?cmd=npx&arg=-y&arg=mongodb-mcp-server&arg=--connection-string&arg=mongodb://localhost:27017&id=mongodb&name=MongoDB&description=MongoDB%20database%20integration) + 2. Press `Yes` to confirm the installation + 3. Enter your MongoDB connection string in the format: `mongodb://username:password@hostname:27017/database` + 4. Click `Save Configuration` + 5. Scroll to the top and click `Exit` from the upper left corner + + + 1. Run the `configure` command: + ```sh + goose configure + ``` + + 2. Choose to add a `Command-line Extension` + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◆ What type of extension would you like to add? + │ ○ Built-in Extension + // highlight-start + │ ● Command-line Extension (Run a local command or script) + // highlight-end + │ ○ Remote Extension + └ + ``` + + 3. Name your extension + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◇ What type of extension would you like to add? + │ Command-line Extension + │ + // highlight-start + ◆ What would you like to call this extension? + │ MongoDB + // highlight-end + └ + ``` + + 4. Enter the command with your database connection string + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◇ What would you like to call this extension? + │ MongoDB + │ + // highlight-start + ◆ What command should be run? + │ npx -y mongodb-mcp-server --connection-string mongodb://localhost:27017 + // highlight-end + └ + ``` + + 5. Set the timeout (default 300s is usually sufficient) + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◇ What would you like to call this extension? + │ MongoDB + │ + ◇ What command should be run? + │ npx -y mongodb-mcp-server --connection-string mongodb://localhost:27017 + │ + // highlight-start + ◆ Please set the timeout for this tool (in secs): + │ 300 + // highlight-end + └ + ``` + + 6. Choose to add a description. If you select "Yes" here, you will be prompted to enter a description for the extension. + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◇ What would you like to call this extension? + │ MongoDB + │ + ◇ What command should be run? + │ npx -y mongodb-mcp-server --connection-string mongodb://localhost:27017 + │ + ◇ Please set the timeout for this tool (in secs): + │ 300 + │ + // highlight-start + ◆ Would you like to add a description? + │ No + // highlight-end + └ + ``` + + 7. Complete the configuration + ```sh + ┌ goose-configure + │ + ◇ What would you like to configure? + │ Add Extension (Connect to a new extension) + │ + ◇ What would you like to call this extension? + │ MongoDB + │ + ◇ What command should be run? + │ npx -y mongodb-mcp-server --connection-string mongodb://localhost:27017 + │ + ◇ Please set the timeout for this tool (in secs): + │ 300 + │ + ◇ Would you like to add a description? + │ No + │ + // highlight-start + ◆ Would you like to add environment variables? + │ No + // highlight-end + └ Added MongoDB extension + ``` + + + + +## Available Operations + +The MongoDB extension provides comprehensive database management capabilities through natural language interactions. You can perform a wide range of operations including: + +### **Query Operations** +- **Find documents** with complex filters and sorting +- **Aggregate data** using MongoDB's powerful aggregation pipeline +- **Count documents** in collections with optional filtering + +### **Document Management** +- **Insert** single or multiple documents +- **Update** one or many documents with specified criteria +- **Delete** specific documents or bulk deletions + +### **Collection & Database Administration** +- **Create indexes** to optimize query performance +- **Rename collections** for better organization +- **Drop collections** when no longer needed +- **Drop entire databases** (use with caution!) + +### **Schema & Analytics** +- **Analyze collection schemas** to understand document structure +- **List databases and collections** for exploration +- **View collection indexes** and their configurations +- **Check database statistics** and collection storage sizes + +Just describe what you want to accomplish in natural language, and Goose will help you execute the appropriate MongoDB operations safely and efficiently. + +## Example Usage + +#### Goose Prompt +``` +Update all users in the premium tier who haven't logged in for more than 90 days - set their status to 'inactive' and add a note about when this change was made. +``` + +#### Goose Output +``` +I've successfully updated inactive premium users using updateMany: + +db.users.updateMany( + { + subscriptionTier: "premium", + lastLogin: { $lt: new Date(Date.now() - 90 * 24 * 60 * 60 * 1000) } + }, + { + $set: { + status: "inactive", + inactiveNote: "Auto-updated to inactive on 2025-05-25 due to 90+ days of inactivity", + lastModified: new Date() + } + } +) + +Operation Result: +{ + acknowledged: true, + matchedCount: 23, + modifiedCount: 23, + upsertedId: null, + upsertedCount: 0 +} + +Updated 23 premium tier users who hadn't logged in since February 24, 2025. Each document now has: +• status: "inactive" +• inactiveNote: explains the automatic status change +• lastModified: timestamp of this operation + +Consider sending re-engagement emails to these users or reviewing their subscription status. +``` diff --git a/documentation/src/components/CLIStreamExtensionInstructions.js b/documentation/src/components/CLIStreamExtensionInstructions.js new file mode 100644 index 000000000000..286f7044904a --- /dev/null +++ b/documentation/src/components/CLIStreamExtensionInstructions.js @@ -0,0 +1,244 @@ +import React from 'react'; +import CodeBlock from '@theme/CodeBlock'; +import Admonition from '@theme/Admonition'; + +export default function CLIStreamExtensionInstructions({ + name, + endpointUri, + timeout = 300, + headers = [], + infoNote, +}) { + const hasHeaders = headers.length > 0; + const headerStepText = hasHeaders + ? `Choose Yes when asked to add custom headers` + : 'Choose No when asked to add custom headers'; + + return ( +
+
    +
  1. Run the configure command:
  2. +
+ {`goose configure`} + +
    +
  1. Choose to add a Remote Extension (Streaming HTTP)
  2. +
+ {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◆ What type of extension would you like to add? +│ ○ Built-in Extension +│ ○ Command-line Extension +│ ○ Remote Extension (SSE) +// highlight-start +│ ● Remote Extension (Streaming HTTP) (Connect to a remote extension via MCP Streaming HTTP) +// highlight-end +└`} + +
    +
  1. Give your extension a name
  2. +
+ {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +// highlight-start +◆ What would you like to call this extension? +│ ${name} +// highlight-end +└`} + +
    +
  1. Enter the endpoint URI
  2. +
+ {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +// highlight-start +◆ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +// highlight-end +└`} + +
    +
  1. + Enter the number of seconds Goose should wait for actions to complete before timing out. Default is 300 seconds +
  2. +
+ {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +◇ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +│ +// highlight-start +◆ Please set the timeout for this tool (in secs): +│ ${timeout} +// highlight-end +└`} + +
    +
  1. Choose to add a description. If you select Yes, you'll be prompted to enter a description for the extension
  2. +
+ {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +◇ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +│ +◇ Please set the timeout for this tool (in secs): +│ ${timeout} +│ +// highlight-start +◆ Would you like to add a description? +│ No +// highlight-end +└`} + +
    +
  1. {headerStepText}
  2. +
+ + {!hasHeaders && ( + {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +◇ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +│ +◇ Please set the timeout for this tool (in secs): +│ ${timeout} +│ +◇ Would you like to add a description? +│ No +│ +// highlight-start +◆ Would you like to add custom headers? +│ No +// highlight-end +└ Added ${name} extension`} + )} + + {hasHeaders && ( + <> + {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +◇ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +│ +◇ Please set the timeout for this tool (in secs): +│ ${timeout} +│ +◇ Would you like to add a description? +│ No +│ +// highlight-start +◆ Would you like to add custom headers? +│ Yes +// highlight-end +└`} + +
    +
  1. Add your custom header{headers.length > 1 ? 's' : ''}
  2. +
+ + {infoNote && ( + <> + + {infoNote} + +
+ + )} + + {`┌ goose-configure +│ +◇ What would you like to configure? +│ Add Extension +│ +◇ What type of extension would you like to add? +│ Remote Extension (Streaming HTTP) +│ +◇ What would you like to call this extension? +│ ${name} +│ +◇ What is the Streaming HTTP endpoint URI? +│ ${endpointUri} +│ +◇ Please set the timeout for this tool (in secs): +│ ${timeout} +│ +◇ Would you like to add a description? +│ No +│ +◇ Would you like to add custom headers? +│ Yes +│ +// highlight-start +${headers + .map( + ({ key, value }, i) => `◇ Header name: +│ ${key} +│ +◇ Header value: +│ ${value} +│ +◇ Add another header? +│ ${i === headers.length - 1 ? 'No' : 'Yes'}` + ) + .join('\n│\n')} +// highlight-end +└ Added ${name} extension`} + + )} +
+ ); +} diff --git a/documentation/src/components/GooseBuiltinInstaller.tsx b/documentation/src/components/GooseBuiltinInstaller.tsx new file mode 100644 index 000000000000..3bed475dee21 --- /dev/null +++ b/documentation/src/components/GooseBuiltinInstaller.tsx @@ -0,0 +1,24 @@ +import React from 'react'; +import { PanelLeft } from 'lucide-react'; + +interface GooseBuiltinInstallerProps { + extensionName: string; + description?: string; +} + +const GooseBuiltinInstaller: React.FC = ({ + extensionName, + description +}) => { + return ( +
+
    +
  1. Click the button in the top-left to open the sidebar
  2. +
  3. Click Extensions in the sidebar
  4. +
  5. Toggle {extensionName} on
  6. +
+
+ ); +}; + +export default GooseBuiltinInstaller; diff --git a/documentation/src/components/GooseDesktopInstaller.tsx b/documentation/src/components/GooseDesktopInstaller.tsx new file mode 100644 index 000000000000..cf9f113ca3b5 --- /dev/null +++ b/documentation/src/components/GooseDesktopInstaller.tsx @@ -0,0 +1,122 @@ +import React from 'react'; +import { PanelLeft } from 'lucide-react'; + +interface EnvVar { + name: string; + label: string; +} + +interface GooseDesktopInstallerProps { + extensionId: string; + extensionName: string; + description: string; + // Command-line extension props (optional when using url) + command?: string; + args?: string[]; + // SSE extension prop (optional when using command+args) + url?: string; + envVars?: EnvVar[]; + apiKeyLink?: string; + apiKeyLinkText?: string; + customStep3?: string; + hasEnvVars?: boolean; // Explicit control over configuration steps + appendToStep3?: string; +} + +export default function GooseDesktopInstaller({ + extensionId, + extensionName, + description, + command, + args, + url, + envVars = [], + apiKeyLink, + apiKeyLinkText, + customStep3, + hasEnvVars, + appendToStep3 +}: GooseDesktopInstallerProps) { + + // Build the goose:// URL + const buildGooseUrl = () => { + let urlParts = []; + + // Add SSE extension URL or command-line extension command+args first + if (url) { + urlParts.push(`url=${encodeURIComponent(url)}`); + } else if (command && args) { + urlParts.push(`cmd=${encodeURIComponent(command)}`); + urlParts.push(...args.map(arg => `arg=${encodeURIComponent(arg)}`)); + } + + // Add common parameters + urlParts.push( + `id=${encodeURIComponent(extensionId)}`, + `name=${encodeURIComponent(extensionName)}`, + `description=${encodeURIComponent(description)}` + ); + + // Add environment variables (matching TLDR sections encoding) + urlParts.push(...envVars.map(envVar => + `env=${encodeURIComponent(`${envVar.name}=${envVar.label}`)}` + )); + + return `goose://extension?${urlParts.join('&')}`; + }; + + // Generate step 3 content (only if needed) + const getStep3Content = () => { + if (customStep3) { + return customStep3; + } + + if (apiKeyLink && apiKeyLinkText) { + return ( + <> + Get your {apiKeyLinkText} and paste it in + + ); + } + + if (envVars.length > 0) { + const envVarNames = envVars.map(env => env.name).join(', '); + return `Obtain your ${envVarNames} and paste it in`; + } + + return null; // No configuration needed + }; + + const content = getStep3Content(); + const step3Content = appendToStep3 + ? ( + <> + {content} + {content ?
: null} + {appendToStep3} + + ) + : content; + + const hasConfigurationContent = step3Content !== null; + const shouldShowConfigurationSteps = hasEnvVars ?? hasConfigurationContent; + + return ( +
+
    +
  1. + Launch the installer +
  2. +
  3. Click OK to confirm the installation
  4. + {shouldShowConfigurationSteps && ( + <> +
  5. {step3Content}
  6. +
  7. Click Add Extension
  8. + + )} +
  9. Click the button in the top-left to open the sidebar
  10. +
  11. Navigate to the chat
  12. +
+
+ ); +} diff --git a/download_cli.ps1 b/download_cli.ps1 new file mode 100644 index 000000000000..57eca20f498b --- /dev/null +++ b/download_cli.ps1 @@ -0,0 +1,233 @@ +############################################################################## +# Goose CLI Install Script for Windows PowerShell +# +# This script downloads the latest stable 'goose' CLI binary from GitHub releases +# and installs it to your system. +# +# Supported OS: Windows +# Supported Architectures: x86_64 +# +# Usage: +# Invoke-WebRequest -Uri "https://github.com/block/goose/releases/download/stable/download_cli.ps1" -OutFile "download_cli.ps1"; .\download_cli.ps1 +# Or simply: .\download_cli.ps1 +# +# Environment variables: +# $env:GOOSE_BIN_DIR - Directory to which Goose will be installed (default: $env:USERPROFILE\.local\bin) +# $env:GOOSE_VERSION - Optional: specific version to install (e.g., "v1.0.25"). Can be in the format vX.Y.Z, vX.Y.Z-suffix, or X.Y.Z +# $env:GOOSE_PROVIDER - Optional: provider for goose +# $env:GOOSE_MODEL - Optional: model for goose +# $env:CANARY - Optional: if set to "true", downloads from canary release instead of stable +# $env:CONFIGURE - Optional: if set to "false", disables running goose configure interactively +############################################################################## + +# Set error action preference to stop on errors +$ErrorActionPreference = "Stop" + +# --- 1) Variables --- +$REPO = "block/goose" +$OUT_FILE = "goose.exe" + +# Set default bin directory if not specified +if (-not $env:GOOSE_BIN_DIR) { + $env:GOOSE_BIN_DIR = Join-Path $env:USERPROFILE ".local\bin" +} + +# Determine release type +$RELEASE = if ($env:CANARY -eq "true") { "true" } else { "false" } +$CONFIGURE = if ($env:CONFIGURE -eq "false") { "false" } else { "true" } + +# Determine release tag +if ($env:GOOSE_VERSION) { + # Validate version format + if ($env:GOOSE_VERSION -notmatch '^v?[0-9]+\.[0-9]+\.[0-9]+(-.*)?$') { + Write-Error "Invalid version '$env:GOOSE_VERSION'. Expected: semver format vX.Y.Z, vX.Y.Z-suffix, or X.Y.Z" + exit 1 + } + # Ensure version starts with 'v' + $RELEASE_TAG = if ($env:GOOSE_VERSION.StartsWith("v")) { $env:GOOSE_VERSION } else { "v$env:GOOSE_VERSION" } +} else { + # Use canary or stable based on RELEASE variable + $RELEASE_TAG = if ($RELEASE -eq "true") { "canary" } else { "stable" } +} + +# --- 2) Detect Architecture --- +$ARCH = $env:PROCESSOR_ARCHITECTURE +if ($ARCH -eq "AMD64") { + $ARCH = "x86_64" +} elseif ($ARCH -eq "ARM64") { + Write-Error "Windows ARM64 is not currently supported." + exit 1 +} else { + Write-Error "Unsupported architecture '$ARCH'. Only x86_64 is supported on Windows." + exit 1 +} + +# --- 3) Build download URL --- +$FILE = "goose-$ARCH-pc-windows-gnu.zip" +$DOWNLOAD_URL = "https://github.com/$REPO/releases/download/$RELEASE_TAG/$FILE" + +Write-Host "Downloading $RELEASE_TAG release: $FILE..." -ForegroundColor Green + +# --- 4) Download the file --- +try { + Invoke-WebRequest -Uri $DOWNLOAD_URL -OutFile $FILE -UseBasicParsing + Write-Host "Download completed successfully." -ForegroundColor Green +} catch { + Write-Error "Failed to download $DOWNLOAD_URL. Error: $($_.Exception.Message)" + exit 1 +} + +# --- 5) Create temporary directory for extraction --- +$TMP_DIR = Join-Path $env:TEMP "goose_install_$(Get-Random)" +try { + New-Item -ItemType Directory -Path $TMP_DIR -Force | Out-Null + Write-Host "Created temporary directory: $TMP_DIR" -ForegroundColor Yellow +} catch { + Write-Error "Could not create temporary extraction directory: $TMP_DIR" + exit 1 +} + +# --- 6) Extract the archive --- +Write-Host "Extracting $FILE to temporary directory..." -ForegroundColor Green +try { + Expand-Archive -Path $FILE -DestinationPath $TMP_DIR -Force + Write-Host "Extraction completed successfully." -ForegroundColor Green +} catch { + Write-Error "Failed to extract $FILE. Error: $($_.Exception.Message)" + Remove-Item -Path $TMP_DIR -Recurse -Force -ErrorAction SilentlyContinue + exit 1 +} + +# Clean up the downloaded archive +Remove-Item -Path $FILE -Force + +# --- 7) Determine extraction directory --- +$EXTRACT_DIR = $TMP_DIR +if (Test-Path (Join-Path $TMP_DIR "goose-package")) { + Write-Host "Found goose-package subdirectory, using that as extraction directory" -ForegroundColor Yellow + $EXTRACT_DIR = Join-Path $TMP_DIR "goose-package" +} + +# --- 8) Create bin directory if it doesn't exist --- +if (-not (Test-Path $env:GOOSE_BIN_DIR)) { + Write-Host "Creating directory: $env:GOOSE_BIN_DIR" -ForegroundColor Yellow + try { + New-Item -ItemType Directory -Path $env:GOOSE_BIN_DIR -Force | Out-Null + } catch { + Write-Error "Could not create directory: $env:GOOSE_BIN_DIR" + Remove-Item -Path $TMP_DIR -Recurse -Force -ErrorAction SilentlyContinue + exit 1 + } +} + +# --- 9) Install goose binary --- +$SOURCE_GOOSE = Join-Path $EXTRACT_DIR "goose.exe" +$DEST_GOOSE = Join-Path $env:GOOSE_BIN_DIR $OUT_FILE + +if (Test-Path $SOURCE_GOOSE) { + Write-Host "Moving goose to $DEST_GOOSE" -ForegroundColor Green + try { + # Remove existing file if it exists to avoid conflicts + if (Test-Path $DEST_GOOSE) { + Remove-Item -Path $DEST_GOOSE -Force + } + Move-Item -Path $SOURCE_GOOSE -Destination $DEST_GOOSE -Force + } catch { + Write-Error "Failed to move goose.exe to $DEST_GOOSE. Error: $($_.Exception.Message)" + Remove-Item -Path $TMP_DIR -Recurse -Force -ErrorAction SilentlyContinue + exit 1 + } +} else { + Write-Error "goose.exe not found in extracted files" + Remove-Item -Path $TMP_DIR -Recurse -Force -ErrorAction SilentlyContinue + exit 1 +} + +# --- 10) Install temporal-service if it exists --- +$SOURCE_TEMPORAL_SERVICE = Join-Path $EXTRACT_DIR "temporal-service.exe" +if (Test-Path $SOURCE_TEMPORAL_SERVICE) { + $DEST_TEMPORAL_SERVICE = Join-Path $env:GOOSE_BIN_DIR "temporal-service.exe" + Write-Host "Moving temporal-service to $DEST_TEMPORAL_SERVICE" -ForegroundColor Green + try { + # Remove existing file if it exists to avoid conflicts + if (Test-Path $DEST_TEMPORAL_SERVICE) { + Remove-Item -Path $DEST_TEMPORAL_SERVICE -Force + } + Move-Item -Path $SOURCE_TEMPORAL_SERVICE -Destination $DEST_TEMPORAL_SERVICE -Force + } catch { + Write-Warning "Failed to move temporal-service.exe: $($_.Exception.Message)" + } +} + +# --- 11) Install temporal CLI if it exists --- +$SOURCE_TEMPORAL = Join-Path $EXTRACT_DIR "temporal.exe" +if (Test-Path $SOURCE_TEMPORAL) { + $DEST_TEMPORAL = Join-Path $env:GOOSE_BIN_DIR "temporal.exe" + Write-Host "Moving temporal CLI to $DEST_TEMPORAL" -ForegroundColor Green + try { + # Remove existing file if it exists to avoid conflicts + if (Test-Path $DEST_TEMPORAL) { + Remove-Item -Path $DEST_TEMPORAL -Force + } + Move-Item -Path $SOURCE_TEMPORAL -Destination $DEST_TEMPORAL -Force + } catch { + Write-Warning "Failed to move temporal.exe: $($_.Exception.Message)" + } +} + +# --- 12) Copy Windows runtime DLLs if they exist --- +$DLL_FILES = Get-ChildItem -Path $EXTRACT_DIR -Filter "*.dll" -ErrorAction SilentlyContinue +foreach ($dll in $DLL_FILES) { + $DEST_DLL = Join-Path $env:GOOSE_BIN_DIR $dll.Name + Write-Host "Moving Windows runtime DLL: $($dll.Name)" -ForegroundColor Green + try { + # Remove existing file if it exists to avoid conflicts + if (Test-Path $DEST_DLL) { + Remove-Item -Path $DEST_DLL -Force + } + Move-Item -Path $dll.FullName -Destination $DEST_DLL -Force + } catch { + Write-Warning "Failed to move $($dll.Name): $($_.Exception.Message)" + } +} + +# --- 13) Clean up temporary directory --- +try { + Remove-Item -Path $TMP_DIR -Recurse -Force + Write-Host "Cleaned up temporary directory." -ForegroundColor Yellow +} catch { + Write-Warning "Could not clean up temporary directory: $TMP_DIR" +} + +# --- 14) Configure Goose (Optional) --- +if ($CONFIGURE -eq "true") { + Write-Host "" + Write-Host "Configuring Goose" -ForegroundColor Green + Write-Host "" + try { + & $DEST_GOOSE configure + } catch { + Write-Warning "Failed to run goose configure. You may need to run it manually later." + } +} else { + Write-Host "Skipping 'goose configure', you may need to run this manually later" -ForegroundColor Yellow +} + +# --- 15) Check PATH and give instructions if needed --- +$CURRENT_PATH = $env:PATH +if ($CURRENT_PATH -notlike "*$env:GOOSE_BIN_DIR*") { + Write-Host "" + Write-Host "Warning: Goose installed, but $env:GOOSE_BIN_DIR is not in your PATH." -ForegroundColor Yellow + Write-Host "To add it to your PATH permanently, run the following command as Administrator:" -ForegroundColor Yellow + Write-Host " [Environment]::SetEnvironmentVariable('PATH', `$env:PATH + ';$env:GOOSE_BIN_DIR', 'Machine')" -ForegroundColor Cyan + Write-Host "" + Write-Host "Or add it to your user PATH (no admin required):" -ForegroundColor Yellow + Write-Host " [Environment]::SetEnvironmentVariable('PATH', `$env:PATH + ';$env:GOOSE_BIN_DIR', 'User')" -ForegroundColor Cyan + Write-Host "" + Write-Host "For this session only, you can run:" -ForegroundColor Yellow + Write-Host " `$env:PATH += ';$env:GOOSE_BIN_DIR'" -ForegroundColor Cyan + Write-Host "" +} + +Write-Host "Goose CLI installation completed successfully!" -ForegroundColor Green +Write-Host "Goose is installed at: $DEST_GOOSE" -ForegroundColor Green diff --git a/ui/desktop/.goosehints b/ui/desktop/.goosehints index ab102a4c257a..08f8e40a4492 100644 --- a/ui/desktop/.goosehints +++ b/ui/desktop/.goosehints @@ -52,7 +52,7 @@ The Goose Desktop App is an Electron application built with TypeScript, React, a 2. Create a main component file (e.g., `YourFeatureView.tsx`) 3. Add your view type to the `View` type in `App.tsx` 4. Import and add your component to the render section in `App.tsx` -5. Add navigation to your view from other components (e.g., adding a button in `BottomMenu.tsx` or `MoreMenu.tsx`) +5. Add navigation to your view from other components (e.g., adding a new route or button in `App.tsx`) ## State Management diff --git a/ui/desktop/forge.config.ts b/ui/desktop/forge.config.ts index c32ce48f2f07..827e124ff52e 100644 --- a/ui/desktop/forge.config.ts +++ b/ui/desktop/forge.config.ts @@ -12,14 +12,14 @@ let cfg = { certificateFile: process.env.WINDOWS_CERTIFICATE_FILE, signingRole: process.env.WINDOW_SIGNING_ROLE, rfc3161TimeStampServer: 'http://timestamp.digicert.com', - signWithParams: '/fd sha256 /tr http://timestamp.digicert.com /td sha256' + signWithParams: '/fd sha256 /tr http://timestamp.digicert.com /td sha256', }, // Protocol registration protocols: [ { - name: "GooseProtocol", - schemes: ["goose"] - } + name: 'GooseProtocol', + schemes: ['goose'], + }, ], // macOS Info.plist extensions for drag-and-drop support extendInfo: { @@ -44,9 +44,9 @@ let cfg = { osxNotarize: { appleId: process.env['APPLE_ID'], appleIdPassword: process.env['APPLE_ID_PASSWORD'], - teamId: process.env['APPLE_TEAM_ID'] + teamId: process.env['APPLE_TEAM_ID'], }, -} +}; if (process.env['APPLE_ID'] === undefined) { delete cfg.osxNotarize; @@ -62,12 +62,12 @@ module.exports = { config: { repository: { owner: 'block', - name: 'goose' + name: 'goose', }, prerelease: false, - draft: true - } - } + draft: true, + }, + }, ], makers: [ { @@ -76,22 +76,22 @@ module.exports = { config: { arch: process.env.ELECTRON_ARCH === 'x64' ? ['x64'] : ['arm64'], options: { - icon: 'src/images/icon.ico' - } - } + icon: 'src/images/icon.ico', + }, + }, }, { name: '@electron-forge/maker-deb', config: { name: 'Goose', - bin: 'Goose' + bin: 'Goose', }, }, { name: '@electron-forge/maker-rpm', config: { name: 'Goose', - bin: 'Goose' + bin: 'Goose', }, }, ], @@ -102,17 +102,17 @@ module.exports = { build: [ { entry: 'src/main.ts', - config: 'vite.main.config.ts', + config: 'vite.main.config.mts', }, { entry: 'src/preload.ts', - config: 'vite.preload.config.ts', + config: 'vite.preload.config.mts', }, ], renderer: [ { name: 'main_window', - config: 'vite.renderer.config.ts', + config: 'vite.renderer.config.mts', }, ], }, diff --git a/ui/desktop/index.html b/ui/desktop/index.html index 564e1f0e60e3..f606653e47eb 100644 --- a/ui/desktop/index.html +++ b/ui/desktop/index.html @@ -2,6 +2,7 @@ + Goose