diff --git a/go.mod b/go.mod index fcb3d98588b54..ed0687352720a 100644 --- a/go.mod +++ b/go.mod @@ -203,6 +203,8 @@ require ( github.com/spiffe/aws-spiffe-workload-helper v0.0.1-rc.8 github.com/spiffe/go-spiffe/v2 v2.5.0 github.com/stretchr/testify v1.10.0 + github.com/tidwall/pretty v1.2.0 + github.com/tidwall/sjson v1.2.5 github.com/ucarion/urlpath v0.0.0-20200424170820-7ccc79b76bbb github.com/vulcand/predicate v1.2.0 // replaced github.com/yusufpapurcu/wmi v1.2.4 @@ -531,6 +533,8 @@ require ( github.com/thales-e-security/pool v0.0.2 // indirect github.com/theupdateframework/go-tuf v0.7.0 // indirect github.com/theupdateframework/go-tuf/v2 v2.0.2 // indirect + github.com/tidwall/gjson v1.14.2 // indirect + github.com/tidwall/match v1.1.1 // indirect github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 642e1da25ed92..63de0834330c9 100644 --- a/go.sum +++ b/go.sum @@ -2227,7 +2227,15 @@ github.com/theupdateframework/go-tuf v0.7.0 h1:CqbQFrWo1ae3/I0UCblSbczevCCbS31Qv github.com/theupdateframework/go-tuf v0.7.0/go.mod h1:uEB7WSY+7ZIugK6R1hiBMBjQftaFzn7ZCDJcp1tCUug= github.com/theupdateframework/go-tuf/v2 v2.0.2 h1:PyNnjV9BJNzN1ZE6BcWK+5JbF+if370jjzO84SS+Ebo= github.com/theupdateframework/go-tuf/v2 v2.0.2/go.mod h1:baB22nBHeHBCeuGZcIlctNq4P61PcOdyARlplg5xmLA= +github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 h1:N9UxlsOzu5mttdjhxkDLbzwtEecuXmlxZVo/ds7JKJI= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0/go.mod h1:PxSp9GlOkKL9rlybW804uspnHuO9nbD98V/fDX4uSis= github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 h1:3B9i6XBXNTRspfkTC0asN5W0K6GhOSgcujNiECNRNb0= diff --git a/lib/client/mcp/claude/config.go b/lib/client/mcp/claude/config.go new file mode 100644 index 0000000000000..4e546deb13e0c --- /dev/null +++ b/lib/client/mcp/claude/config.go @@ -0,0 +1,273 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package claude + +import ( + "bytes" + "encoding/json" + "io" + "maps" + "os" + "path/filepath" + "runtime" + + "github.com/gravitational/trace" + "github.com/tidwall/pretty" + "github.com/tidwall/sjson" +) + +// DefaultConfigPath returns the default path for the Claude Desktop config. +// +// https://modelcontextprotocol.io/quickstart/user +// +// macOS: ~/Library/Application Support/Claude/claude_desktop_config.json +// Windows: %APPDATA%\Claude\claude_desktop_config.json +func DefaultConfigPath() (string, error) { + switch runtime.GOOS { + case "darwin", "windows": + // os.UserConfigDir: + // On Darwin, it returns $HOME/Library/Application Support. + // On Windows, it returns %AppData%. + configDir, err := os.UserConfigDir() + if err != nil { + return "", trace.ConvertSystemError(err) + } + return filepath.Join(configDir, "Claude", "claude_desktop_config.json"), nil + + default: + // TODO(greedy52) there is no official Claude Desktop for linux yet. The + // unofficial one uses the same path as above. + return "", trace.NotImplemented("Claude Desktop is not supported on OS %s", runtime.GOOS) + } +} + +// MCPServer contains details to launch an MCP server. +// +// https://modelcontextprotocol.io/quickstart/user +type MCPServer struct { + // Command specifies the command to execute. + Command string `json:"command"` + // Args specifies the arguments for the command. + Args []string `json:"args,omitempty"` + // Envs specifies extra environment variable. + Envs map[string]string `json:"env,omitempty"` +} + +// Config represents a Claude Desktop config. +// +// Config preserves unknown fields and ordering from the original JSON when +// saving, by using the sjson lib. +// +// Config functions are not thread-safe. +type Config struct { + mcpServers map[string]MCPServer + configData []byte + isOriginalJSONCompact bool +} + +// NewConfig creates an empty config. +func NewConfig() *Config { + return &Config{ + mcpServers: make(map[string]MCPServer), + configData: []byte("{}"), + isOriginalJSONCompact: false, + } +} + +// NewConfigFromJSON creates a config from JSON. +func NewConfigFromJSON(data []byte) (*Config, error) { + config := struct { + MCPServers map[string]MCPServer `json:"mcpServers"` + }{} + if err := json.Unmarshal(data, &config); err != nil { + return nil, trace.Wrap(err, "parsing Claude Desktop config") + } + + if config.MCPServers == nil { + config.MCPServers = map[string]MCPServer{} + } + isOriginalJSONCompact, err := isJSONCompact(data) + if err != nil { + return nil, trace.Wrap(err, "parsing Claude Desktop config") + } + + return &Config{ + mcpServers: config.MCPServers, + configData: data, + isOriginalJSONCompact: isOriginalJSONCompact, + }, nil +} + +// GetMCPServers returns a shallow copy of the MCP servers. +func (c *Config) GetMCPServers() map[string]MCPServer { + return maps.Clone(c.mcpServers) +} + +// PutMCPServer adds a new MCP server or replace an existing one. +func (c *Config) PutMCPServer(serverName string, server MCPServer) (err error) { + c.mcpServers[serverName] = server + c.configData, err = sjson.SetBytes(c.configData, c.mcpServerJSONPath(serverName), server) + return trace.Wrap(err) +} + +// RemoveMCPServer removes an MCP server by name. +func (c *Config) RemoveMCPServer(serverName string) (err error) { + if _, ok := c.mcpServers[serverName]; !ok { + return trace.NotFound("mcp server %v not found", serverName) + } + + delete(c.mcpServers, serverName) + c.configData, err = sjson.DeleteBytes(c.configData, c.mcpServerJSONPath(serverName)) + return trace.Wrap(err) +} + +// FormatJSONOption specifies the option on how to format the JSON output. +type FormatJSONOption string + +const ( + // FormatJSONPretty prettifies the JSON output. + FormatJSONPretty FormatJSONOption = "pretty" + // FormatJSONCompact minifies the JSON output. + FormatJSONCompact FormatJSONOption = "compact" + // FormatJSONNone skips formatting. + FormatJSONNone FormatJSONOption = "none" + // FormatJSONAuto minifies the JSON output if the original JSON is already + // minified. Otherwise, the JSON output is prettified. If the original JSON + // is "{}", the JSON output is also prettified. + FormatJSONAuto FormatJSONOption = "auto" +) + +// Write writes the config to provided writer. +func (c *Config) Write(w io.Writer, format FormatJSONOption) error { + data, err := c.formatConfigData(format) + if err != nil { + return trace.Wrap(err) + } + _, err = w.Write(data) + return trace.Wrap(err) +} + +// FileConfig represents a Config read from a file. +// +// Note that outside changes to the config file after LoadConfigFromFile will be +// ignored when saving. +type FileConfig struct { + *Config + configPath string + configExists bool +} + +// LoadConfigFromFile loads the Claude Desktop's config from the provided path. +func LoadConfigFromFile(configPath string) (*FileConfig, error) { + data, err := os.ReadFile(configPath) + switch { + case os.IsNotExist(err): + return &FileConfig{ + Config: NewConfig(), + configPath: configPath, + configExists: false, + }, nil + + case err != nil: + return nil, trace.Wrap(trace.ConvertSystemError(err), "reading Claude Desktop config") + + default: + config, err := NewConfigFromJSON(data) + if err != nil { + return nil, trace.Wrap(err) + } + + return &FileConfig{ + Config: config, + configPath: configPath, + configExists: true, + }, nil + } +} + +// LoadConfigFromDefaultPath loads the Claude Desktop's config from the default +// path. +func LoadConfigFromDefaultPath() (*FileConfig, error) { + configPath, err := DefaultConfigPath() + if err != nil { + return nil, trace.Wrap(err, "finding Claude Desktop config path") + } + config, err := LoadConfigFromFile(configPath) + return config, trace.Wrap(err) +} + +// Exists returns true if config file exists. +func (c *FileConfig) Exists() bool { + return c.configExists +} + +// Save saves the updated config to the config path. Format defaults to "auto" +// if empty. +func (c *FileConfig) Save(format FormatJSONOption) error { + // Claude Desktop creates the config with 0644. + file, err := os.OpenFile(c.configPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return trace.ConvertSystemError(err) + } + defer file.Close() + return trace.Wrap(c.Write(file, format)) +} + +func (c *Config) mcpServerJSONPath(serverName string) string { + return "mcpServers." + serverName +} + +func (c *Config) formatConfigData(format FormatJSONOption) ([]byte, error) { + return formatJSON(c.configData, format, c.isOriginalJSONCompact) +} + +func formatJSON(data []byte, format FormatJSONOption, isOriginalCompact bool) ([]byte, error) { + switch format { + case FormatJSONPretty: + // pretty.Pretty is more human-readable than json.Indent. + return pretty.Pretty(data), nil + case FormatJSONCompact: + return pretty.Ugly(data), nil + case FormatJSONNone: + return data, nil + case FormatJSONAuto, "": + if isOriginalCompact { + return pretty.Ugly(data), nil + } + return pretty.Pretty(data), nil + default: + return nil, trace.BadParameter("invalid JSON format option %q", format) + } +} + +func isJSONCompact(data []byte) (bool, error) { + data = bytes.TrimSpace(data) + + // Do not treat empty object as compact. + if bytes.Equal(data, []byte("{}")) { + return false, nil + } + + var buf bytes.Buffer + err := json.Compact(&buf, data) + if err != nil { + return false, trace.Wrap(err) + } + return bytes.Equal(buf.Bytes(), data), nil +} diff --git a/lib/client/mcp/claude/config_test.go b/lib/client/mcp/claude/config_test.go new file mode 100644 index 0000000000000..c6647d46d7713 --- /dev/null +++ b/lib/client/mcp/claude/config_test.go @@ -0,0 +1,253 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package claude + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestFileConfig_fileNotExists(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + + config, err := LoadConfigFromFile(configPath) + require.NoError(t, err) + require.NotNil(t, config) + require.False(t, config.Exists()) + + require.NoError(t, config.PutMCPServer("test", MCPServer{ + Command: "command", + })) + require.NoError(t, config.Save(FormatJSONCompact)) + requireFileWithData(t, configPath, `{"mcpServers":{"test":{"command":"command"}}}`) +} + +func TestFileConfig_sampleFile(t *testing.T) { + const sampleConfigJSON = `{ + "someUnknownField": "someUnknownValue", + "mcpServers": { + "Puppeteer": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-puppeteer"], + "someUnknownField": "someUnknownValue" + }, + "teleport-my-mcp": { + "command": "tsh", + "args": ["mcp", "connect", "my-mcp"], + "env": { + "TELEPORT_HOME": "/tsh-home/" + } + } + } +} +` + var sampleMCPServers = map[string]MCPServer{ + "Puppeteer": { + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-puppeteer"}, + }, + "teleport-my-mcp": { + Command: "tsh", + Args: []string{"mcp", "connect", "my-mcp"}, + Envs: map[string]string{ + "TELEPORT_HOME": "/tsh-home/", + }, + }, + } + + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(configPath, []byte(sampleConfigJSON), 0600)) + + // load + config, err := LoadConfigFromFile(configPath) + require.NoError(t, err) + require.NotNil(t, config) + require.True(t, config.Exists()) + require.Equal(t, sampleMCPServers, config.GetMCPServers()) + + // remove + require.True(t, trace.IsNotFound(config.RemoveMCPServer("not-found"))) + require.NoError(t, config.RemoveMCPServer("teleport-my-mcp")) + require.NoError(t, config.Save(FormatJSONPretty)) + requireFileWithData(t, configPath, `{ + "someUnknownField": "someUnknownValue", + "mcpServers": { + "Puppeteer": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-puppeteer"], + "someUnknownField": "someUnknownValue" + } + } +} +`) + + // add it back + require.NoError(t, config.PutMCPServer("teleport-my-mcp", sampleMCPServers["teleport-my-mcp"])) + require.NoError(t, config.Save(FormatJSONAuto)) + requireFileWithData(t, configPath, sampleConfigJSON) + + // replace + require.NoError(t, config.PutMCPServer("Puppeteer", MCPServer{ + Command: "custom-script", + })) + require.NoError(t, config.Save("")) + requireFileWithData(t, configPath, `{ + "someUnknownField": "someUnknownValue", + "mcpServers": { + "Puppeteer": { + "command": "custom-script" + }, + "teleport-my-mcp": { + "command": "tsh", + "args": ["mcp", "connect", "my-mcp"], + "env": { + "TELEPORT_HOME": "/tsh-home/" + } + } + } +} +`) +} + +func TestConfig_Write(t *testing.T) { + config := NewConfig() + + require.NoError(t, config.PutMCPServer("test", MCPServer{ + Command: "command", + })) + var buf bytes.Buffer + + require.NoError(t, config.Write(&buf, FormatJSONCompact)) + require.Equal(t, `{"mcpServers":{"test":{"command":"command"}}}`, buf.String()) +} + +func Test_isJSONCompact(t *testing.T) { + tests := []struct { + name string + in string + checkError require.ErrorAssertionFunc + checkIsCompact require.BoolAssertionFunc + }{ + { + name: "bad JSON", + in: "{", + checkError: require.Error, + checkIsCompact: require.False, + }, + { + name: "empty object", + in: "{}", + checkError: require.NoError, + checkIsCompact: require.False, + }, + { + name: "compact", + in: `{"a":"b"}`, + checkError: require.NoError, + checkIsCompact: require.True, + }, + { + name: "not compact", + in: `{ + "a": "b" +}`, + checkError: require.NoError, + checkIsCompact: require.False, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isCompact, err := isJSONCompact([]byte(tt.in)) + tt.checkError(t, err) + tt.checkIsCompact(t, isCompact) + }) + } +} + +func Test_formatJSON(t *testing.T) { + notFormatted := `{"a": "b"}` + compact := `{"a":"b"}` + pretty := `{ + "a": "b" +} +` + tests := []struct { + name string + in string + format FormatJSONOption + isOriginalCompact bool + out string + }{ + { + name: "to compact", + in: notFormatted, + format: FormatJSONCompact, + out: compact, + }, + { + name: "to pretty", + in: notFormatted, + format: FormatJSONPretty, + out: pretty, + }, + { + name: "none", + in: notFormatted, + format: FormatJSONNone, + out: notFormatted, + }, + { + name: "auto compact", + in: notFormatted, + format: FormatJSONAuto, + isOriginalCompact: true, + out: compact, + }, + { + name: "auto pretty", + in: notFormatted, + format: FormatJSONAuto, + isOriginalCompact: false, + out: pretty, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatted, err := formatJSON([]byte(tt.in), tt.format, tt.isOriginalCompact) + require.NoError(t, err) + require.Equal(t, tt.out, string(formatted)) + }) + } +} + +func requireFileWithData(t *testing.T, path string, want string) { + t.Helper() + read, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, want, string(read)) +}