diff --git a/.github/lsp.json b/.github/lsp.json index f1bd89796..e58456ac4 100644 --- a/.github/lsp.json +++ b/.github/lsp.json @@ -15,7 +15,7 @@ "rootUri": "dotnet" }, "go": { - "command": "${HOME}/go/bin/gopls", + "command": "gopls", "args": ["serve"], "fileExtensions": { ".go": "go" diff --git a/.github/workflows/dotnet-sdk-tests.yml b/.github/workflows/dotnet-sdk-tests.yml index 0bfa613a7..c86b37920 100644 --- a/.github/workflows/dotnet-sdk-tests.yml +++ b/.github/workflows/dotnet-sdk-tests.yml @@ -1,13 +1,15 @@ name: ".NET SDK Tests" on: + push: + branches: + - main pull_request: paths: - 'dotnet/**' - 'test/**' - 'nodejs/package.json' - '.github/workflows/dotnet-sdk-tests.yml' - - '.github/actions/setup-copilot/**' - '!**/*.md' - '!**/LICENSE*' - '!**/.gitignore' @@ -39,17 +41,16 @@ jobs: working-directory: ./dotnet steps: - uses: actions/checkout@v6.0.2 - - uses: ./.github/actions/setup-copilot - id: setup-copilot - uses: actions/setup-dotnet@v5 with: dotnet-version: "8.0.x" - uses: actions/setup-node@v6 with: + node-version: "24" cache: "npm" cache-dependency-path: "./nodejs/package-lock.json" - - name: Install Node.js dependencies (for CLI) + - name: Install Node.js dependencies (for CLI version extraction) working-directory: ./nodejs run: npm ci --ignore-scripts @@ -80,5 +81,4 @@ jobs: - name: Run .NET SDK tests env: COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} - COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} run: dotnet test --no-build -v n diff --git a/.github/workflows/go-sdk-tests.yml b/.github/workflows/go-sdk-tests.yml index 70aaebb84..ed75bcb0c 100644 --- a/.github/workflows/go-sdk-tests.yml +++ b/.github/workflows/go-sdk-tests.yml @@ -1,6 +1,9 @@ name: "Go SDK Tests" on: + push: + branches: + - main pull_request: paths: - 'go/**' diff --git a/.github/workflows/nodejs-sdk-tests.yml b/.github/workflows/nodejs-sdk-tests.yml index 9eded8d61..088d94a5b 100644 --- a/.github/workflows/nodejs-sdk-tests.yml +++ b/.github/workflows/nodejs-sdk-tests.yml @@ -4,12 +4,14 @@ env: HUSKY: 0 on: + push: + branches: + - main pull_request: paths: - 'nodejs/**' - 'test/**' - '.github/workflows/nodejs-sdk-tests.yml' - - '.github/actions/setup-copilot/**' - '!**/*.md' - '!**/LICENSE*' - '!**/.gitignore' @@ -45,9 +47,7 @@ jobs: with: cache: "npm" cache-dependency-path: "./nodejs/package-lock.json" - node-version: 22 - - uses: ./.github/actions/setup-copilot - id: setup-copilot + node-version: 24 - name: Install dependencies run: npm ci --ignore-scripts @@ -72,5 +72,4 @@ jobs: - name: Run Node.js SDK tests env: COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} - COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} run: npm test diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 749c520dd..a3849d62c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -106,6 +106,7 @@ jobs: name: nodejs-package path: nodejs/*.tgz - name: Publish to npm + if: github.ref == 'refs/heads/main' run: npm publish --tag ${{ github.event.inputs.dist-tag }} --access public --registry https://registry.npmjs.org publish-dotnet: @@ -130,6 +131,7 @@ jobs: name: dotnet-package path: dotnet/artifacts/*.nupkg - name: NuGet login (OIDC) + if: github.ref == 'refs/heads/main' uses: NuGet/login@v1 id: nuget-login with: @@ -139,6 +141,7 @@ jobs: # are associated with individual maintainers' accounts too. user: stevesanderson - name: Publish to NuGet + if: github.ref == 'refs/heads/main' run: dotnet nuget push ./artifacts/*.nupkg --api-key ${{ steps.nuget-login.outputs.NUGET_API_KEY }} --source https://api.nuget.org/v3/index.json --skip-duplicate publish-python: @@ -153,18 +156,25 @@ jobs: - uses: actions/setup-python@v6 with: python-version: "3.12" + - uses: actions/setup-node@v6 + with: + node-version: "22.x" - name: Set up uv uses: astral-sh/setup-uv@v7 + - name: Install Node.js dependencies (for CLI version) + working-directory: ./nodejs + run: npm ci --ignore-scripts - name: Set version run: sed -i "s/^version = .*/version = \"${{ needs.version.outputs.version }}\"/" pyproject.toml - - name: Build package - run: uv build + - name: Build platform wheels + run: node scripts/build-wheels.mjs --output-dir dist - name: Upload artifact uses: actions/upload-artifact@v6 with: name: python-package path: python/dist/* - name: Publish to PyPI + if: github.ref == 'refs/heads/main' uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: python/dist/ diff --git a/.github/workflows/python-sdk-tests.yml b/.github/workflows/python-sdk-tests.yml index 06c62e511..560288d2d 100644 --- a/.github/workflows/python-sdk-tests.yml +++ b/.github/workflows/python-sdk-tests.yml @@ -4,13 +4,15 @@ env: PYTHONUTF8: 1 on: + push: + branches: + - main pull_request: paths: - 'python/**' - 'test/**' - 'nodejs/package.json' - '.github/workflows/python-sdk-tests.yml' - - '.github/actions/setup-copilot/**' - '!**/*.md' - '!**/LICENSE*' - '!**/.gitignore' @@ -42,11 +44,14 @@ jobs: working-directory: ./python steps: - uses: actions/checkout@v6.0.2 - - uses: ./.github/actions/setup-copilot - id: setup-copilot - uses: actions/setup-python@v6 with: python-version: "3.12" + - uses: actions/setup-node@v6 + with: + node-version: "24" + cache: "npm" + cache-dependency-path: "./nodejs/package-lock.json" - name: Set up uv uses: astral-sh/setup-uv@v7 @@ -56,6 +61,10 @@ jobs: - name: Install Python dev dependencies run: uv sync --locked --all-extras --dev + - name: Install Node.js dependencies (for CLI in tests) + working-directory: ./nodejs + run: npm ci --ignore-scripts + - name: Run ruff format check run: uv run ruff format --check . @@ -76,5 +85,4 @@ jobs: - name: Run Python SDK tests env: COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} - COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} run: uv run pytest -v -s diff --git a/docs/auth/byok.md b/docs/auth/byok.md index 6c8367435..b244c4532 100644 --- a/docs/auth/byok.md +++ b/docs/auth/byok.md @@ -272,19 +272,23 @@ provider: { } ``` +> **Note:** The `bearerToken` option accepts a **static token string** only. The SDK does not refresh this token automatically. If your token expires, requests will fail and you'll need to create a new session with a fresh token. + ## Limitations When using BYOK, be aware of these limitations: ### Identity Limitations -BYOK authentication is **key-based only**. The following identity providers are NOT supported: +BYOK authentication uses **static credentials only**. The following identity providers are NOT supported: - ❌ **Microsoft Entra ID (Azure AD)** - No support for Entra managed identities or service principals - ❌ **Third-party identity providers** - No OIDC, SAML, or other federated identity - ❌ **Managed identities** - Azure Managed Identity is not supported -You must use an API key or bearer token that you manage yourself. +You must use an API key or static bearer token that you manage yourself. + +**Why not Entra ID?** While Entra ID does issue bearer tokens, these tokens are short-lived (typically 1 hour) and require automatic refresh via the Azure Identity SDK. The `bearerToken` option only accepts a static string—there is no callback mechanism for the SDK to request fresh tokens. For long-running workloads requiring Entra authentication, you would need to implement your own token refresh logic and create new sessions with updated tokens. ### Feature Limitations diff --git a/dotnet/.gitignore b/dotnet/.gitignore index fda46a3e3..ef38c1ee2 100644 --- a/dotnet/.gitignore +++ b/dotnet/.gitignore @@ -2,6 +2,9 @@ bin/ obj/ +# Generated build props (contains CLI version) +src/build/GitHub.Copilot.SDK.props + # NuGet packages *.nupkg *.snupkg diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index e2b86b145..74f1c66f2 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -873,7 +873,9 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio private static async Task<(Process Process, int? DetectedLocalhostTcpPort)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) { - var cliPath = options.CliPath ?? "copilot"; + // Use explicit path or bundled CLI - no PATH fallback + var cliPath = options.CliPath ?? GetBundledCliPath(out var searchedPath) + ?? throw new InvalidOperationException($"Copilot CLI not found at '{searchedPath}'. Ensure the SDK NuGet package was restored correctly or provide an explicit CliPath."); var args = new List(); if (options.CliArgs != null) @@ -881,7 +883,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio args.AddRange(options.CliArgs); } - args.AddRange(["--headless", "--log-level", options.LogLevel]); + args.AddRange(["--headless", "--no-auto-update", "--log-level", options.LogLevel]); if (options.UseStdio) { @@ -976,6 +978,14 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio return (cliProcess, detectedLocalhostTcpPort); } + private static string? GetBundledCliPath(out string searchedPath) + { + var binaryName = OperatingSystem.IsWindows() ? "copilot.exe" : "copilot"; + var rid = Path.GetFileName(System.Runtime.InteropServices.RuntimeInformation.RuntimeIdentifier); + searchedPath = Path.Combine(AppContext.BaseDirectory, "runtimes", rid, "native", binaryName); + return File.Exists(searchedPath) ? searchedPath : null; + } + private static (string FileName, IEnumerable Args) ResolveCliCommand(string cliPath, IEnumerable args) { var isJsFile = cliPath.EndsWith(".js", StringComparison.OrdinalIgnoreCase); @@ -985,13 +995,6 @@ private static (string FileName, IEnumerable Args) ResolveCliCommand(str return ("node", new[] { cliPath }.Concat(args)); } - // On Windows with UseShellExecute=false, Process.Start doesn't search PATHEXT, - // so use cmd /c to let the shell resolve the executable - if (OperatingSystem.IsWindows() && !Path.IsPathRooted(cliPath)) - { - return ("cmd", new[] { "/c", cliPath }.Concat(args)); - } - return (cliPath, args); } diff --git a/dotnet/src/Generated/SessionEvents.cs b/dotnet/src/Generated/SessionEvents.cs index 04820060f..022588396 100644 --- a/dotnet/src/Generated/SessionEvents.cs +++ b/dotnet/src/Generated/SessionEvents.cs @@ -6,7 +6,7 @@ // // Generated from: @github/copilot/session-events.schema.json // Generated by: scripts/generate-session-types.ts -// Generated at: 2026-02-03T20:40:49.743Z +// Generated at: 2026-02-06T20:38:23.832Z // // To update these types: // 1. Update the schema in copilot-agent-runtime @@ -765,6 +765,10 @@ public partial class SessionCompactionCompleteData [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("compactionTokensUsed")] public SessionCompactionCompleteDataCompactionTokensUsed? CompactionTokensUsed { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("requestId")] + public string? RequestId { get; set; } } public partial class UserMessageData diff --git a/dotnet/src/GitHub.Copilot.SDK.csproj b/dotnet/src/GitHub.Copilot.SDK.csproj index d20364ef2..019788cfa 100644 --- a/dotnet/src/GitHub.Copilot.SDK.csproj +++ b/dotnet/src/GitHub.Copilot.SDK.csproj @@ -1,31 +1,60 @@  + + net8.0 + enable + enable + true + 0.1.0 + SDK for programmatic control of GitHub Copilot CLI + GitHub + GitHub + Copyright (c) Microsoft Corporation. All rights reserved. + MIT + README.md + https://github.com/github/copilot-sdk + github;copilot;sdk;jsonrpc;agent + true + + + + + + + + + + + + + + + + + + + + + <_VersionPropsContent> + - net8.0 - enable - enable - true - 0.1.0 - SDK for programmatic control of GitHub Copilot CLI - GitHub - GitHub - Copyright (c) Microsoft Corporation. All rights reserved. - MIT - README.md - https://github.com/github/copilot-sdk - github;copilot;sdk;jsonrpc;agent - true + $(CopilotCliVersion) +]]> + + + + + + + + - - - - - - - - - - + + + + + + diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 28f4e2e7c..664b35d9e 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -24,6 +24,9 @@ public enum ConnectionState public class CopilotClientOptions { + /// + /// Path to the Copilot CLI executable. If not specified, uses the bundled CLI from the SDK. + /// public string? CliPath { get; set; } public string[]? CliArgs { get; set; } public string? Cwd { get; set; } diff --git a/dotnet/src/build/GitHub.Copilot.SDK.targets b/dotnet/src/build/GitHub.Copilot.SDK.targets new file mode 100644 index 000000000..20afd8156 --- /dev/null +++ b/dotnet/src/build/GitHub.Copilot.SDK.targets @@ -0,0 +1,84 @@ + + + + + + + + <_CopilotRid Condition="'$(RuntimeIdentifier)' != ''">$(RuntimeIdentifier) + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('Windows')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'X64'">win-x64 + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('Windows')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'Arm64'">win-arm64 + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('Linux')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'X64'">linux-x64 + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('Linux')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'Arm64'">linux-arm64 + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('OSX')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'X64'">osx-x64 + <_CopilotRid Condition="'$(_CopilotRid)' == '' And $([MSBuild]::IsOSPlatform('OSX')) And '$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)' == 'Arm64'">osx-arm64 + + + + + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'win-x64'">win32-x64 + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'win-arm64'">win32-arm64 + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'linux-x64'">linux-x64 + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'linux-arm64'">linux-arm64 + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'osx-x64'">darwin-x64 + <_CopilotPlatform Condition="'$(_CopilotRid)' == 'osx-arm64'">darwin-arm64 + <_CopilotBinary Condition="$(_CopilotRid.StartsWith('win-'))">copilot.exe + <_CopilotBinary Condition="'$(_CopilotBinary)' == ''">copilot + + + + + + + + + <_CopilotCacheDir>$(IntermediateOutputPath)copilot-cli\$(CopilotCliVersion)\$(_CopilotPlatform) + <_CopilotCliBinaryPath>$(_CopilotCacheDir)\$(_CopilotBinary) + <_CopilotArchivePath>$(_CopilotCacheDir)\copilot.tgz + <_CopilotDownloadUrl>https://registry.npmjs.org/@github/copilot-$(_CopilotPlatform)/-/copilot-$(_CopilotPlatform)-$(CopilotCliVersion).tgz + + + + + + + + + + + + + <_TarCommand Condition="$([MSBuild]::IsOSPlatform('Windows'))">$(SystemRoot)\System32\tar.exe + <_TarCommand Condition="'$(_TarCommand)' == ''">tar + + + + + + + + + + <_CopilotCacheDir>$(IntermediateOutputPath)copilot-cli\$(CopilotCliVersion)\$(_CopilotPlatform) + <_CopilotCliBinaryPath>$(_CopilotCacheDir)\$(_CopilotBinary) + <_CopilotOutputDir>$(OutDir)runtimes\$(_CopilotRid)\native + + + + + + + + + <_CopilotCacheDir>$(IntermediateOutputPath)copilot-cli\$(CopilotCliVersion)\$(_CopilotPlatform) + <_CopilotCliBinaryPath>$(_CopilotCacheDir)\$(_CopilotBinary) + + + + + + diff --git a/dotnet/test/ClientTests.cs b/dotnet/test/ClientTests.cs index f433e677c..e3419f981 100644 --- a/dotnet/test/ClientTests.cs +++ b/dotnet/test/ClientTests.cs @@ -8,37 +8,12 @@ namespace GitHub.Copilot.SDK.Test; // These tests bypass E2ETestBase because they are about how the CLI subprocess is started // Other test classes should instead inherit from E2ETestBase -public class ClientTests : IAsyncLifetime +public class ClientTests { - private string _cliPath = null!; - - public Task InitializeAsync() - { - _cliPath = GetCliPath(); - return Task.CompletedTask; - } - - public Task DisposeAsync() => Task.CompletedTask; - - private static string GetCliPath() - { - var envPath = Environment.GetEnvironmentVariable("COPILOT_CLI_PATH"); - if (!string.IsNullOrEmpty(envPath)) return envPath; - - var dir = new DirectoryInfo(AppContext.BaseDirectory); - while (dir != null) - { - var path = Path.Combine(dir.FullName, "nodejs/node_modules/@github/copilot/index.js"); - if (File.Exists(path)) return path; - dir = dir.Parent; - } - throw new InvalidOperationException("CLI not found. Run 'npm install' in the nodejs directory first."); - } - [Fact] public async Task Should_Start_And_Connect_To_Server_Using_Stdio() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath, UseStdio = true }); + using var client = new CopilotClient(new CopilotClientOptions { UseStdio = true }); try { @@ -61,7 +36,7 @@ public async Task Should_Start_And_Connect_To_Server_Using_Stdio() [Fact] public async Task Should_Start_And_Connect_To_Server_Using_Tcp() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath, UseStdio = false }); + using var client = new CopilotClient(new CopilotClientOptions { UseStdio = false }); try { @@ -82,7 +57,7 @@ public async Task Should_Start_And_Connect_To_Server_Using_Tcp() [Fact] public async Task Should_Force_Stop_Without_Cleanup() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath }); + using var client = new CopilotClient(new CopilotClientOptions()); await client.CreateSessionAsync(); await client.ForceStopAsync(); @@ -93,7 +68,7 @@ public async Task Should_Force_Stop_Without_Cleanup() [Fact] public async Task Should_Get_Status_With_Version_And_Protocol_Info() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath, UseStdio = true }); + using var client = new CopilotClient(new CopilotClientOptions { UseStdio = true }); try { @@ -115,7 +90,7 @@ public async Task Should_Get_Status_With_Version_And_Protocol_Info() [Fact] public async Task Should_Get_Auth_Status() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath, UseStdio = true }); + using var client = new CopilotClient(new CopilotClientOptions { UseStdio = true }); try { @@ -140,7 +115,7 @@ public async Task Should_Get_Auth_Status() [Fact] public async Task Should_List_Models_When_Authenticated() { - using var client = new CopilotClient(new CopilotClientOptions { CliPath = _cliPath, UseStdio = true }); + using var client = new CopilotClient(new CopilotClientOptions { UseStdio = true }); try { @@ -178,7 +153,6 @@ public void Should_Accept_GithubToken_Option() { var options = new CopilotClientOptions { - CliPath = _cliPath, GithubToken = "gho_test_token" }; @@ -188,7 +162,7 @@ public void Should_Accept_GithubToken_Option() [Fact] public void Should_Default_UseLoggedInUser_To_Null() { - var options = new CopilotClientOptions { CliPath = _cliPath }; + var options = new CopilotClientOptions(); Assert.Null(options.UseLoggedInUser); } @@ -198,7 +172,6 @@ public void Should_Allow_Explicit_UseLoggedInUser_False() { var options = new CopilotClientOptions { - CliPath = _cliPath, UseLoggedInUser = false }; @@ -210,7 +183,6 @@ public void Should_Allow_Explicit_UseLoggedInUser_True_With_GithubToken() { var options = new CopilotClientOptions { - CliPath = _cliPath, GithubToken = "gho_test_token", UseLoggedInUser = true }; diff --git a/dotnet/test/Harness/E2ETestContext.cs b/dotnet/test/Harness/E2ETestContext.cs index b8727ed5c..2518ca69e 100644 --- a/dotnet/test/Harness/E2ETestContext.cs +++ b/dotnet/test/Harness/E2ETestContext.cs @@ -9,7 +9,6 @@ namespace GitHub.Copilot.SDK.Test.Harness; public class E2ETestContext : IAsyncDisposable { - public string CliPath { get; } public string HomeDir { get; } public string WorkDir { get; } public string ProxyUrl { get; } @@ -17,9 +16,8 @@ public class E2ETestContext : IAsyncDisposable private readonly CapiProxy _proxy; private readonly string _repoRoot; - private E2ETestContext(string cliPath, string homeDir, string workDir, string proxyUrl, CapiProxy proxy, string repoRoot) + private E2ETestContext(string homeDir, string workDir, string proxyUrl, CapiProxy proxy, string repoRoot) { - CliPath = cliPath; HomeDir = homeDir; WorkDir = workDir; ProxyUrl = proxyUrl; @@ -30,7 +28,6 @@ private E2ETestContext(string cliPath, string homeDir, string workDir, string pr public static async Task CreateAsync() { var repoRoot = FindRepoRoot(); - var cliPath = GetCliPath(repoRoot); var homeDir = Path.Combine(Path.GetTempPath(), $"copilot-test-config-{Guid.NewGuid()}"); var workDir = Path.Combine(Path.GetTempPath(), $"copilot-test-work-{Guid.NewGuid()}"); @@ -41,7 +38,7 @@ public static async Task CreateAsync() var proxy = new CapiProxy(); var proxyUrl = await proxy.StartAsync(); - return new E2ETestContext(cliPath, homeDir, workDir, proxyUrl, proxy, repoRoot); + return new E2ETestContext(homeDir, workDir, proxyUrl, proxy, repoRoot); } private static string FindRepoRoot() @@ -94,7 +91,6 @@ public IReadOnlyDictionary GetEnvironment() public CopilotClient CreateClient() => new(new CopilotClientOptions { - CliPath = CliPath, Cwd = WorkDir, Environment = GetEnvironment(), GithubToken = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI")) ? "fake-token-for-e2e-tests" : null, diff --git a/go/client.go b/go/client.go index 4e680ac2f..319c6588c 100644 --- a/go/client.go +++ b/go/client.go @@ -396,36 +396,6 @@ func (c *Client) ForceStop() { } } -// buildProviderParams converts a ProviderConfig to a map for JSON-RPC params. -func buildProviderParams(p *ProviderConfig) map[string]any { - params := make(map[string]any) - if p.Type != "" { - params["type"] = p.Type - } - if p.WireApi != "" { - params["wireApi"] = p.WireApi - } - if p.BaseURL != "" { - params["baseUrl"] = p.BaseURL - } - if p.APIKey != "" { - params["apiKey"] = p.APIKey - } - if p.BearerToken != "" { - params["bearerToken"] = p.BearerToken - } - if p.Azure != nil { - azure := make(map[string]any) - if p.Azure.APIVersion != "" { - azure["apiVersion"] = p.Azure.APIVersion - } - if len(azure) > 0 { - params["azure"] = azure - } - } - return params -} - func (c *Client) ensureConnected() error { if c.client != nil { return nil @@ -467,166 +437,54 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return nil, err } - params := make(map[string]any) + req := createSessionRequest{} if config != nil { - if config.Model != "" { - params["model"] = config.Model - } - if config.SessionID != "" { - params["sessionId"] = config.SessionID - } - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - // Add system message configuration if provided - if config.SystemMessage != nil { - systemMessage := make(map[string]any) + req.Model = config.Model + req.SessionID = config.SessionID + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + req.Tools = config.Tools + req.SystemMessage = config.SystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.Provider = config.Provider + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions - if config.SystemMessage.Mode != "" { - systemMessage["mode"] = config.SystemMessage.Mode - } - - if config.SystemMessage.Mode == "replace" { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } else { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } - - if len(systemMessage) > 0 { - params["systemMessage"] = systemMessage - } - } - // Add tool filtering options - if len(config.AvailableTools) > 0 { - params["availableTools"] = config.AvailableTools - } - if len(config.ExcludedTools) > 0 { - params["excludedTools"] = config.ExcludedTools - } - // Add streaming option if config.Streaming { - params["streaming"] = config.Streaming - } - // Add provider configuration - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory - } - // Add MCP servers configuration - if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add config directory override - if config.ConfigDir != "" { - params["configDir"] = config.ConfigDir - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills - } - // Add infinite sessions configuration - if config.InfiniteSessions != nil { - infiniteSessions := make(map[string]any) - if config.InfiniteSessions.Enabled != nil { - infiniteSessions["enabled"] = *config.InfiniteSessions.Enabled - } - if config.InfiniteSessions.BackgroundCompactionThreshold != nil { - infiniteSessions["backgroundCompactionThreshold"] = *config.InfiniteSessions.BackgroundCompactionThreshold - } - if config.InfiniteSessions.BufferExhaustionThreshold != nil { - infiniteSessions["bufferExhaustionThreshold"] = *config.InfiniteSessions.BufferExhaustionThreshold - } - params["infiniteSessions"] = infiniteSessions + req.Hooks = Bool(true) } } - result, err := c.client.Request("session.create", params) + result, err := c.client.Request("session.create", req) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } - sessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(sessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) @@ -644,7 +502,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } c.sessionsMux.Lock() - c.sessions[sessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -676,167 +534,56 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, err } - params := map[string]any{ - "sessionId": sessionID, - } - + var req resumeSessionRequest + req.SessionID = sessionID if config != nil { - if config.Model != "" { - params["model"] = config.Model - } - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if config.SystemMessage != nil { - systemMessage := make(map[string]any) - - if config.SystemMessage.Mode != "" { - systemMessage["mode"] = config.SystemMessage.Mode - } - - if config.SystemMessage.Mode == "replace" { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } else { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } - - if len(systemMessage) > 0 { - params["systemMessage"] = systemMessage - } - } - if len(config.AvailableTools) > 0 { - params["availableTools"] = config.AvailableTools - } - if len(config.ExcludedTools) > 0 { - params["excludedTools"] = config.ExcludedTools - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) - } - // Add streaming option + req.Model = config.Model + req.ReasoningEffort = config.ReasoningEffort + req.SystemMessage = config.SystemMessage + req.Tools = config.Tools + req.Provider = config.Provider + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools if config.Streaming { - params["streaming"] = config.Streaming + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory + req.Hooks = Bool(true) } - // Add config directory - if config.ConfigDir != "" { - params["configDir"] = config.ConfigDir - } - // Add disable resume flag + req.WorkingDirectory = config.WorkingDirectory + req.ConfigDir = config.ConfigDir if config.DisableResume { - params["disableResume"] = true - } - // Add MCP servers configuration - if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills - } - // Add infinite sessions configuration - if config.InfiniteSessions != nil { - infiniteSessions := map[string]any{} - if config.InfiniteSessions.Enabled != nil { - infiniteSessions["enabled"] = *config.InfiniteSessions.Enabled - } - if config.InfiniteSessions.BackgroundCompactionThreshold != nil { - infiniteSessions["backgroundCompactionThreshold"] = *config.InfiniteSessions.BackgroundCompactionThreshold - } - if config.InfiniteSessions.BufferExhaustionThreshold != nil { - infiniteSessions["bufferExhaustionThreshold"] = *config.InfiniteSessions.BufferExhaustionThreshold - } - params["infiniteSessions"] = infiniteSessions + req.DisableResume = Bool(true) } + req.MCPServers = config.MCPServers + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions } - result, err := c.client.Request("session.resume", params) + result, err := c.client.Request("session.resume", req) if err != nil { return nil, fmt.Errorf("failed to resume session: %w", err) } - resumedSessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response resumeSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(resumedSessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) if config.OnPermissionRequest != nil { @@ -853,7 +600,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, } c.sessionsMux.Lock() - c.sessions[resumedSessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -878,19 +625,13 @@ func (c *Client) ListSessions(ctx context.Context) ([]SessionMetadata, error) { return nil, err } - result, err := c.client.Request("session.list", map[string]any{}) + result, err := c.client.Request("session.list", listSessionsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal sessions response: %w", err) - } - - var response ListSessionsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listSessionsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal sessions response: %w", err) } @@ -912,23 +653,13 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { return err } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.delete", params) + result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID}) if err != nil { return err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal delete response: %w", err) - } - - var response DeleteSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response deleteSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal delete response: %w", err) } @@ -973,18 +704,13 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { } } - result, err := c.client.Request("session.getForeground", map[string]any{}) + result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{}) if err != nil { return nil, err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal getForeground response: %w", err) - } - - var response GetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response getForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal getForeground response: %w", err) } @@ -1012,22 +738,13 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e } } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.setForeground", params) + result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID}) if err != nil { return err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal setForeground response: %w", err) - } - - var response SetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response setForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal setForeground response: %w", err) } @@ -1104,8 +821,8 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } } -// dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers -func (c *Client) dispatchLifecycleEvent(event SessionLifecycleEvent) { +// handleLifecycleEvent dispatches a lifecycle event to all registered handlers +func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks typedHandlers := make([]SessionLifecycleHandler, 0) @@ -1164,29 +881,16 @@ func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error return nil, fmt.Errorf("client not connected") } - params := map[string]any{} - if message != "" { - params["message"] = message - } - - result, err := c.client.Request("ping", params) + result, err := c.client.Request("ping", pingRequest{Message: message}) if err != nil { return nil, err } - response := &PingResponse{} - if msg, ok := result["message"].(string); ok { - response.Message = msg - } - if ts, ok := result["timestamp"].(float64); ok { - response.Timestamp = int64(ts) - } - if pv, ok := result["protocolVersion"].(float64); ok { - v := int(pv) - response.ProtocolVersion = &v + var response PingResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetStatus returns CLI status including version and protocol information @@ -1195,20 +899,16 @@ func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("status.get", map[string]any{}) + result, err := c.client.Request("status.get", getStatusRequest{}) if err != nil { return nil, err } - response := &GetStatusResponse{} - if v, ok := result["version"].(string); ok { - response.Version = v - } - if pv, ok := result["protocolVersion"].(float64); ok { - response.ProtocolVersion = int(pv) + var response GetStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetAuthStatus returns current authentication status @@ -1217,29 +917,16 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("auth.getStatus", map[string]any{}) + result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{}) if err != nil { return nil, err } - response := &GetAuthStatusResponse{} - if v, ok := result["isAuthenticated"].(bool); ok { - response.IsAuthenticated = v - } - if v, ok := result["authType"].(string); ok { - response.AuthType = &v - } - if v, ok := result["host"].(string); ok { - response.Host = &v - } - if v, ok := result["login"].(string); ok { - response.Login = &v - } - if v, ok := result["statusMessage"].(string); ok { - response.StatusMessage = &v + var response GetAuthStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // ListModels returns available models with their metadata. @@ -1264,19 +951,13 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { } // Cache miss - fetch from backend while holding lock - result, err := c.client.Request("models.list", map[string]any{}) + result, err := c.client.Request("models.list", listModelsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal models response: %w", err) - } - - var response GetModelsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listModelsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal models response: %w", err) } @@ -1313,7 +994,7 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error { // This spawns the CLI server as a subprocess using the configured transport // mode (stdio or TCP). func (c *Client) startCLIServer(ctx context.Context) error { - args := []string{"--headless", "--log-level", c.options.LogLevel} + args := []string{"--headless", "--no-auto-update", "--log-level", c.options.LogLevel} // Choose transport mode if c.useStdio { @@ -1470,82 +1151,48 @@ func (c *Client) connectViaTcp(ctx context.Context) error { // setupNotificationHandler configures handlers for session events, tool calls, and permission requests. func (c *Client) setupNotificationHandler() { - c.client.SetNotificationHandler(func(method string, params map[string]any) { - switch method { - case "session.event": - // Extract sessionId and event - sessionID, ok := params["sessionId"].(string) - if !ok { - return - } - - // Marshal the event back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(params["event"]) - if err != nil { - return - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - return - } - - // Dispatch to session - c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] - c.sessionsMux.Unlock() - - if ok { - session.dispatchEvent(event) - } - case "session.lifecycle": - // Handle session lifecycle events - eventJSON, err := json.Marshal(params) - if err != nil { - return - } - - var event SessionLifecycleEvent - if err := json.Unmarshal(eventJSON, &event); err != nil { - return - } + c.client.SetRequestHandler("session.event", jsonrpc2.NotificationHandlerFor(c.handleSessionEvent)) + c.client.SetRequestHandler("session.lifecycle", jsonrpc2.NotificationHandlerFor(c.handleLifecycleEvent)) + c.client.SetRequestHandler("tool.call", jsonrpc2.RequestHandlerFor(c.handleToolCallRequest)) + c.client.SetRequestHandler("permission.request", jsonrpc2.RequestHandlerFor(c.handlePermissionRequest)) + c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) + c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) +} - c.dispatchLifecycleEvent(event) - } - }) +func (c *Client) handleSessionEvent(req sessionEventRequest) { + if req.SessionID == "" { + return + } + // Dispatch to session + c.sessionsMux.Lock() + session, ok := c.sessions[req.SessionID] + c.sessionsMux.Unlock() - c.client.SetRequestHandler("tool.call", c.handleToolCallRequest) - c.client.SetRequestHandler("permission.request", c.handlePermissionRequest) - c.client.SetRequestHandler("userInput.request", c.handleUserInputRequest) - c.client.SetRequestHandler("hooks.invoke", c.handleHooksInvoke) + if ok { + session.dispatchEvent(req.Event) + } } // handleToolCallRequest handles a tool call request from the CLI server. -func (c *Client) handleToolCallRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - toolCallID, _ := params["toolCallId"].(string) - toolName, _ := params["toolName"].(string) - - if sessionID == "" || toolCallID == "" || toolName == "" { +func (c *Client) handleToolCallRequest(req toolCallRequest) (*toolCallResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.ToolCallID == "" || req.ToolName == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - handler, ok := session.getToolHandler(toolName) + handler, ok := session.getToolHandler(req.ToolName) if !ok { - return map[string]any{"result": buildUnsupportedToolResult(toolName)}, nil + return &toolCallResponse{Result: buildUnsupportedToolResult(req.ToolName)}, nil } - arguments := params["arguments"] - result := c.executeToolCall(sessionID, toolCallID, toolName, arguments, handler) - - return map[string]any{"result": result}, nil + result := c.executeToolCall(req.SessionID, req.ToolCallID, req.ToolName, req.Arguments, handler) + return &toolCallResponse{Result: result}, nil } // executeToolCall executes a tool handler and returns the result. @@ -1579,100 +1226,70 @@ func (c *Client) executeToolCall( } // handlePermissionRequest handles a permission request from the CLI server. -func (c *Client) handlePermissionRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - permissionRequest, _ := params["permissionRequest"].(map[string]any) - - if sessionID == "" { +func (c *Client) handlePermissionRequest(req permissionRequestRequest) (*permissionRequestResponse, *jsonrpc2.Error) { + if req.SessionID == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - result, err := session.handlePermissionRequest(permissionRequest) + result, err := session.handlePermissionRequest(req.Request) if err != nil { // Return denial on error - return map[string]any{ - "result": map[string]any{ - "kind": "denied-no-approval-rule-and-could-not-request-from-user", + return &permissionRequestResponse{ + Result: PermissionRequestResult{ + Kind: "denied-no-approval-rule-and-could-not-request-from-user", }, }, nil } - return map[string]any{"result": result}, nil + return &permissionRequestResponse{Result: result}, nil } // handleUserInputRequest handles a user input request from the CLI server. -func (c *Client) handleUserInputRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - question, _ := params["question"].(string) - - if sessionID == "" || question == "" { +func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - // Parse choices - var choices []string - if choicesRaw, ok := params["choices"].([]any); ok { - for _, choice := range choicesRaw { - if s, ok := choice.(string); ok { - choices = append(choices, s) - } - } - } - - var allowFreeform *bool - if af, ok := params["allowFreeform"].(bool); ok { - allowFreeform = &af - } - - request := UserInputRequest{ - Question: question, - Choices: choices, - AllowFreeform: allowFreeform, - } - - response, err := session.handleUserInputRequest(request) + response, err := session.handleUserInputRequest(UserInputRequest{ + Question: req.Question, + Choices: req.Choices, + AllowFreeform: req.AllowFreeform, + }) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } - return map[string]any{ - "answer": response.Answer, - "wasFreeform": response.WasFreeform, - }, nil + return &userInputResponse{Answer: response.Answer, WasFreeform: response.WasFreeform}, nil } // handleHooksInvoke handles a hooks invocation from the CLI server. -func (c *Client) handleHooksInvoke(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - hookType, _ := params["hookType"].(string) - input, _ := params["input"].(map[string]any) - - if sessionID == "" || hookType == "" { +func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jsonrpc2.Error) { + if req.SessionID == "" || req.Type == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - output, err := session.handleHooksInvoke(hookType, input) + output, err := session.handleHooksInvoke(req.Type, req.Input) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } diff --git a/go/client_test.go b/go/client_test.go index 185bb4cbc..176dad8c5 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -25,25 +25,20 @@ func TestClient_HandleToolCallRequest(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - params := map[string]any{ - "sessionId": session.SessionID, - "toolCallId": "123", - "toolName": "missing_tool", - "arguments": map[string]any{}, + params := toolCallRequest{ + SessionID: session.SessionID, + ToolCallID: "123", + ToolName: "missing_tool", + Arguments: map[string]any{}, } response, _ := client.handleToolCallRequest(params) - result, ok := response["result"].(ToolResult) - if !ok { - t.Fatalf("Expected result to be ToolResult, got %T", response["result"]) + if response.Result.ResultType != "failure" { + t.Errorf("Expected resultType to be 'failure', got %q", response.Result.ResultType) } - if result.ResultType != "failure" { - t.Errorf("Expected resultType to be 'failure', got %q", result.ResultType) - } - - if result.Error != "tool 'missing_tool' not supported" { - t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", result.Error) + if response.Result.Error != "tool 'missing_tool' not supported" { + t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", response.Result.Error) } }) } diff --git a/go/generated_session_events.go b/go/generated_session_events.go index ec3e6c17d..ec4de9bea 100644 --- a/go/generated_session_events.go +++ b/go/generated_session_events.go @@ -2,7 +2,7 @@ // // Generated from: @github/copilot/session-events.schema.json // Generated by: scripts/generate-session-types.ts -// Generated at: 2026-02-03T20:40:49.610Z +// Generated at: 2026-02-06T20:38:23.463Z // // To update these types: // 1. Update the schema in copilot-agent-runtime @@ -92,6 +92,7 @@ type Data struct { PostCompactionTokens *float64 `json:"postCompactionTokens,omitempty"` PreCompactionMessagesLength *float64 `json:"preCompactionMessagesLength,omitempty"` PreCompactionTokens *float64 `json:"preCompactionTokens,omitempty"` + RequestID *string `json:"requestId,omitempty"` Success *bool `json:"success,omitempty"` SummaryContent *string `json:"summaryContent,omitempty"` TokensRemoved *float64 `json:"tokensRemoved,omitempty"` diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 8e4a0f6a0..e44e12315 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "reflect" "sync" ) @@ -23,43 +24,39 @@ func (e *Error) Error() string { // Request represents a JSON-RPC 2.0 request type Request struct { JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` + ID json.RawMessage `json:"id"` // nil for notifications Method string `json:"method"` - Params map[string]any `json:"params"` + Params json.RawMessage `json:"params"` +} + +func (r *Request) IsCall() bool { + return len(r.ID) > 0 } // Response represents a JSON-RPC 2.0 response type Response struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` - Result map[string]any `json:"result,omitempty"` + Result json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` } -// Notification represents a JSON-RPC 2.0 notification -type Notification struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]any `json:"params"` -} - // NotificationHandler handles incoming notifications -type NotificationHandler func(method string, params map[string]any) +type NotificationHandler func(method string, params json.RawMessage) // RequestHandler handles incoming server requests and returns a result or error -type RequestHandler func(params map[string]any) (map[string]any, *Error) +type RequestHandler func(params json.RawMessage) (json.RawMessage, *Error) // Client is a minimal JSON-RPC 2.0 client for stdio transport type Client struct { - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - pendingRequests map[string]chan *Response - notificationHandler NotificationHandler - requestHandlers map[string]RequestHandler - running bool - stopChan chan struct{} - wg sync.WaitGroup + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + pendingRequests map[string]chan *Response + requestHandlers map[string]RequestHandler + running bool + stopChan chan struct{} + wg sync.WaitGroup } // NewClient creates a new JSON-RPC client @@ -96,11 +93,55 @@ func (c *Client) Stop() { c.wg.Wait() } -// SetNotificationHandler sets the handler for incoming notifications -func (c *Client) SetNotificationHandler(handler NotificationHandler) { - c.mu.Lock() - defer c.mu.Unlock() - c.notificationHandler = handler +func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeFor[In](); t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + handler(in) + return nil, nil + } +} + +// RequestHandlerFor creates a RequestHandler from a typed function +func RequestHandlerFor[In, Out any](handler func(params In) (Out, *Error)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + out, errj := handler(in) + if errj != nil { + return nil, errj + } + outData, err := json.Marshal(out) + if err != nil { + return nil, &Error{ + Code: -32603, + Message: fmt.Sprintf("Failed to marshal response: %v", err), + } + } + return outData, nil + } } // SetRequestHandler registers a handler for incoming requests from the server @@ -115,7 +156,7 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) { } // Request sends a JSON-RPC request and waits for the response -func (c *Client) Request(method string, params map[string]any) (map[string]any, error) { +func (c *Client) Request(method string, params any) (json.RawMessage, error) { requestID := generateUUID() // Create response channel @@ -131,12 +172,17 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, c.mu.Unlock() }() + paramsData, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + // Send request request := Request{ JSONRPC: "2.0", ID: json.RawMessage(`"` + requestID + `"`), Method: method, - Params: params, + Params: json.RawMessage(paramsData), } if err := c.sendMessage(request); err != nil { @@ -156,11 +202,16 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, } // Notify sends a JSON-RPC notification (no response expected) -func (c *Client) Notify(method string, params map[string]any) error { - notification := Notification{ +func (c *Client) Notify(method string, params any) error { + paramsData, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + + notification := Request{ JSONRPC: "2.0", Method: method, - Params: params, + Params: json.RawMessage(paramsData), } return c.sendMessage(notification) } @@ -231,7 +282,7 @@ func (c *Client) readLoop() { // Try to parse as request first (has both ID and Method) var request Request - if err := json.Unmarshal(body, &request); err == nil && request.Method != "" && len(request.ID) > 0 { + if err := json.Unmarshal(body, &request); err == nil && request.Method != "" { c.handleRequest(&request) continue } @@ -242,13 +293,6 @@ func (c *Client) readLoop() { c.handleResponse(&response) continue } - - // Try to parse as notification (has Method but no ID) - var notification Notification - if err := json.Unmarshal(body, ¬ification); err == nil && notification.Method != "" { - c.handleNotification(¬ification) - continue - } } } @@ -270,24 +314,21 @@ func (c *Client) handleResponse(response *Response) { } } -// handleNotification dispatches a notification to the handler -func (c *Client) handleNotification(notification *Notification) { - c.mu.Lock() - handler := c.notificationHandler - c.mu.Unlock() - - if handler != nil { - handler(notification.Method, notification.Params) - } -} - func (c *Client) handleRequest(request *Request) { c.mu.Lock() handler := c.requestHandlers[request.Method] c.mu.Unlock() if handler == nil { - c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + } + return + } + + // Notifications run synchronously, calls run in a goroutine to avoid blocking + if !request.IsCall() { + handler(request.Params) return } @@ -303,14 +344,11 @@ func (c *Client) handleRequest(request *Request) { c.sendErrorResponse(request.ID, err.Code, err.Message, err.Data) return } - if result == nil { - result = make(map[string]any) - } c.sendResponse(request.ID, result) }() } -func (c *Client) sendResponse(id json.RawMessage, result map[string]any) { +func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) { response := Response{ JSONRPC: "2.0", ID: id, diff --git a/go/session.go b/go/session.go index e4f1473df..37cfe52f8 100644 --- a/go/session.go +++ b/go/session.go @@ -106,29 +106,23 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // log.Printf("Failed to send message: %v", err) // } func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { - params := map[string]any{ - "sessionId": s.SessionID, - "prompt": options.Prompt, + req := sessionSendRequest{ + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, } - if options.Attachments != nil { - params["attachments"] = options.Attachments - } - if options.Mode != "" { - params["mode"] = options.Mode - } - - result, err := s.client.Request("session.send", params) + result, err := s.client.Request("session.send", req) if err != nil { return "", fmt.Errorf("failed to send message: %w", err) } - messageID, ok := result["messageId"].(string) - if !ok { - return "", fmt.Errorf("invalid response: missing messageId") + var response sessionSendResponse + if err := json.Unmarshal(result, &response); err != nil { + return "", fmt.Errorf("failed to unmarshal send response: %w", err) } - - return messageID, nil + return response.MessageID, nil } // SendAndWait sends a message to this session and waits until the session becomes idle. @@ -306,7 +300,7 @@ func (s *Session) getPermissionHandler() PermissionHandler { // handlePermissionRequest handles a permission request from the Copilot CLI. // This is an internal method called by the SDK when the CLI requests permission. -func (s *Session) handlePermissionRequest(requestData map[string]any) (PermissionRequestResult, error) { +func (s *Session) handlePermissionRequest(request PermissionRequest) (PermissionRequestResult, error) { handler := s.getPermissionHandler() if handler == nil { @@ -315,16 +309,6 @@ func (s *Session) handlePermissionRequest(requestData map[string]any) (Permissio }, nil } - // Convert map to PermissionRequest struct - kind, _ := requestData["kind"].(string) - toolCallID, _ := requestData["toolCallId"].(string) - - request := PermissionRequest{ - Kind: kind, - ToolCallID: toolCallID, - Extra: requestData, - } - invocation := PermissionInvocation{ SessionID: s.SessionID, } @@ -388,7 +372,7 @@ func (s *Session) getHooks() *SessionHooks { // handleHooksInvoke handles a hook invocation from the Copilot CLI. // This is an internal method called by the SDK when the CLI invokes a hook. -func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, error) { +func (s *Session) handleHooksInvoke(hookType string, rawInput json.RawMessage) (any, error) { hooks := s.getHooks() if hooks == nil { @@ -404,153 +388,66 @@ func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, if hooks.OnPreToolUse == nil { return nil, nil } - hookInput := parsePreToolUseInput(input) - return hooks.OnPreToolUse(hookInput, invocation) + var input PreToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPreToolUse(input, invocation) case "postToolUse": if hooks.OnPostToolUse == nil { return nil, nil } - hookInput := parsePostToolUseInput(input) - return hooks.OnPostToolUse(hookInput, invocation) + var input PostToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPostToolUse(input, invocation) case "userPromptSubmitted": if hooks.OnUserPromptSubmitted == nil { return nil, nil } - hookInput := parseUserPromptSubmittedInput(input) - return hooks.OnUserPromptSubmitted(hookInput, invocation) + var input UserPromptSubmittedHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnUserPromptSubmitted(input, invocation) case "sessionStart": if hooks.OnSessionStart == nil { return nil, nil } - hookInput := parseSessionStartInput(input) - return hooks.OnSessionStart(hookInput, invocation) + var input SessionStartHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionStart(input, invocation) case "sessionEnd": if hooks.OnSessionEnd == nil { return nil, nil } - hookInput := parseSessionEndInput(input) - return hooks.OnSessionEnd(hookInput, invocation) + var input SessionEndHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionEnd(input, invocation) case "errorOccurred": if hooks.OnErrorOccurred == nil { return nil, nil } - hookInput := parseErrorOccurredInput(input) - return hooks.OnErrorOccurred(hookInput, invocation) - + var input ErrorOccurredHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnErrorOccurred(input, invocation) default: return nil, fmt.Errorf("unknown hook type: %s", hookType) } } -// Helper functions to parse hook inputs - -func parsePreToolUseInput(input map[string]any) PreToolUseHookInput { - result := PreToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - return result -} - -func parsePostToolUseInput(input map[string]any) PostToolUseHookInput { - result := PostToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - result.ToolResult = input["toolResult"] - return result -} - -func parseUserPromptSubmittedInput(input map[string]any) UserPromptSubmittedHookInput { - result := UserPromptSubmittedHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if prompt, ok := input["prompt"].(string); ok { - result.Prompt = prompt - } - return result -} - -func parseSessionStartInput(input map[string]any) SessionStartHookInput { - result := SessionStartHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if source, ok := input["source"].(string); ok { - result.Source = source - } - if prompt, ok := input["initialPrompt"].(string); ok { - result.InitialPrompt = prompt - } - return result -} - -func parseSessionEndInput(input map[string]any) SessionEndHookInput { - result := SessionEndHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if reason, ok := input["reason"].(string); ok { - result.Reason = reason - } - if msg, ok := input["finalMessage"].(string); ok { - result.FinalMessage = msg - } - if errStr, ok := input["error"].(string); ok { - result.Error = errStr - } - return result -} - -func parseErrorOccurredInput(input map[string]any) ErrorOccurredHookInput { - result := ErrorOccurredHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if errMsg, ok := input["error"].(string); ok { - result.Error = errMsg - } - if ctx, ok := input["errorContext"].(string); ok { - result.ErrorContext = ctx - } - if rec, ok := input["recoverable"].(bool); ok { - result.Recoverable = rec - } - return result -} - // dispatchEvent dispatches an event to all registered handlers. // This is an internal method; handlers are called synchronously and any panics // are recovered to prevent crashing the event dispatcher. @@ -596,38 +493,17 @@ func (s *Session) dispatchEvent(event SessionEvent) { // } // } func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { - params := map[string]any{ - "sessionId": s.SessionID, - } - result, err := s.client.Request("session.getMessages", params) + result, err := s.client.Request("session.getMessages", sessionGetMessagesRequest{SessionID: s.SessionID}) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } - eventsRaw, ok := result["events"].([]any) - if !ok { - return nil, fmt.Errorf("invalid response: missing events") - } - - // Convert to SessionEvent structs - events := make([]SessionEvent, 0, len(eventsRaw)) - for _, eventRaw := range eventsRaw { - // Marshal back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(eventRaw) - if err != nil { - continue - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - continue - } - - events = append(events, event) + var response sessionGetMessagesResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal get messages response: %w", err) } - - return events, nil + return response.Events, nil } // Destroy destroys this session and releases all associated resources. @@ -645,11 +521,7 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // log.Printf("Failed to destroy session: %v", err) // } func (s *Session) Destroy() error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.destroy", params) + _, err := s.client.Request("session.destroy", sessionDestroyRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to destroy session: %w", err) } @@ -692,11 +564,7 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.abort", params) + _, err := s.client.Request("session.abort", sessionAbortRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to abort session: %w", err) } diff --git a/go/types.go b/go/types.go index 599eb35c1..a3b38ee31 100644 --- a/go/types.go +++ b/go/types.go @@ -1,5 +1,7 @@ package copilot +import "encoding/json" + // ConnectionState represents the client connection state type ConnectionState string @@ -113,15 +115,15 @@ type PermissionInvocation struct { // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { - Question string `json:"question"` - Choices []string `json:"choices,omitempty"` - AllowFreeform *bool `json:"allowFreeform,omitempty"` + Question string + Choices []string + AllowFreeform *bool } // UserInputResponse represents the user's response to an input request type UserInputResponse struct { - Answer string `json:"answer"` - WasFreeform bool `json:"wasFreeform"` + Answer string + WasFreeform bool } // UserInputHandler handles user input requests from the agent @@ -307,13 +309,13 @@ type CustomAgentConfig struct { // limits through background compaction and persist state to a workspace directory. type InfiniteSessionConfig struct { // Enabled controls whether infinite sessions are enabled (default: true) - Enabled *bool + Enabled *bool `json:"enabled,omitempty"` // BackgroundCompactionThreshold is the context utilization (0.0-1.0) at which // background compaction starts. Default: 0.80 - BackgroundCompactionThreshold *float64 + BackgroundCompactionThreshold *float64 `json:"backgroundCompactionThreshold,omitempty"` // BufferExhaustionThreshold is the context utilization (0.0-1.0) at which // the session blocks until compaction completes. Default: 0.95 - BufferExhaustionThreshold *float64 + BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } // SessionConfig configures a new session @@ -369,10 +371,10 @@ type SessionConfig struct { // Tool describes a caller-implemented tool that can be invoked by Copilot type Tool struct { - Name string - Description string // optional - Parameters map[string]any - Handler ToolHandler + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Handler ToolHandler `json:"-"` } // ToolInvocation describes a tool call initiated by Copilot @@ -491,43 +493,6 @@ type MessageOptions struct { // SessionEventHandler is a callback for session events type SessionEventHandler func(event SessionEvent) -// PingResponse is the response from a ping request -type PingResponse struct { - Message string `json:"message"` - Timestamp int64 `json:"timestamp"` - ProtocolVersion *int `json:"protocolVersion,omitempty"` -} - -// SessionCreateResponse is the response from session.create -type SessionCreateResponse struct { - SessionID string `json:"sessionId"` -} - -// SessionSendResponse is the response from session.send -type SessionSendResponse struct { - MessageID string `json:"messageId"` -} - -// SessionGetMessagesResponse is the response from session.getMessages -type SessionGetMessagesResponse struct { - Events []SessionEvent `json:"events"` -} - -// GetStatusResponse is the response from status.get -type GetStatusResponse struct { - Version string `json:"version"` - ProtocolVersion int `json:"protocolVersion"` -} - -// GetAuthStatusResponse is the response from auth.getStatus -type GetAuthStatusResponse struct { - IsAuthenticated bool `json:"isAuthenticated"` - AuthType *string `json:"authType,omitempty"` - Host *string `json:"host,omitempty"` - Login *string `json:"login,omitempty"` - StatusMessage *string `json:"statusMessage,omitempty"` -} - // ModelVisionLimits contains vision-specific limits type ModelVisionLimits struct { SupportedMediaTypes []string `json:"supported_media_types"` @@ -576,11 +541,6 @@ type ModelInfo struct { DefaultReasoningEffort string `json:"defaultReasoningEffort,omitempty"` } -// GetModelsResponse is the response from models.list -type GetModelsResponse struct { - Models []ModelInfo `json:"models"` -} - // SessionMetadata contains metadata about a session type SessionMetadata struct { SessionID string `json:"sessionId"` @@ -590,22 +550,6 @@ type SessionMetadata struct { IsRemote bool `json:"isRemote"` } -// ListSessionsResponse is the response from session.list -type ListSessionsResponse struct { - Sessions []SessionMetadata `json:"sessions"` -} - -// DeleteSessionRequest is the request for session.delete -type DeleteSessionRequest struct { - SessionID string `json:"sessionId"` -} - -// DeleteSessionResponse is the response from session.delete -type DeleteSessionResponse struct { - Success bool `json:"success"` - Error *string `json:"error,omitempty"` -} - // SessionLifecycleEventType represents the type of session lifecycle event type SessionLifecycleEventType string @@ -634,19 +578,224 @@ type SessionLifecycleEventMetadata struct { // SessionLifecycleHandler is a callback for session lifecycle events type SessionLifecycleHandler func(event SessionLifecycleEvent) -// GetForegroundSessionResponse is the response from session.getForeground -type GetForegroundSessionResponse struct { +// permissionRequestRequest represents the request data for a permission request +type permissionRequestRequest struct { + SessionID string `json:"sessionId"` + Request PermissionRequest `json:"permissionRequest"` +} + +// permissionRequestResponse represents the response to a permission request +type permissionRequestResponse struct { + Result PermissionRequestResult `json:"result"` +} + +// createSessionRequest is the request for session.create +type createSessionRequest struct { + Model string `json:"model,omitempty"` + SessionID string `json:"sessionId,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools,omitempty"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` +} + +// createSessionResponse is the response from session.create +type createSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +// resumeSessionRequest is the request for session.resume +type resumeSessionRequest struct { + SessionID string `json:"sessionId"` + Model string `json:"model,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools,omitempty"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + DisableResume *bool `json:"disableResume,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` +} + +// resumeSessionResponse is the response from session.resume +type resumeSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +type hooksInvokeRequest struct { + SessionID string `json:"sessionId"` + Type string `json:"hookType"` + Input json.RawMessage `json:"input"` +} + +// listSessionsRequest is the request for session.list +type listSessionsRequest struct{} + +// listSessionsResponse is the response from session.list +type listSessionsResponse struct { + Sessions []SessionMetadata `json:"sessions"` +} + +// deleteSessionRequest is the request for session.delete +type deleteSessionRequest struct { + SessionID string `json:"sessionId"` +} + +// deleteSessionResponse is the response from session.delete +type deleteSessionResponse struct { + Success bool `json:"success"` + Error *string `json:"error,omitempty"` +} + +// getForegroundSessionRequest is the request for session.getForeground +type getForegroundSessionRequest struct{} + +// getForegroundSessionResponse is the response from session.getForeground +type getForegroundSessionResponse struct { SessionID *string `json:"sessionId,omitempty"` WorkspacePath *string `json:"workspacePath,omitempty"` } -// SetForegroundSessionRequest is the request for session.setForeground -type SetForegroundSessionRequest struct { +// setForegroundSessionRequest is the request for session.setForeground +type setForegroundSessionRequest struct { SessionID string `json:"sessionId"` } -// SetForegroundSessionResponse is the response from session.setForeground -type SetForegroundSessionResponse struct { +// setForegroundSessionResponse is the response from session.setForeground +type setForegroundSessionResponse struct { Success bool `json:"success"` Error *string `json:"error,omitempty"` } + +type pingRequest struct { + Message string `json:"message,omitempty"` +} + +// PingResponse is the response from a ping request +type PingResponse struct { + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` + ProtocolVersion *int `json:"protocolVersion,omitempty"` +} + +// getStatusRequest is the request for status.get +type getStatusRequest struct{} + +// GetStatusResponse is the response from status.get +type GetStatusResponse struct { + Version string `json:"version"` + ProtocolVersion int `json:"protocolVersion"` +} + +// getAuthStatusRequest is the request for auth.getStatus +type getAuthStatusRequest struct{} + +// GetAuthStatusResponse is the response from auth.getStatus +type GetAuthStatusResponse struct { + IsAuthenticated bool `json:"isAuthenticated"` + AuthType *string `json:"authType,omitempty"` + Host *string `json:"host,omitempty"` + Login *string `json:"login,omitempty"` + StatusMessage *string `json:"statusMessage,omitempty"` +} + +// listModelsRequest is the request for models.list +type listModelsRequest struct{} + +// listModelsResponse is the response from models.list +type listModelsResponse struct { + Models []ModelInfo `json:"models"` +} + +// sessionGetMessagesRequest is the request for session.getMessages +type sessionGetMessagesRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionGetMessagesResponse is the response from session.getMessages +type sessionGetMessagesResponse struct { + Events []SessionEvent `json:"events"` +} + +// sessionDestroyRequest is the request for session.destroy +type sessionDestroyRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionAbortRequest is the request for session.abort +type sessionAbortRequest struct { + SessionID string `json:"sessionId"` +} + +type sessionSendRequest struct { + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` +} + +// sessionSendResponse is the response from session.send +type sessionSendResponse struct { + MessageID string `json:"messageId"` +} + +// sessionEventRequest is the request for session event notifications +type sessionEventRequest struct { + SessionID string `json:"sessionId"` + Event SessionEvent `json:"event"` +} + +// toolCallRequest represents a tool call request from the server +// to the client for execution. +type toolCallRequest struct { + SessionID string `json:"sessionId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + Arguments any `json:"arguments"` +} + +// toolCallResponse represents the response to a tool call request +// from the client back to the server. +type toolCallResponse struct { + Result ToolResult `json:"result"` +} + +// userInputRequest represents a request for user input from the agent +type userInputRequest struct { + SessionID string `json:"sessionId"` + Question string `json:"question"` + Choices []string `json:"choices,omitempty"` + AllowFreeform *bool `json:"allowFreeform,omitempty"` +} + +// userInputResponse represents the user's response to an input request +type userInputResponse struct { + Answer string `json:"answer"` + WasFreeform bool `json:"wasFreeform"` +} diff --git a/nodejs/README.md b/nodejs/README.md index 9ad030aa1..3a78f4199 100644 --- a/nodejs/README.md +++ b/nodejs/README.md @@ -250,7 +250,7 @@ Sessions emit various events during processing: - `assistant.message` - Assistant response - `assistant.message_delta` - Streaming response chunk - `tool.execution_start` - Tool execution started -- `tool.execution_end` - Tool execution completed +- `tool.execution_complete` - Tool execution completed - And more... See `SessionEvent` type in the source for full details. diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 5f33118be..266d994e3 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -9,7 +9,7 @@ "version": "0.1.8", "license": "MIT", "dependencies": { - "@github/copilot": "^0.0.403", + "@github/copilot": "^0.0.405", "vscode-jsonrpc": "^8.2.1", "zod": "^4.3.6" }, @@ -31,7 +31,7 @@ "vitest": "^4.0.18" }, "engines": { - "node": ">=18.0.0" + "node": ">=24.0.0" } }, "node_modules/@apidevtools/json-schema-ref-parser": { @@ -662,26 +662,26 @@ } }, "node_modules/@github/copilot": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot/-/copilot-0.0.403.tgz", - "integrity": "sha512-v5jUdtGJReLmE1rmff/LZf+50nzmYQYAaSRNtVNr9g0j0GkCd/noQExe31i1+PudvWU0ZJjltR0B8pUfDRdA9Q==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot/-/copilot-0.0.405.tgz", + "integrity": "sha512-zp0kGSkoKrO4MTWefAxU5w2VEc02QnhPY3FmVxOeduh6ayDIz2V368mXxs46ThremdMnMyZPL1k989BW4NpOVw==", "license": "SEE LICENSE IN LICENSE.md", "bin": { "copilot": "npm-loader.js" }, "optionalDependencies": { - "@github/copilot-darwin-arm64": "0.0.403", - "@github/copilot-darwin-x64": "0.0.403", - "@github/copilot-linux-arm64": "0.0.403", - "@github/copilot-linux-x64": "0.0.403", - "@github/copilot-win32-arm64": "0.0.403", - "@github/copilot-win32-x64": "0.0.403" + "@github/copilot-darwin-arm64": "0.0.405", + "@github/copilot-darwin-x64": "0.0.405", + "@github/copilot-linux-arm64": "0.0.405", + "@github/copilot-linux-x64": "0.0.405", + "@github/copilot-win32-arm64": "0.0.405", + "@github/copilot-win32-x64": "0.0.405" } }, "node_modules/@github/copilot-darwin-arm64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-darwin-arm64/-/copilot-darwin-arm64-0.0.403.tgz", - "integrity": "sha512-dOw8IleA0d1soHnbr/6wc6vZiYWNTKMgfTe/NET1nCfMzyKDt/0F0I7PT5y+DLujJknTla/ZeEmmBUmliTW4Cg==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-arm64/-/copilot-darwin-arm64-0.0.405.tgz", + "integrity": "sha512-RVFpU1cEMqjR0rLpwLwbIfT7XzqqVoQX99G6nsj+WrHu3TIeCgfffyd2YShd4QwZYsMRoTfKB+rirQ+0G5Uiig==", "cpu": [ "arm64" ], @@ -695,9 +695,9 @@ } }, "node_modules/@github/copilot-darwin-x64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-darwin-x64/-/copilot-darwin-x64-0.0.403.tgz", - "integrity": "sha512-aK2jSNWgY8eiZ+TmrvGhssMCPDTKArc0ip6Ul5OaslpytKks8hyXoRbxGD0N9sKioSUSbvKUf+1AqavbDpJO+w==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-x64/-/copilot-darwin-x64-0.0.405.tgz", + "integrity": "sha512-Xj2FyPzpZlfqPTuMrXtPNEijSmm2ivHvyMWgy5Ijv7Slabxe+2s3WXDaokE3SQHodK6M0Yle2yrx9kxiwWA+qw==", "cpu": [ "x64" ], @@ -711,9 +711,9 @@ } }, "node_modules/@github/copilot-linux-arm64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-linux-arm64/-/copilot-linux-arm64-0.0.403.tgz", - "integrity": "sha512-KhoR2iR70O6vCkzf0h8/K+p82qAgOvMTgAPm9bVEHvbdGFR7Py9qL5v03bMbPxsA45oNaZAkzDhfTAqWhIAZsQ==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-arm64/-/copilot-linux-arm64-0.0.405.tgz", + "integrity": "sha512-16Wiq8EYB6ghwqZdYytnNkcCN4sT3jyt9XkjfMxI5DDdjLuPc8wbj5VV5pw8S6lZvBL4eAwXGE3+fPqXKxH6GQ==", "cpu": [ "arm64" ], @@ -727,9 +727,9 @@ } }, "node_modules/@github/copilot-linux-x64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-linux-x64/-/copilot-linux-x64-0.0.403.tgz", - "integrity": "sha512-eoswUc9vo4TB+/9PgFJLVtzI4dPjkpJXdCsAioVuoqPdNxHxlIHFe9HaVcqMRZxUNY1YHEBZozy+IpUEGjgdfQ==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-x64/-/copilot-linux-x64-0.0.405.tgz", + "integrity": "sha512-HXpg7p235//pAuCvcL9m2EeIrL/K6OUEkFeHF3BFHzqUJR4a69gKLsxtUg0ZctypHqo2SehGCRAyVippTVlTyg==", "cpu": [ "x64" ], @@ -743,9 +743,9 @@ } }, "node_modules/@github/copilot-win32-arm64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-win32-arm64/-/copilot-win32-arm64-0.0.403.tgz", - "integrity": "sha512-djWjzCsp2xPNafMyOZ/ivU328/WvWhdroGie/DugiJBTgQL2SP0quWW1fhTlDwE81a3g9CxfJonaRgOpFTJTcg==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-arm64/-/copilot-win32-arm64-0.0.405.tgz", + "integrity": "sha512-4JCUMiRjP7zB3j1XpEtJq7b7cxTzuwDJ9o76jayAL8HL9NhqKZ6Ys6uxhDA6f/l0N2GVD1TEICxsnPgadz6srg==", "cpu": [ "arm64" ], @@ -759,9 +759,9 @@ } }, "node_modules/@github/copilot-win32-x64": { - "version": "0.0.403", - "resolved": "https://registry.npmjs.org/@github/copilot-win32-x64/-/copilot-win32-x64-0.0.403.tgz", - "integrity": "sha512-lju8cHy2E6Ux7R7tWyLZeksYC2MVZu9i9ocjiBX/qfG2/pNJs7S5OlkwKJ0BSXSbZEHQYq7iHfEWp201bVfk9A==", + "version": "0.0.405", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-x64/-/copilot-win32-x64-0.0.405.tgz", + "integrity": "sha512-uHoJ9N8kZbTLbzgqBE1szHwLElv2f+P2OWlqmRSawQhwPl0s7u55dka7mZYvj2ZoNvIyb0OyShCO56OpmCcy/w==", "cpu": [ "x64" ], diff --git a/nodejs/package.json b/nodejs/package.json index 11dd20b53..b6e23f401 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -40,7 +40,7 @@ "author": "GitHub", "license": "MIT", "dependencies": { - "@github/copilot": "^0.0.403", + "@github/copilot": "^0.0.405", "vscode-jsonrpc": "^8.2.1", "zod": "^4.3.6" }, @@ -62,7 +62,7 @@ "vitest": "^4.0.18" }, "engines": { - "node": ">=18.0.0" + "node": ">=24.0.0" }, "files": [ "dist/**/*", diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index d50a9bbdd..7ffc57e5e 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -12,7 +12,10 @@ */ import { spawn, type ChildProcess } from "node:child_process"; +import { existsSync } from "node:fs"; import { Socket } from "node:net"; +import { dirname, join } from "node:path"; +import { fileURLToPath } from "node:url"; import { createMessageConnection, MessageConnection, @@ -102,6 +105,20 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | * await client.stop(); * ``` */ + +/** + * Gets the path to the bundled CLI from the @github/copilot package. + * Uses index.js directly rather than npm-loader.js (which spawns the native binary). + */ +function getBundledCliPath(): string { + // Find the actual location of the @github/copilot package by resolving its sdk export + const sdkUrl = import.meta.resolve("@github/copilot/sdk"); + const sdkPath = fileURLToPath(sdkUrl); + // sdkPath is like .../node_modules/@github/copilot/sdk/index.js + // Go up two levels to get the package root, then append index.js + return join(dirname(dirname(sdkPath)), "index.js"); +} + export class CopilotClient { private cliProcess: ChildProcess | null = null; private connection: MessageConnection | null = null; @@ -172,7 +189,7 @@ export class CopilotClient { } this.options = { - cliPath: options.cliPath || "copilot", + cliPath: options.cliPath || getBundledCliPath(), cliArgs: options.cliArgs ?? [], cwd: options.cwd ?? process.cwd(), port: options.port || 0, @@ -1014,6 +1031,7 @@ export class CopilotClient { const args = [ ...this.options.cliArgs, "--headless", + "--no-auto-update", "--log-level", this.options.logLevel, ]; @@ -1042,35 +1060,34 @@ export class CopilotClient { envWithoutNodeDebug.COPILOT_SDK_AUTH_TOKEN = this.options.githubToken; } - // If cliPath is a .js file, spawn it with node - // Note that we can't rely on the shebang as Windows doesn't support it - const isJsFile = this.options.cliPath.endsWith(".js"); - const isAbsolutePath = - this.options.cliPath.startsWith("/") || /^[a-zA-Z]:/.test(this.options.cliPath); + // Verify CLI exists before attempting to spawn + if (!existsSync(this.options.cliPath)) { + throw new Error( + `Copilot CLI not found at ${this.options.cliPath}. Ensure @github/copilot is installed.` + ); + } - let command: string; - let spawnArgs: string[]; + const stdioConfig: ["pipe", "pipe", "pipe"] | ["ignore", "pipe", "pipe"] = this.options + .useStdio + ? ["pipe", "pipe", "pipe"] + : ["ignore", "pipe", "pipe"]; + // For .js files, spawn node explicitly; for executables, spawn directly + const isJsFile = this.options.cliPath.endsWith(".js"); if (isJsFile) { - command = "node"; - spawnArgs = [this.options.cliPath, ...args]; - } else if (process.platform === "win32" && !isAbsolutePath) { - // On Windows, spawn doesn't search PATHEXT, so use cmd /c to resolve the executable. - command = "cmd"; - spawnArgs = ["/c", `${this.options.cliPath}`, ...args]; + this.cliProcess = spawn(process.execPath, [this.options.cliPath, ...args], { + stdio: stdioConfig, + cwd: this.options.cwd, + env: envWithoutNodeDebug, + }); } else { - command = this.options.cliPath; - spawnArgs = args; + this.cliProcess = spawn(this.options.cliPath, args, { + stdio: stdioConfig, + cwd: this.options.cwd, + env: envWithoutNodeDebug, + }); } - this.cliProcess = spawn(command, spawnArgs, { - stdio: this.options.useStdio - ? ["pipe", "pipe", "pipe"] - : ["ignore", "pipe", "pipe"], - cwd: this.options.cwd, - env: envWithoutNodeDebug, - }); - let stdout = ""; let resolved = false; diff --git a/nodejs/src/generated/session-events.ts b/nodejs/src/generated/session-events.ts index e50ecd04d..86783a043 100644 --- a/nodejs/src/generated/session-events.ts +++ b/nodejs/src/generated/session-events.ts @@ -3,7 +3,7 @@ * * Generated from: @github/copilot/session-events.schema.json * Generated by: scripts/generate-session-types.ts - * Generated at: 2026-02-03T20:40:49.167Z + * Generated at: 2026-02-06T20:38:23.139Z * * To update these types: * 1. Update the schema in copilot-agent-runtime @@ -216,6 +216,7 @@ export type SessionEvent = output: number; cachedInput: number; }; + requestId?: string; }; } | { diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 3418ef7b0..9f04f895a 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -15,8 +15,8 @@ export type SessionEvent = GeneratedSessionEvent; */ export interface CopilotClientOptions { /** - * Path to the Copilot CLI executable - * @default "copilot" (searches PATH) + * Path to the CLI executable or JavaScript entry point. + * If not specified, uses the bundled CLI from the @github/copilot package. */ cliPath?: string; diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index c5c6d49fd..25a8fb87d 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -1,13 +1,12 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, expect, it, onTestFinished } from "vitest"; import { CopilotClient } from "../src/index.js"; -import { CLI_PATH } from "./e2e/harness/sdkTestContext.js"; // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.ts instead describe("CopilotClient", () => { it("returns a standardized failure result when a tool is not registered", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH }); + const client = new CopilotClient(); await client.start(); onTestFinished(() => client.forceStop()); diff --git a/nodejs/test/e2e/client.test.ts b/nodejs/test/e2e/client.test.ts index 24992f66f..526e95095 100644 --- a/nodejs/test/e2e/client.test.ts +++ b/nodejs/test/e2e/client.test.ts @@ -1,7 +1,6 @@ import { ChildProcess } from "child_process"; import { describe, expect, it, onTestFinished } from "vitest"; import { CopilotClient } from "../../src/index.js"; -import { CLI_PATH } from "./harness/sdkTestContext.js"; function onTestFinishedForceStop(client: CopilotClient) { onTestFinished(async () => { @@ -15,7 +14,7 @@ function onTestFinishedForceStop(client: CopilotClient) { describe("Client", () => { it("should start and connect to server using stdio", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: true }); + const client = new CopilotClient({ useStdio: true }); onTestFinishedForceStop(client); await client.start(); @@ -30,7 +29,7 @@ describe("Client", () => { }); it("should start and connect to server using tcp", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: false }); + const client = new CopilotClient({ useStdio: false }); onTestFinishedForceStop(client); await client.start(); @@ -50,7 +49,7 @@ describe("Client", () => { // saying "Cannot call write after a stream was destroyed" // because the JSON-RPC logic is still trying to write to stdin after // the process has exited. - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: false }); + const client = new CopilotClient({ useStdio: false }); await client.createSession(); @@ -67,7 +66,7 @@ describe("Client", () => { }); it("should forceStop without cleanup", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH }); + const client = new CopilotClient({}); onTestFinishedForceStop(client); await client.createSession(); @@ -76,7 +75,7 @@ describe("Client", () => { }); it("should get status with version and protocol info", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: true }); + const client = new CopilotClient({ useStdio: true }); onTestFinishedForceStop(client); await client.start(); @@ -92,7 +91,7 @@ describe("Client", () => { }); it("should get auth status", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: true }); + const client = new CopilotClient({ useStdio: true }); onTestFinishedForceStop(client); await client.start(); @@ -108,7 +107,7 @@ describe("Client", () => { }); it("should list models when authenticated", async () => { - const client = new CopilotClient({ cliPath: CLI_PATH, useStdio: true }); + const client = new CopilotClient({ useStdio: true }); onTestFinishedForceStop(client); await client.start(); diff --git a/nodejs/test/e2e/harness/sdkTestContext.ts b/nodejs/test/e2e/harness/sdkTestContext.ts index 094eaff9c..beabf3812 100644 --- a/nodejs/test/e2e/harness/sdkTestContext.ts +++ b/nodejs/test/e2e/harness/sdkTestContext.ts @@ -17,10 +17,6 @@ const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); const SNAPSHOTS_DIR = resolve(__dirname, "../../../../test/snapshots"); -export const CLI_PATH = - process.env.COPILOT_CLI_PATH || - resolve(__dirname, "../../../node_modules/@github/copilot/index.js"); - export async function createSdkTestContext({ logLevel, }: { logLevel?: "error" | "none" | "warning" | "info" | "debug" | "all" } = {}) { @@ -41,7 +37,6 @@ export async function createSdkTestContext({ }; const copilotClient = new CopilotClient({ - cliPath: CLI_PATH, cwd: workDir, env, logLevel: logLevel || "error", diff --git a/nodejs/test/e2e/session.test.ts b/nodejs/test/e2e/session.test.ts index b3fba4755..01a3ad0b1 100644 --- a/nodejs/test/e2e/session.test.ts +++ b/nodejs/test/e2e/session.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it, onTestFinished } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; import { CopilotClient } from "../../src/index.js"; -import { CLI_PATH, createSdkTestContext } from "./harness/sdkTestContext.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; import { getFinalAssistantMessage, getNextEventOfType } from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { @@ -157,7 +157,6 @@ describe("Sessions", async () => { // Resume using a new client const newClient = new CopilotClient({ - cliPath: CLI_PATH, env, githubToken: process.env.CI === "true" ? "fake-token-for-e2e-tests" : undefined, }); diff --git a/python/.gitignore b/python/.gitignore index 421d7a7dc..b9774ce33 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -162,3 +162,10 @@ cython_debug/ # Ruff and ty cache .ruff_cache/ .ty_cache/ + +# Build script caches +.cli-cache/ +.build-temp/ + +# Bundled CLI binary (only in platform wheels, not in repo) +copilot/bin/ diff --git a/python/copilot/client.py b/python/copilot/client.py index f2b52eac4..85b728971 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -16,10 +16,11 @@ import inspect import os import re -import shutil import subprocess +import sys import threading from dataclasses import asdict, is_dataclass +from pathlib import Path from typing import Any, Callable, Optional, cast from .generated.session_events import session_event_from_dict @@ -48,6 +49,26 @@ ) +def _get_bundled_cli_path() -> Optional[str]: + """Get the path to the bundled CLI binary, if available.""" + # The binary is bundled in copilot/bin/ within the package + bin_dir = Path(__file__).parent / "bin" + if not bin_dir.exists(): + return None + + # Determine binary name based on platform + if sys.platform == "win32": + binary_name = "copilot.exe" + else: + binary_name = "copilot" + + binary_path = bin_dir / binary_name + if binary_path.exists(): + return str(binary_path) + + return None + + class CopilotClient: """ Main client for interacting with the Copilot CLI. @@ -130,8 +151,21 @@ def __init__(self, options: Optional[CopilotClientOptions] = None): else: self._actual_port = None - # Check environment variable for CLI path - default_cli_path = os.environ.get("COPILOT_CLI_PATH", "copilot") + # Determine CLI path: explicit option > bundled binary + # Not needed when connecting to external server via cli_url + if opts.get("cli_url"): + default_cli_path = "" # Not used for external server + elif opts.get("cli_path"): + default_cli_path = opts["cli_path"] + else: + bundled_path = _get_bundled_cli_path() + if bundled_path: + default_cli_path = bundled_path + else: + raise RuntimeError( + "Copilot CLI not found. The bundled CLI binary is not available. " + "Ensure you installed a platform-specific wheel, or provide cli_path." + ) # Default use_logged_in_user to False when github_token is provided github_token = opts.get("github_token") @@ -140,7 +174,7 @@ def __init__(self, options: Optional[CopilotClientOptions] = None): use_logged_in_user = False if github_token else True self.options: CopilotClientOptions = { - "cli_path": opts.get("cli_path", default_cli_path), + "cli_path": default_cli_path, "cwd": opts.get("cwd", os.getcwd()), "port": opts.get("port", 0), "use_stdio": False if opts.get("cli_url") else opts.get("use_stdio", True), @@ -1078,14 +1112,11 @@ async def _start_cli_server(self) -> None: """ cli_path = self.options["cli_path"] - # Resolve the full path on Windows (handles .cmd/.bat files) - # On Windows, subprocess.Popen doesn't use PATHEXT to resolve extensions, - # so we need to use shutil.which() to find the actual executable - resolved_path = shutil.which(cli_path) - if resolved_path: - cli_path = resolved_path + # Verify CLI exists + if not os.path.exists(cli_path): + raise RuntimeError(f"Copilot CLI not found at {cli_path}") - args = ["--headless", "--log-level", self.options["log_level"]] + args = ["--headless", "--no-auto-update", "--log-level", self.options["log_level"]] # Add auth-related flags if self.options.get("github_token"): diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index 6f0f753d4..84dff82e1 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -3,7 +3,7 @@ Generated from: @github/copilot/session-events.schema.json Generated by: scripts/generate-session-types.ts -Generated at: 2026-02-03T20:40:49.486Z +Generated at: 2026-02-06T20:38:23.376Z To update these types: 1. Update the schema in copilot-agent-runtime @@ -543,6 +543,7 @@ class Data: post_compaction_tokens: Optional[float] = None pre_compaction_messages_length: Optional[float] = None pre_compaction_tokens: Optional[float] = None + request_id: Optional[str] = None success: Optional[bool] = None summary_content: Optional[str] = None tokens_removed: Optional[float] = None @@ -649,6 +650,7 @@ def from_dict(obj: Any) -> 'Data': post_compaction_tokens = from_union([from_float, from_none], obj.get("postCompactionTokens")) pre_compaction_messages_length = from_union([from_float, from_none], obj.get("preCompactionMessagesLength")) pre_compaction_tokens = from_union([from_float, from_none], obj.get("preCompactionTokens")) + request_id = from_union([from_str, from_none], obj.get("requestId")) success = from_union([from_bool, from_none], obj.get("success")) summary_content = from_union([from_str, from_none], obj.get("summaryContent")) tokens_removed = from_union([from_float, from_none], obj.get("tokensRemoved")) @@ -701,7 +703,7 @@ def from_dict(obj: Any) -> 'Data': output = obj.get("output") metadata = from_union([Metadata.from_dict, from_none], obj.get("metadata")) role = from_union([Role, from_none], obj.get("role")) - return Data(context, copilot_version, producer, selected_model, session_id, start_time, version, event_count, resume_time, error_type, message, provider_call_id, stack, status_code, info_type, new_model, previous_model, handoff_time, remote_session_id, repository, source_type, summary, messages_removed_during_truncation, performed_by, post_truncation_messages_length, post_truncation_tokens_in_messages, pre_truncation_messages_length, pre_truncation_tokens_in_messages, token_limit, tokens_removed_during_truncation, events_removed, up_to_event_id, code_changes, current_model, error_reason, model_metrics, session_start_time, shutdown_type, total_api_duration_ms, total_premium_requests, current_tokens, messages_length, checkpoint_number, checkpoint_path, compaction_tokens_used, error, messages_removed, post_compaction_tokens, pre_compaction_messages_length, pre_compaction_tokens, success, summary_content, tokens_removed, attachments, content, source, transformed_content, turn_id, intent, reasoning_id, delta_content, encrypted_content, message_id, parent_tool_call_id, reasoning_opaque, reasoning_text, tool_requests, total_response_size_bytes, api_call_id, cache_read_tokens, cache_write_tokens, cost, duration, initiator, input_tokens, model, output_tokens, quota_snapshots, reason, arguments, tool_call_id, tool_name, mcp_server_name, mcp_tool_name, partial_output, progress_message, is_user_requested, result, tool_telemetry, allowed_tools, name, path, agent_description, agent_display_name, agent_name, tools, hook_invocation_id, hook_type, input, output, metadata, role) + return Data(context, copilot_version, producer, selected_model, session_id, start_time, version, event_count, resume_time, error_type, message, provider_call_id, stack, status_code, info_type, new_model, previous_model, handoff_time, remote_session_id, repository, source_type, summary, messages_removed_during_truncation, performed_by, post_truncation_messages_length, post_truncation_tokens_in_messages, pre_truncation_messages_length, pre_truncation_tokens_in_messages, token_limit, tokens_removed_during_truncation, events_removed, up_to_event_id, code_changes, current_model, error_reason, model_metrics, session_start_time, shutdown_type, total_api_duration_ms, total_premium_requests, current_tokens, messages_length, checkpoint_number, checkpoint_path, compaction_tokens_used, error, messages_removed, post_compaction_tokens, pre_compaction_messages_length, pre_compaction_tokens, request_id, success, summary_content, tokens_removed, attachments, content, source, transformed_content, turn_id, intent, reasoning_id, delta_content, encrypted_content, message_id, parent_tool_call_id, reasoning_opaque, reasoning_text, tool_requests, total_response_size_bytes, api_call_id, cache_read_tokens, cache_write_tokens, cost, duration, initiator, input_tokens, model, output_tokens, quota_snapshots, reason, arguments, tool_call_id, tool_name, mcp_server_name, mcp_tool_name, partial_output, progress_message, is_user_requested, result, tool_telemetry, allowed_tools, name, path, agent_description, agent_display_name, agent_name, tools, hook_invocation_id, hook_type, input, output, metadata, role) def to_dict(self) -> dict: result: dict = {} @@ -805,6 +807,8 @@ def to_dict(self) -> dict: result["preCompactionMessagesLength"] = from_union([to_float, from_none], self.pre_compaction_messages_length) if self.pre_compaction_tokens is not None: result["preCompactionTokens"] = from_union([to_float, from_none], self.pre_compaction_tokens) + if self.request_id is not None: + result["requestId"] = from_union([from_str, from_none], self.request_id) if self.success is not None: result["success"] = from_union([from_bool, from_none], self.success) if self.summary_content is not None: diff --git a/python/e2e/testharness/context.py b/python/e2e/testharness/context.py index e0b8ea4e8..533ee87e7 100644 --- a/python/e2e/testharness/context.py +++ b/python/e2e/testharness/context.py @@ -16,25 +16,18 @@ from .proxy import CapiProxy -def get_cli_path() -> str: - """Get CLI path from environment or try to find it. Raises if not found.""" - # Check environment variable first - cli_path = os.environ.get("COPILOT_CLI_PATH") - if cli_path and os.path.exists(cli_path): - return cli_path - +def get_cli_path_for_tests() -> str: + """Get CLI path for E2E tests. Uses node_modules CLI during development.""" # Look for CLI in sibling nodejs directory's node_modules - base_path = Path(__file__).parents[3] # equivalent to: path.parent.parent.parent.parent + base_path = Path(__file__).parents[3] full_path = base_path / "nodejs" / "node_modules" / "@github" / "copilot" / "index.js" if full_path.exists(): return str(full_path.resolve()) - raise RuntimeError( - "CLI not found. Set COPILOT_CLI_PATH or run 'npm install' in the nodejs directory." - ) + raise RuntimeError("CLI not found for tests. Run 'npm install' in the nodejs directory.") -CLI_PATH = get_cli_path() +CLI_PATH = get_cli_path_for_tests() SNAPSHOTS_DIR = Path(__file__).parents[3] / "test" / "snapshots" @@ -51,12 +44,7 @@ def __init__(self): async def setup(self): """Set up the test context with a shared client.""" - cli_path = get_cli_path() - if not cli_path or not os.path.exists(cli_path): - raise RuntimeError( - f"CLI not found at {cli_path}. Run 'npm install' in the nodejs directory first." - ) - self.cli_path = cli_path + self.cli_path = get_cli_path_for_tests() self.home_dir = tempfile.mkdtemp(prefix="copilot-test-config-") self.work_dir = tempfile.mkdtemp(prefix="copilot-test-work-") @@ -112,16 +100,18 @@ async def configure_for_test(self, test_file: str, test_name: str): await self._proxy.configure(abs_snapshot_path, self.work_dir) # Clear temp directories between tests (but leave them in place) + # Use ignore_errors=True to handle race conditions where files may still + # be written by background processes during cleanup for item in Path(self.home_dir).iterdir(): if item.is_dir(): - shutil.rmtree(item) + shutil.rmtree(item, ignore_errors=True) else: - item.unlink() + item.unlink(missing_ok=True) for item in Path(self.work_dir).iterdir(): if item.is_dir(): - shutil.rmtree(item) + shutil.rmtree(item, ignore_errors=True) else: - item.unlink() + item.unlink(missing_ok=True) def get_env(self) -> dict: """Return environment variables configured for isolated testing.""" diff --git a/python/pyproject.toml b/python/pyproject.toml index d5177af36..b902b050a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "ty>=0.0.2", "pytest>=7.0.0", "pytest-asyncio>=0.21.0", + "pytest-timeout>=2.0.0", "httpx>=0.24.0", ] diff --git a/python/scripts/build-wheels.mjs b/python/scripts/build-wheels.mjs new file mode 100644 index 000000000..5dac70254 --- /dev/null +++ b/python/scripts/build-wheels.mjs @@ -0,0 +1,364 @@ +#!/usr/bin/env node +/** + * Build platform-specific Python wheels with bundled Copilot CLI binaries. + * + * Downloads the Copilot CLI binary for each platform from the npm registry + * and builds a wheel that includes it. + * + * Usage: + * node scripts/build-wheels.mjs [--platform PLATFORM] [--output-dir DIR] + * + * --platform: Build for specific platform only (linux-x64, linux-arm64, darwin-x64, + * darwin-arm64, win32-x64, win32-arm64). If not specified, builds all. + * --output-dir: Directory for output wheels (default: dist/) + */ + +import { execSync } from "node:child_process"; +import { + createWriteStream, + existsSync, + mkdirSync, + readFileSync, + writeFileSync, + chmodSync, + rmSync, + cpSync, + readdirSync, + statSync, +} from "node:fs"; +import { dirname, join } from "node:path"; +import { pipeline } from "node:stream/promises"; +import { fileURLToPath } from "node:url"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); +const pythonDir = dirname(__dirname); +const repoRoot = dirname(pythonDir); + +// Platform mappings: npm package suffix -> [wheel platform tag, binary name] +const PLATFORMS = { + "linux-x64": ["manylinux_2_17_x86_64", "copilot"], + "linux-arm64": ["manylinux_2_17_aarch64", "copilot"], + "darwin-x64": ["macosx_10_9_x86_64", "copilot"], + "darwin-arm64": ["macosx_11_0_arm64", "copilot"], + "win32-x64": ["win_amd64", "copilot.exe"], + "win32-arm64": ["win_arm64", "copilot.exe"], +}; + +function getCliVersion() { + const packageLockPath = join(repoRoot, "nodejs", "package-lock.json"); + if (!existsSync(packageLockPath)) { + throw new Error( + `package-lock.json not found at ${packageLockPath}. Run 'npm install' in nodejs/ first.` + ); + } + + const packageLock = JSON.parse(readFileSync(packageLockPath, "utf-8")); + const version = packageLock.packages?.["node_modules/@github/copilot"]?.version; + + if (!version) { + throw new Error("Could not find @github/copilot version in package-lock.json"); + } + + return version; +} + +function getPkgVersion() { + const pyprojectPath = join(pythonDir, "pyproject.toml"); + const content = readFileSync(pyprojectPath, "utf-8"); + const match = content.match(/version\s*=\s*"([^"]+)"/); + if (!match) { + throw new Error("Could not find version in pyproject.toml"); + } + return match[1]; +} + +async function downloadCliBinary(platform, cliVersion, cacheDir) { + const [, binaryName] = PLATFORMS[platform]; + const cachedBinary = join(cacheDir, binaryName); + + // Check cache + if (existsSync(cachedBinary)) { + console.log(` Using cached ${binaryName}`); + return cachedBinary; + } + + const tarballUrl = `https://registry.npmjs.org/@github/copilot-${platform}/-/copilot-${platform}-${cliVersion}.tgz`; + console.log(` Downloading from ${tarballUrl}...`); + + // Download tarball + const response = await fetch(tarballUrl); + if (!response.ok) { + throw new Error(`Failed to download: ${response.status} ${response.statusText}`); + } + + // Extract to cache dir + mkdirSync(cacheDir, { recursive: true }); + + const tarballPath = join(cacheDir, `copilot-${platform}-${cliVersion}.tgz`); + const fileStream = createWriteStream(tarballPath); + + await pipeline(response.body, fileStream); + + // Extract binary from tarball using system tar + // On Windows, use the system32 tar to avoid Git Bash tar issues + const tarCmd = process.platform === "win32" + ? `"${process.env.SystemRoot}\\System32\\tar.exe"` + : "tar"; + + try { + execSync(`${tarCmd} -xzf "${tarballPath}" -C "${cacheDir}" --strip-components=1 "package/${binaryName}"`, { + stdio: "inherit", + }); + } catch (e) { + // Clean up on failure + if (existsSync(tarballPath)) { + rmSync(tarballPath); + } + throw new Error(`Failed to extract binary: ${e.message}`); + } + + // Clean up tarball + rmSync(tarballPath); + + // Verify binary exists + if (!existsSync(cachedBinary)) { + throw new Error(`Binary not found after extraction: ${cachedBinary}`); + } + + // Make executable on Unix + if (!binaryName.endsWith(".exe")) { + chmodSync(cachedBinary, 0o755); + } + + const size = statSync(cachedBinary).size / 1024 / 1024; + console.log(` Downloaded ${binaryName} (${size.toFixed(1)} MB)`); + + return cachedBinary; +} + +function getCliLicensePath() { + // Use license from node_modules (requires npm ci in nodejs/ first) + const licensePath = join(repoRoot, "nodejs", "node_modules", "@github", "copilot", "LICENSE.md"); + if (!existsSync(licensePath)) { + throw new Error( + `CLI LICENSE.md not found at ${licensePath}. Run 'npm ci' in nodejs/ first.` + ); + } + return licensePath; +} + +async function buildWheel(platform, pkgVersion, cliVersion, outputDir, licensePath) { + const [wheelTag, binaryName] = PLATFORMS[platform]; + console.log(`\nBuilding wheel for ${platform}...`); + + // Cache directory includes version + const cacheDir = join(pythonDir, ".cli-cache", cliVersion, platform); + + // Download/get cached binary + const binaryPath = await downloadCliBinary(platform, cliVersion, cacheDir); + + // Create temp build directory + const buildDir = join(pythonDir, ".build-temp", platform); + if (existsSync(buildDir)) { + rmSync(buildDir, { recursive: true }); + } + mkdirSync(buildDir, { recursive: true }); + + // Copy package source + const pkgDir = join(buildDir, "copilot"); + cpSync(join(pythonDir, "copilot"), pkgDir, { recursive: true }); + + // Create bin directory and copy binary + const binDir = join(pkgDir, "bin"); + mkdirSync(binDir, { recursive: true }); + cpSync(binaryPath, join(binDir, binaryName)); + + // Create VERSION file + writeFileSync(join(binDir, "VERSION"), cliVersion); + + // Create __init__.py + writeFileSync(join(binDir, "__init__.py"), '"""Bundled Copilot CLI binary."""\n'); + + // Copy and modify pyproject.toml - replace license reference with file + let pyprojectContent = readFileSync(join(pythonDir, "pyproject.toml"), "utf-8"); + + // Replace the license specification with file reference + pyprojectContent = pyprojectContent.replace( + 'license = {text = "MIT"}', + 'license = {file = "CLI-LICENSE.md"}' + ); + + // Add package-data configuration + const packageDataConfig = ` +[tool.setuptools.package-data] +"copilot.bin" = ["*"] +`; + pyprojectContent = pyprojectContent.replace("\n[tool.ruff]", `${packageDataConfig}\n[tool.ruff]`); + writeFileSync(join(buildDir, "pyproject.toml"), pyprojectContent); + + // Copy README + if (existsSync(join(pythonDir, "README.md"))) { + cpSync(join(pythonDir, "README.md"), join(buildDir, "README.md")); + } + + // Copy CLI LICENSE + cpSync(licensePath, join(buildDir, "CLI-LICENSE.md")); + + // Build wheel using uv (faster and doesn't require build package to be installed) + const distDir = join(buildDir, "dist"); + execSync("uv build --wheel", { + cwd: buildDir, + stdio: "inherit", + }); + + // Find built wheel + const wheels = readdirSync(distDir).filter((f) => f.endsWith(".whl")); + if (wheels.length === 0) { + throw new Error("No wheel found after build"); + } + + const srcWheel = join(distDir, wheels[0]); + const newName = wheels[0].replace("-py3-none-any.whl", `-py3-none-${wheelTag}.whl`); + const destWheel = join(outputDir, newName); + + // Repack wheel with correct platform tag + await repackWheelWithPlatform(srcWheel, destWheel, wheelTag); + + // Clean up build dir + rmSync(buildDir, { recursive: true }); + + const size = statSync(destWheel).size / 1024 / 1024; + console.log(` Built ${newName} (${size.toFixed(1)} MB)`); + + return destWheel; +} + +async function repackWheelWithPlatform(srcWheel, destWheel, platformTag) { + // Write Python script to temp file to avoid shell escaping issues + const script = ` +import sys +import zipfile +import tempfile +from pathlib import Path + +src_wheel = Path(sys.argv[1]) +dest_wheel = Path(sys.argv[2]) +platform_tag = sys.argv[3] + +with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + # Extract wheel + with zipfile.ZipFile(src_wheel, 'r') as zf: + zf.extractall(tmpdir) + + # Find and update WHEEL file + wheel_info_dirs = list(tmpdir.glob('*.dist-info')) + if not wheel_info_dirs: + raise RuntimeError('No .dist-info directory found in wheel') + + wheel_info_dir = wheel_info_dirs[0] + wheel_file = wheel_info_dir / 'WHEEL' + + with open(wheel_file) as f: + wheel_content = f.read() + + wheel_content = wheel_content.replace('Tag: py3-none-any', f'Tag: py3-none-{platform_tag}') + + with open(wheel_file, 'w') as f: + f.write(wheel_content) + + # Regenerate RECORD file + record_file = wheel_info_dir / 'RECORD' + records = [] + for path in tmpdir.rglob('*'): + if path.is_file() and path.name != 'RECORD': + rel_path = path.relative_to(tmpdir) + records.append(f'{rel_path},,') + records.append(f'{wheel_info_dir.name}/RECORD,,') + + with open(record_file, 'w') as f: + f.write('\\n'.join(records)) + + # Create new wheel + dest_wheel.parent.mkdir(parents=True, exist_ok=True) + if dest_wheel.exists(): + dest_wheel.unlink() + + with zipfile.ZipFile(dest_wheel, 'w', zipfile.ZIP_DEFLATED) as zf: + for path in tmpdir.rglob('*'): + if path.is_file(): + zf.write(path, path.relative_to(tmpdir)) +`; + + // Write script to temp file + const scriptPath = join(pythonDir, ".build-temp", "repack_wheel.py"); + mkdirSync(dirname(scriptPath), { recursive: true }); + writeFileSync(scriptPath, script); + + try { + execSync(`python "${scriptPath}" "${srcWheel}" "${destWheel}" "${platformTag}"`, { + stdio: "inherit", + }); + } finally { + // Clean up script + rmSync(scriptPath); + } +} + +async function main() { + const args = process.argv.slice(2); + let platform = null; + let outputDir = join(pythonDir, "dist"); + + // Parse args + for (let i = 0; i < args.length; i++) { + if (args[i] === "--platform" && args[i + 1]) { + platform = args[++i]; + if (!PLATFORMS[platform]) { + console.error(`Invalid platform: ${platform}`); + console.error(`Valid platforms: ${Object.keys(PLATFORMS).join(", ")}`); + process.exit(1); + } + } else if (args[i] === "--output-dir" && args[i + 1]) { + outputDir = args[++i]; + } + } + + const cliVersion = getCliVersion(); + const pkgVersion = getPkgVersion(); + + console.log(`CLI version: ${cliVersion}`); + console.log(`Package version: ${pkgVersion}`); + + mkdirSync(outputDir, { recursive: true }); + + // Get CLI license from node_modules + const licensePath = getCliLicensePath(); + + const platforms = platform ? [platform] : Object.keys(PLATFORMS); + const wheels = []; + + for (const p of platforms) { + try { + const wheel = await buildWheel(p, pkgVersion, cliVersion, outputDir, licensePath); + wheels.push(wheel); + } catch (e) { + console.error(`Error building wheel for ${p}:`, e.message); + if (platform) { + process.exit(1); + } + } + } + + console.log(`\nBuilt ${wheels.length} wheel(s):`); + for (const wheel of wheels) { + console.log(` ${wheel}`); + } +} + +main().catch((e) => { + console.error(e); + process.exit(1); +}); diff --git a/python/test_client.py b/python/test_client.py index 3823c86b0..7b4af8c0f 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -4,8 +4,6 @@ This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead. """ -from unittest.mock import MagicMock, patch - import pytest from copilot import CopilotClient @@ -98,25 +96,36 @@ def test_is_external_server_true(self): class TestAuthOptions: def test_accepts_github_token(self): - client = CopilotClient({"github_token": "gho_test_token", "log_level": "error"}) + client = CopilotClient( + {"cli_path": CLI_PATH, "github_token": "gho_test_token", "log_level": "error"} + ) assert client.options.get("github_token") == "gho_test_token" def test_default_use_logged_in_user_true_without_token(self): - client = CopilotClient({"log_level": "error"}) + client = CopilotClient({"cli_path": CLI_PATH, "log_level": "error"}) assert client.options.get("use_logged_in_user") is True def test_default_use_logged_in_user_false_with_token(self): - client = CopilotClient({"github_token": "gho_test_token", "log_level": "error"}) + client = CopilotClient( + {"cli_path": CLI_PATH, "github_token": "gho_test_token", "log_level": "error"} + ) assert client.options.get("use_logged_in_user") is False def test_explicit_use_logged_in_user_true_with_token(self): client = CopilotClient( - {"github_token": "gho_test_token", "use_logged_in_user": True, "log_level": "error"} + { + "cli_path": CLI_PATH, + "github_token": "gho_test_token", + "use_logged_in_user": True, + "log_level": "error", + } ) assert client.options.get("use_logged_in_user") is True def test_explicit_use_logged_in_user_false_without_token(self): - client = CopilotClient({"use_logged_in_user": False, "log_level": "error"}) + client = CopilotClient( + {"cli_path": CLI_PATH, "use_logged_in_user": False, "log_level": "error"} + ) assert client.options.get("use_logged_in_user") is False def test_github_token_with_cli_url_raises(self): @@ -138,62 +147,3 @@ def test_use_logged_in_user_with_cli_url_raises(self): CopilotClient( {"cli_url": "localhost:8080", "use_logged_in_user": False, "log_level": "error"} ) - - -class TestCLIPathResolution: - """Test that CLI path resolution works correctly, especially on Windows.""" - - @pytest.mark.asyncio - async def test_cli_path_resolved_with_which(self): - """Test that shutil.which() is used to resolve the CLI path.""" - # Create a mock resolved path - mock_resolved_path = "/usr/local/bin/copilot" - - with patch("copilot.client.shutil.which", return_value=mock_resolved_path): - with patch("copilot.client.subprocess.Popen") as mock_popen: - # Mock the process and its stdout for TCP mode - mock_process = MagicMock() - mock_process.stdout.readline.return_value = b"listening on port 8080\n" - mock_popen.return_value = mock_process - - client = CopilotClient( - {"cli_path": "copilot", "use_stdio": False, "log_level": "error"} - ) - - try: - await client._start_cli_server() - - # Verify that subprocess.Popen was called with the resolved path - mock_popen.assert_called_once() - args = mock_popen.call_args[0][0] - assert args[0] == mock_resolved_path - finally: - if client._process: - client._process = None - - @pytest.mark.asyncio - async def test_cli_path_not_resolved_when_which_returns_none(self): - """Test that original path is used when shutil.which() returns None.""" - original_path = "/custom/path/to/copilot" - - with patch("copilot.client.shutil.which", return_value=None): - with patch("copilot.client.subprocess.Popen") as mock_popen: - # Mock the process and its stdout for TCP mode - mock_process = MagicMock() - mock_process.stdout.readline.return_value = b"listening on port 8080\n" - mock_popen.return_value = mock_process - - client = CopilotClient( - {"cli_path": original_path, "use_stdio": False, "log_level": "error"} - ) - - try: - await client._start_cli_server() - - # Verify that subprocess.Popen was called with the original path - mock_popen.assert_called_once() - args = mock_popen.call_args[0][0] - assert args[0] == original_path - finally: - if client._process: - client._process = None diff --git a/python/uv.lock b/python/uv.lock index 8208e3847..35134a0b0 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -85,6 +85,7 @@ dev = [ { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pytest-timeout" }, { name = "ruff" }, { name = "ty" }, ] @@ -95,6 +96,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, + { name = "pytest-timeout", marker = "extra == 'dev'", specifier = ">=2.0.0" }, { name = "python-dateutil", specifier = ">=2.9.0.post0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.2" }, @@ -421,6 +423,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"