diff --git a/build.assets/charts/Dockerfile-distroless b/build.assets/charts/Dockerfile-distroless index 08f26ffb0505c..75be2b86a437d 100644 --- a/build.assets/charts/Dockerfile-distroless +++ b/build.assets/charts/Dockerfile-distroless @@ -33,4 +33,5 @@ FROM $BASE_IMAGE COPY --from=teleport /opt/staging / COPY --from=staging /opt/staging/root / COPY --from=staging /opt/staging/status /var/lib/dpkg/status.d +ENV TELEPORT_TOOLS_VERSION=off ENTRYPOINT ["/usr/bin/dumb-init", "/usr/local/bin/teleport", "start", "-c", "/etc/teleport/teleport.yaml"] diff --git a/integration/autoupdate/tools/updater_test.go b/integration/autoupdate/tools/updater_test.go index 21f2962d75768..0ef2eabb4aa41 100644 --- a/integration/autoupdate/tools/updater_test.go +++ b/integration/autoupdate/tools/updater_test.go @@ -24,7 +24,6 @@ import ( "fmt" "os" "os/exec" - "path/filepath" "regexp" "strings" "testing" @@ -60,8 +59,13 @@ func TestUpdate(t *testing.T) { err := updater.Update(ctx, testVersions[0]) require.NoError(t, err) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) + tctlPath, err := updater.ToolPath("tctl", testVersions[0]) + require.NoError(t, err) + // Verify that the installed version is equal to requested one. - cmd := exec.CommandContext(ctx, filepath.Join(toolsDir, "tctl"), "version") + cmd := exec.CommandContext(ctx, tctlPath, "version") out, err := cmd.Output() require.NoError(t, err) @@ -71,7 +75,7 @@ func TestUpdate(t *testing.T) { // Execute version command again with setting the new version which must // trigger re-execution of the same command after downloading requested version. - cmd = exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version") + cmd = exec.CommandContext(ctx, tshPath, "version") cmd.Env = append( os.Environ(), fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), @@ -101,6 +105,9 @@ func TestParallelUpdate(t *testing.T) { err := updater.Update(ctx, testVersions[0]) require.NoError(t, err) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) + // By setting the limit request next test http serving file going blocked until unlock is sent. lock := make(chan struct{}) limitedWriter.SetLimitRequest(limitRequest{ @@ -110,8 +117,8 @@ func TestParallelUpdate(t *testing.T) { outputs := make([]bytes.Buffer, 3) errChan := make(chan error, 3) - for i := 0; i < len(outputs); i++ { - cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version") + for i := range outputs { + cmd := exec.Command(tshPath, "version") cmd.Stdout = &outputs[i] cmd.Stderr = &outputs[i] cmd.Env = append( @@ -139,7 +146,7 @@ func TestParallelUpdate(t *testing.T) { // Wait till process finished with exit code 0, but we still should get progress // bar in output content. - for i := 0; i < cap(outputs); i++ { + for range cap(outputs) { select { case <-time.After(5 * time.Second): require.Fail(t, "failed to wait till the process is finished") @@ -149,7 +156,7 @@ func TestParallelUpdate(t *testing.T) { } var progressCount int - for i := 0; i < cap(outputs); i++ { + for i := range cap(outputs) { matches := pattern.FindStringSubmatch(outputs[i].String()) require.Len(t, matches, 2) assert.Equal(t, testVersions[1], matches[1]) @@ -173,9 +180,11 @@ func TestUpdateInterruptSignal(t *testing.T) { ) err := updater.Update(ctx, testVersions[0]) require.NoError(t, err) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) var output bytes.Buffer - cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version") + cmd := exec.Command(tshPath, "version") cmd.Stdout = &output cmd.Stderr = &output cmd.Env = append( @@ -237,9 +246,11 @@ func TestUpdateForOSSBuild(t *testing.T) { ) err := updater.Update(ctx, testVersions[0]) require.NoError(t, err) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) // Verify that requested update is ignored by OSS build and version wasn't updated. - cmd := exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version") + cmd := exec.CommandContext(ctx, tshPath, "version") cmd.Env = append( os.Environ(), fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), @@ -253,7 +264,7 @@ func TestUpdateForOSSBuild(t *testing.T) { // Next update is set with the base URL env variable, must download new version. t.Setenv(autoupdate.BaseURLEnvVar, baseURL) - cmd = exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version") + cmd = exec.CommandContext(ctx, tshPath, "version") cmd.Env = append( os.Environ(), fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), diff --git a/integration/autoupdate/tools/updater_tsh_test.go b/integration/autoupdate/tools/updater_tsh_test.go index 17816e34e4458..6e56cf168dc7a 100644 --- a/integration/autoupdate/tools/updater_tsh_test.go +++ b/integration/autoupdate/tools/updater_tsh_test.go @@ -36,8 +36,10 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/autoupdate" "github.com/gravitational/teleport/integration/autoupdate/tools/updater" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/autoupdate/tools" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/utils" testserver "github.com/gravitational/teleport/tool/teleport/testenv" ) @@ -45,92 +47,54 @@ import ( // TestAliasLoginWithUpdater runs test cluster with enabled auto updates for client tools, // checks that defined alias in tsh configuration is replaced to the proper login command // and after auto update this not leads to recursive alias re-execution. +// +// # Managed updates: enabled. +// $ tsh loginByAlias +// $ tctl status +// $ tsh version +// Teleport v3.2.1 func TestAliasLoginWithUpdater(t *testing.T) { ctx := context.Background() - homeDir := filepath.Join(t.TempDir(), "home") - require.NoError(t, os.MkdirAll(homeDir, 0700)) - installDir := filepath.Join(t.TempDir(), "local") - require.NoError(t, os.MkdirAll(installDir, 0700)) - - t.Setenv(types.HomeEnvVar, homeDir) - - alice, err := types.NewUser("alice") - require.NoError(t, err) - alice.SetRoles([]string{"access"}) - - // Enable client tools auto updates and set the target version. - config, err := autoupdate.NewAutoUpdateConfig(&autoupdatev1pb.AutoUpdateConfigSpec{ - Tools: &autoupdatev1pb.AutoUpdateConfigSpecTools{ - Mode: autoupdate.ToolsUpdateModeEnabled, - }, - }) - require.NoError(t, err) - version, err := autoupdate.NewAutoUpdateVersion(&autoupdatev1pb.AutoUpdateVersionSpec{ - Tools: &autoupdatev1pb.AutoUpdateVersionSpecTools{ - TargetVersion: testVersions[1], // [v3.2.1] - }, - }) - require.NoError(t, err) + rootServer, homeDir, installDir := bootstrapTestServer(t) + setupManagedUpdates(t, rootServer.GetAuthServer(), autoupdate.ToolsUpdateModeEnabled, testVersions[1]) - // Disable 2fa to simplify login for test. - ap, err := types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOff, - Webauthn: &types.Webauthn{ - RPID: "localhost", - }, - }) + // Assign alias to the login command for test cluster. + proxyAddr, err := rootServer.ProxyWebAddr() require.NoError(t, err) - rootServer := testserver.MakeTestServer(t, - testserver.WithBootstrap(alice), - testserver.WithClusterName(t, "root"), - testserver.WithAuthPreference(ap), - ) - authService := rootServer.GetAuthServer() - _, err = authService.UpsertAutoUpdateConfig(ctx, config) - require.NoError(t, err) - _, err = authService.UpsertAutoUpdateVersion(ctx, version) - require.NoError(t, err) - password, err := utils.CryptoRandomHex(6) + // Fetch compiled test binary and install to tools dir [v1.2.3]. + updater := tools.NewUpdater(installDir, testVersions[0], tools.WithBaseURL(baseURL)) + require.NoError(t, updater.Update(ctx, testVersions[0])) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) require.NoError(t, err) - t.Setenv(updater.TestPassword, password) - err = authService.UpsertPassword("alice", []byte(password)) + tctlPath, err := updater.ToolPath("tctl", testVersions[0]) require.NoError(t, err) - // Assign alias to the login command for test cluster. - proxyAddr, err := rootServer.ProxyWebAddr() - require.NoError(t, err) configPath := filepath.Join(homeDir, client.TSHConfigPath) require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0700)) - executable := filepath.Join(installDir, "tsh") out, err := yaml.Marshal(client.TSHConfig{ Aliases: map[string]string{ "loginalice": fmt.Sprintf( "%s login --insecure --proxy %s --user alice --auth %s", - executable, proxyAddr, constants.LocalConnector, + tshPath, proxyAddr, constants.LocalConnector, ), }, }) require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, out, 0600)) - // Fetch compiled test binary and install to tools dir [v1.2.3]. - err = tools.NewUpdater(installDir, testVersions[0], tools.WithBaseURL(baseURL)).Update(ctx, testVersions[0]) - require.NoError(t, err) - // Execute alias command which must be transformed to the login command. // Since client tools autoupdates is enabled and target version is set // in the test cluster, we have to update client tools to new version. - cmd := exec.CommandContext(ctx, executable, "loginalice") + cmd := exec.CommandContext(ctx, tshPath, "loginalice") cmd.Env = os.Environ() cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr require.NoError(t, cmd.Run()) // Verify tctl status after login. - cmd = exec.CommandContext(ctx, filepath.Join(installDir, "tctl"), "status", "--insecure") + cmd = exec.CommandContext(ctx, tctlPath, "status", "--insecure") cmd.Env = os.Environ() cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -138,10 +102,9 @@ func TestAliasLoginWithUpdater(t *testing.T) { // Run version command to verify that login command executed auto update and // tsh was upgraded to [v3.2.1]. - cmd = exec.CommandContext(ctx, executable, "version") + cmd = exec.CommandContext(ctx, tshPath, "version") out, err = cmd.Output() require.NoError(t, err) - matches := pattern.FindStringSubmatch(string(out)) require.Len(t, matches, 2) require.Equal(t, testVersions[1], matches[1]) @@ -149,3 +112,223 @@ func TestAliasLoginWithUpdater(t *testing.T) { // Verifies that version commands shows version re-executed from. require.Contains(t, string(out), fmt.Sprintf("Re-executed from version: %s", testVersions[0])) } + +// TestLoginWithUpdaterAndProfile runs test cluster with disabled managed updates for client tools, +// verifies that if we set env variable during login we keep using updated version. +// +// # Managed updates: disabled. +// $ TELEPORT_TOOLS_VERSION=3.2.1 tsh login --proxy proxy.example.com +// # Check that created profile after login has enabled autoupdates flag. +// $ tsh version +// Teleport v3.2.1 +func TestLoginWithUpdaterAndProfile(t *testing.T) { + ctx := context.Background() + + rootServer, _, installDir := bootstrapTestServer(t) + setupManagedUpdates(t, rootServer.GetAuthServer(), autoupdate.ToolsUpdateModeDisabled, testVersions[1]) + + proxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + // Fetch compiled test binary and install to tools dir [v1.2.3]. + updater := tools.NewUpdater(installDir, testVersions[0], tools.WithBaseURL(baseURL)) + require.NoError(t, updater.Update(ctx, testVersions[0])) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) + + // First login with set version during login process + t.Setenv("TELEPORT_TOOLS_VERSION", testVersions[1]) + cmd := exec.CommandContext(ctx, tshPath, + "login", "--proxy", proxyAddr.String(), "--insecure", "--user", "alice", "--auth", constants.LocalConnector) + cmd.Env = os.Environ() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + require.NoError(t, cmd.Run()) + // Unset the version after update process. + require.NoError(t, os.Unsetenv("TELEPORT_TOOLS_VERSION")) + + // Run version command to verify that login command executed auto update and + // tsh was upgraded to [v3.2.1]. + cmd = exec.CommandContext(ctx, tshPath, "version") + out, err := cmd.Output() + require.NoError(t, err) + matches := pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[1], matches[1]) +} + +// TestLoginWithDisabledUpdateInProfile runs test cluster with enabled managed updates for client tools, +// verifies that after first update and disabling. +// +// # Managed updates: disabled. +// $ TELEPORT_TOOLS_VERSION=3.2.1 tsh version +// Teleport v3.2.1 +// $ tsh login --proxy proxy.example.com +// $ tsh version +// Teleport v1.2.3 +func TestLoginWithDisabledUpdateInProfile(t *testing.T) { + ctx := context.Background() + + rootServer, _, installDir := bootstrapTestServer(t) + setupManagedUpdates(t, rootServer.GetAuthServer(), autoupdate.ToolsUpdateModeDisabled, testVersions[1]) + + proxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + // Fetch compiled test binary and install to tools dir [v1.2.3]. + updater := tools.NewUpdater(installDir, testVersions[0], tools.WithBaseURL(baseURL)) + require.NoError(t, updater.Update(ctx, testVersions[0])) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) + + // Set env variable to forcibly request update on version command. + t.Setenv("TELEPORT_TOOLS_VERSION", testVersions[1]) + cmd := exec.CommandContext(ctx, tshPath, "version") + cmd.Env = os.Environ() + out, err := cmd.Output() + require.NoError(t, err) + // Check the version. + matches := pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[1], matches[1]) + // Unset the version after update process. + require.NoError(t, os.Unsetenv("TELEPORT_TOOLS_VERSION")) + + // Second login has to update profile and disable further managed updates. + cmd = exec.CommandContext(ctx, tshPath, + "login", "--proxy", proxyAddr.String(), "--insecure", "--user", "alice", "--auth", constants.LocalConnector) + cmd.Env = os.Environ() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + require.NoError(t, cmd.Run()) + + // Run version command to verify that login command executed auto update and + // tsh was upgraded to [v3.2.1]. + cmd = exec.CommandContext(ctx, tshPath, "version") + out, err = cmd.Output() + require.NoError(t, err) + // Check the version. + matches = pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + fmt.Println(string(out)) + require.Equal(t, testVersions[0], matches[1]) +} + +// TestLoginWithDisabledUpdateForcedByEnv verifies that on disabled cluster we are still +// able to update client tools by always setting the environment variable. +// +// # Managed updates: disabled. +// $ tsh login --proxy proxy.example.com +// $ TELEPORT_TOOLS_VERSION=3.2.1 tsh version +// Teleport v3.2.1 +// $ tsh version +// Teleport v1.2.3 +func TestLoginWithDisabledUpdateForcedByEnv(t *testing.T) { + ctx := context.Background() + + rootServer, _, installDir := bootstrapTestServer(t) + setupManagedUpdates(t, rootServer.GetAuthServer(), autoupdate.ToolsUpdateModeDisabled, testVersions[1]) + + proxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + // Fetch compiled test binary and install to tools dir [v1.2.3]. + updater := tools.NewUpdater(installDir, testVersions[0], tools.WithBaseURL(baseURL)) + require.NoError(t, updater.Update(ctx, testVersions[0])) + tshPath, err := updater.ToolPath("tsh", testVersions[0]) + require.NoError(t, err) + + // Second login has to update profile and disable further managed updates. + cmd := exec.CommandContext(ctx, tshPath, + "login", "--proxy", proxyAddr.String(), "--insecure", "--user", "alice", "--auth", constants.LocalConnector) + cmd.Env = os.Environ() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + require.NoError(t, cmd.Run()) + + // Trying to forcibly use specific version not during login. + t.Setenv("TELEPORT_TOOLS_VERSION", testVersions[1]) + cmd = exec.CommandContext(ctx, tshPath, "version") + cmd.Env = os.Environ() + out, err := cmd.Output() + require.NoError(t, err) + // Check that version is used that requested from env variable. + matches := pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + fmt.Println(string(out)) + require.Equal(t, testVersions[1], matches[1]) + // Unset the version after update process. + require.NoError(t, os.Unsetenv("TELEPORT_TOOLS_VERSION")) + + // Run version command to verify that login command executed auto update and + // tsh is version [v1.2.3] since it was requested not during login and cluster + // has disabled managed updates. + cmd = exec.CommandContext(ctx, tshPath, "version") + out, err = cmd.Output() + require.NoError(t, err) + matches = pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + fmt.Println(string(out)) + require.Equal(t, testVersions[0], matches[1]) +} + +func bootstrapTestServer(t *testing.T) (*service.TeleportProcess, string, string) { + t.Helper() + homeDir := filepath.Join(t.TempDir(), "home") + require.NoError(t, os.MkdirAll(homeDir, 0700)) + installDir := filepath.Join(t.TempDir(), "local") + require.NoError(t, os.MkdirAll(installDir, 0700)) + + t.Setenv(types.HomeEnvVar, homeDir) + + alice, err := types.NewUser("alice") + require.NoError(t, err) + alice.SetRoles([]string{"access"}) + + // Disable 2fa to simplify login for test. + ap, err := types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOff, + Webauthn: &types.Webauthn{ + RPID: "localhost", + }, + }) + require.NoError(t, err) + + rootServer := testserver.MakeTestServer(t, + testserver.WithBootstrap(alice), + testserver.WithClusterName(t, "root"), + testserver.WithAuthPreference(ap), + ) + authService := rootServer.GetAuthServer() + + // Set password for the cluster login. + password, err := utils.CryptoRandomHex(6) + require.NoError(t, err) + t.Setenv(updater.TestPassword, password) + err = authService.UpsertPassword("alice", []byte(password)) + require.NoError(t, err) + + return rootServer, homeDir, installDir +} + +func setupManagedUpdates(t *testing.T, server *auth.Server, muMode string, muVersion string) { + t.Helper() + ctx := context.Background() + config, err := autoupdate.NewAutoUpdateConfig(&autoupdatev1pb.AutoUpdateConfigSpec{ + Tools: &autoupdatev1pb.AutoUpdateConfigSpecTools{ + Mode: muMode, + }, + }) + require.NoError(t, err) + version, err := autoupdate.NewAutoUpdateVersion(&autoupdatev1pb.AutoUpdateVersionSpec{ + Tools: &autoupdatev1pb.AutoUpdateVersionSpecTools{ + TargetVersion: muVersion, + }, + }) + require.NoError(t, err) + _, err = server.UpsertAutoUpdateConfig(ctx, config) + require.NoError(t, err) + _, err = server.UpsertAutoUpdateVersion(ctx, version) + require.NoError(t, err) +} diff --git a/lib/autoupdate/tools/config.go b/lib/autoupdate/tools/config.go new file mode 100644 index 0000000000000..8bf1deefa6549 --- /dev/null +++ b/lib/autoupdate/tools/config.go @@ -0,0 +1,209 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "maps" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/utils" +) + +const ( + // configFileVersion identifies the version of the configuration file + // might be used for future migrations. + configFileVersion = "v1" + // lockFileName is file used for locking update process in parallel. + lockFileName = ".lock" + // configFileName is the configuration file used to store versions for known hosts + // and the installed versions of client tools. + configFileName = ".config.json" + configFilePerms = 0o644 + // defaultSizeStoredVersion defines how many versions will be stored in the tools + // directory. Older versions will be cleaned up based on least recently used. + defaultSizeStoredVersion = 3 +) + +// ClientToolsConfig is configuration structure for client tools managed updates. +type ClientToolsConfig struct { + // Version determines version of configuration file (to support future extensions). + Version string `json:"version"` + // Configs stores information about profile and cluster version and mode: + // `{"profile-name":{"version": "1.2.3", "disabled":false}}`. + Configs map[string]*ClusterConfig `json:"configs"` + // Tools stores information about tools directories per versions: + // `[{"tool_name": "tsh", "path": "tool-path", "version": "tool-version"}]`. + Tools []Tool `json:"tools"` + // MaxTools defines the maximum number of tools allowed in the tools directory. + // Any tools exceeding this limit will be removed during the next installation. + MaxTools int `json:"max_tools"` +} + +// AddTool adds a tool to the collection in the configuration, always placing it at the top. +// The collection size is limited by the `defaultSizeStoredVersion` constant. +func (ctc *ClientToolsConfig) AddTool(tool Tool) { + for _, t := range ctc.Tools { + if t.Version == tool.Version { + maps.Copy(t.PathMap, tool.PathMap) + return + } + } + if ctc.MaxTools <= 0 { + ctc.MaxTools = defaultSizeStoredVersion + } + if len(ctc.Tools) >= ctc.MaxTools { + ctc.Tools = append([]Tool{tool}, ctc.Tools[:ctc.MaxTools-1]...) + } else { + ctc.Tools = append([]Tool{tool}, ctc.Tools...) + } +} + +// SetConfig sets the version and mode flag for a specific host. +func (ctc *ClientToolsConfig) SetConfig(proxy string, version string, disabled bool) { + if config, ok := ctc.Configs[proxy]; ok { + config.Disabled = disabled + config.Version = version + } else { + ctc.Configs[proxy] = &ClusterConfig{Version: version, Disabled: disabled} + } +} + +// SelectVersion lookups the version and re-order by last recently used. +func (ctc *ClientToolsConfig) SelectVersion(version string) *Tool { + for i, tool := range ctc.Tools { + if tool.Version == version { + ctc.Tools = append([]Tool{tool}, append(ctc.Tools[:i], ctc.Tools[i+1:]...)...) + return &tool + } + } + return nil +} + +// HasVersion check that specific version present in collection. +func (ctc *ClientToolsConfig) HasVersion(version string) bool { + return slices.ContainsFunc(ctc.Tools, func(s Tool) bool { + return version == s.Version + }) +} + +// ClusterConfig stores required version and mode for specific cluster. +type ClusterConfig struct { + Version string `json:"version"` + Disabled bool `json:"disabled"` +} + +// Tool stores tools path per version, each tool might be stored in different path. +type Tool struct { + // Version is the version of the tools (tsh, tctl) as defined in the PathMap. + Version string `json:"version"` + // PathMap stores the relative path (within the tools directory) for each tool binary. + // For example: {"tctl": "package-id/tctl"}. + PathMap map[string]string `json:"path"` +} + +// PackageNames returns the package names extracted from the tool path map. +func (c *Tool) PackageNames() []string { + var packageNames []string + for _, path := range c.PathMap { + dir := strings.SplitN(path, string(filepath.Separator), 2) + if len(dir) > 0 { + packageNames = append(packageNames, dir[0]) + } + } + return packageNames +} + +// getToolsConfig reads the configuration file for client tools managed updates, +// and acquires a filesystem lock until the configuration is read and deserialized. +func getToolsConfig(toolsDir string) (ctc *ClientToolsConfig, err error) { + unlock, err := utils.FSWriteLock(filepath.Join(toolsDir, lockFileName)) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + err = trace.NewAggregate(err, unlock()) + }() + + ctc = &ClientToolsConfig{ + Configs: make(map[string]*ClusterConfig), + } + data, err := os.ReadFile(filepath.Join(toolsDir, configFileName)) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, trace.Wrap(err) + } + if data != nil { + if err := json.Unmarshal(data, ctc); err != nil { + // If the configuration file content is corrupted, tools execution should not fail. + // Instead, we should proceed and re-install the required version. + slog.WarnContext(context.Background(), "failed to unmarshal config file", "error", err) + } + } + + return ctc, nil +} + +// updateToolsConfig creates or opens the configuration file for client tools managed updates, +// and acquires a filesystem lock until the configuration is written and closed. +func updateToolsConfig(toolsDir string, update func(ctc *ClientToolsConfig) error) (err error) { + unlock, err := utils.FSWriteLock(filepath.Join(toolsDir, lockFileName)) + if err != nil { + return trace.Wrap(err) + } + defer func() { + err = trace.NewAggregate(err, unlock()) + }() + + ctc := &ClientToolsConfig{ + Version: configFileVersion, + Configs: make(map[string]*ClusterConfig), + } + data, err := os.ReadFile(filepath.Join(toolsDir, configFileName)) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return trace.Wrap(err) + } + if data != nil { + if err := json.Unmarshal(data, ctc); err != nil { + // If the configuration file content is corrupted, tools execution should not fail. + // Instead, we should proceed and re-install the required version. + slog.WarnContext(context.Background(), "failed to unmarshal config file", "error", err) + } + } + + // Perform update values before configuration file is going to be written. + if err := update(ctc); err != nil { + return trace.Wrap(err) + } + + jsonData, err := json.Marshal(ctc) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap( + os.WriteFile(filepath.Join(toolsDir, configFileName), jsonData, configFilePerms), + ) +} diff --git a/lib/autoupdate/tools/helper.go b/lib/autoupdate/tools/helper.go index 499a7c5d377c5..bad8752e5123f 100644 --- a/lib/autoupdate/tools/helper.go +++ b/lib/autoupdate/tools/helper.go @@ -23,10 +23,13 @@ import ( "errors" "log/slog" "os" + "path/filepath" "github.com/gravitational/trace" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/profile" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/autoupdate" stacksignal "github.com/gravitational/teleport/lib/utils/signal" ) @@ -39,66 +42,147 @@ var ( baseURL = autoupdate.DefaultBaseURL ) +// newUpdater inits the updater with default base URL and creates directory +// if it doesn't exist. +func newUpdater(toolsDir string) (*Updater, error) { + // Overrides default base URL for custom CDN for downloading updates. + if envBaseURL := os.Getenv(autoupdate.BaseURLEnvVar); envBaseURL != "" { + baseURL = envBaseURL + } + + // Create tools directory if it does not exist. + if err := os.MkdirAll(toolsDir, 0o755); err != nil { + return nil, trace.Wrap(err) + } + + return NewUpdater(toolsDir, version, WithBaseURL(baseURL)), nil +} + // CheckAndUpdateLocal verifies if the TELEPORT_TOOLS_VERSION environment variable -// is set and a version is defined (or disabled by setting it to "off"). The requested -// version is compared with the current client tools version. If they differ, the version -// package is downloaded, extracted to the client tools directory, and re-executed -// with the updated version. -// If $TELEPORT_HOME/bin contains downloaded client tools, it always re-executes -// using the version from the home directory. -func CheckAndUpdateLocal(ctx context.Context, reExecArgs []string) error { +// is set and whether a version is defined (or explicitly disabled by setting it to "off"). +// The environment variable always takes precedence over other version settings. +// +// If `currentProfileName` is specified, the function attempts to find the tools version +// required for the specified cluster in configuration file and re-execute it. +// +// The requested version is compared to the currently running client tools version. +// If they differ, the requested version is downloaded and extracted into the client tools directory, +// the installation is recorded in the configuration file, and the tool is re-executed with the updated version. +func CheckAndUpdateLocal(ctx context.Context, currentProfileName string, reExecArgs []string) error { + // If client tools updates are explicitly disabled, we want to catch this as soon as possible + // so we don't try to read te user home directory, fail, and log warnings. + if os.Getenv(teleportToolsVersionEnv) == teleportToolsVersionEnvDisabled { + return nil + } + + var err error + if currentProfileName == "" { + home := os.Getenv(types.HomeEnvVar) + if home != "" { + home = filepath.Clean(home) + } + profilePath := profile.FullProfilePath(home) + currentProfileName, err = profile.GetCurrentProfileName(profilePath) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + toolsDir, err := Dir() if err != nil { - slog.WarnContext(ctx, "Client tools update is disabled", "error", err) + slog.WarnContext(ctx, "Failed to detect the teleport home directory, client tools updates are disabled", "error", err) return nil } - - // Overrides default base URL for custom CDN for downloading updates. - if envBaseURL := os.Getenv(autoupdate.BaseURLEnvVar); envBaseURL != "" { - baseURL = envBaseURL + updater, err := newUpdater(toolsDir) + if err != nil { + slog.WarnContext(ctx, "Failed to create the updater, client tools updates are disabled", "error", err) + return nil } - updater := NewUpdater(toolsDir, version, WithBaseURL(baseURL)) - // At process startup, check if a version has already been downloaded to - // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION - // environment variable. If so, re-exec that version of client tools. - toolsVersion, reExec, err := updater.CheckLocal() + slog.DebugContext(ctx, "Attempting to local update", "current_profile_name", currentProfileName) + resp, err := updater.CheckLocal(ctx, currentProfileName) if err != nil { - return trace.Wrap(err) + slog.WarnContext(ctx, "Failed to check local teleport versions, client tools updates are disabled", "error", err) + return nil } - if reExec { - return trace.Wrap(updateAndReExec(ctx, updater, toolsVersion, reExecArgs)) + + if resp.ReExec { + return trace.Wrap(updateAndReExec(ctx, updater, resp.Version, reExecArgs)) } return nil } -// CheckAndUpdateRemote verifies client tools version is set for update in cluster -// configuration by making the http request to `webapi/find` endpoint. The requested -// version is compared with the current client tools version. If they differ, the version -// package is downloaded, extracted to the client tools directory, and re-executed +// CheckAndUpdateRemote verifies the client tools version configured for updates in the cluster +// by making an HTTP request to the `webapi/find` endpoint. +// +// If the TELEPORT_TOOLS_VERSION environment variable is set during the remote check, +// the version specified in the environment variable takes precedence over the version +// provided by the cluster. This version will also be recorded in the configuration for the cluster. +// +// The requested version is compared with the current client tools version. +// If they differ, the requested version is downloaded, extracted into the client tools directory, +// the installed version is recorded in the configuration, and the tool is re-executed // with the updated version. -// If $TELEPORT_HOME/bin contains downloaded client tools, it always re-executes -// using the version from the home directory. -func CheckAndUpdateRemote(ctx context.Context, proxy string, insecure bool, reExecArgs []string) error { +func CheckAndUpdateRemote(ctx context.Context, currentProfileName string, insecure bool, reExecArgs []string) error { + // If client tools updates are explicitly disabled, we want to catch this as soon as possible + // so we don't try to read te user home directory, fail, and log warnings. + if os.Getenv(teleportToolsVersionEnv) == teleportToolsVersionEnvDisabled { + return nil + } + toolsDir, err := Dir() if err != nil { - slog.WarnContext(ctx, "Client tools update is disabled", "error", err) + slog.WarnContext(ctx, "Failed to detect the teleport home directory, client tools updates are disabled", "error", err) + return nil + } + updater, err := newUpdater(toolsDir) + if err != nil { + slog.WarnContext(ctx, "Failed to create the updater, client tools updates are disabled", "error", err) return nil } - // Overrides default base URL for custom CDN for downloading updates. - if envBaseURL := os.Getenv(autoupdate.BaseURLEnvVar); envBaseURL != "" { - baseURL = envBaseURL + slog.DebugContext(ctx, "Attempting to remote update", "current_profile_name", currentProfileName, "insecure", insecure) + resp, err := updater.CheckRemote(ctx, currentProfileName, insecure) + if err != nil { + slog.WarnContext(ctx, "Failed to check remote teleport versions, client tools updates are disabled", "error", err) + return nil } - updater := NewUpdater(toolsDir, version, WithBaseURL(baseURL)) - toolsVersion, reExec, err := updater.CheckRemote(ctx, proxy, insecure) + if !resp.Disabled && resp.ReExec { + return trace.Wrap(updateAndReExec(ctx, updater, resp.Version, reExecArgs)) + } + + return nil +} + +// DownloadUpdate checks if a client tools version is set for update in the cluster +// configuration by making an HTTP request to the `webapi/find` endpoint. +// Downloads the new version if it is not already installed without re-execution. +func DownloadUpdate(ctx context.Context, name string, insecure bool) error { + toolsDir, err := Dir() + if err != nil { + slog.WarnContext(ctx, "Client tools update is disabled", "error", err) + return nil + } + updater, err := newUpdater(toolsDir) if err != nil { return trace.Wrap(err) } - if reExec { - return trace.Wrap(updateAndReExec(ctx, updater, toolsVersion, reExecArgs)) + + slog.DebugContext(ctx, "Attempting to remote update", "name", name, "insecure", insecure) + resp, err := updater.CheckRemote(ctx, name, insecure) + if err != nil { + return trace.Wrap(err) + } + + if !resp.Disabled && resp.ReExec { + ctxUpdate, cancel := stacksignal.GetSignalHandler().NotifyContext(ctx) + defer cancel() + err := updater.Update(ctxUpdate, resp.Version) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, ErrNoBaseURL) { + return trace.Wrap(err) + } } return nil @@ -110,15 +194,18 @@ func updateAndReExec(ctx context.Context, updater *Updater, toolsVersion string, // Download the version of client tools required by the cluster. This // is required if the user passed in the TELEPORT_TOOLS_VERSION // explicitly. - err := updater.UpdateWithLock(ctxUpdate, toolsVersion) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, errNoBaseURL) { - return trace.Wrap(err) + err := updater.Update(ctxUpdate, toolsVersion) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, ErrNoBaseURL) { + slog.ErrorContext(ctx, "Failed to update tools version", "error", err, "version", toolsVersion) + // Continue executing the current version of the client tools (tsh, tctl) + // to avoid potential issues with update process (timeout, missing version). + return nil } // Re-execute client tools with the correct version of client tools. - code, err := updater.Exec(args) + code, err := updater.Exec(ctx, toolsVersion, args) if err != nil && !errors.Is(err, os.ErrNotExist) { - slog.DebugContext(ctx, "Failed to re-exec client tool", "error", err) + slog.DebugContext(ctx, "Failed to re-exec client tool", "error", err, "code", code) os.Exit(code) } else if err == nil { os.Exit(code) diff --git a/lib/autoupdate/tools/migration.go b/lib/autoupdate/tools/migration.go new file mode 100644 index 0000000000000..d0b9fe94c3116 --- /dev/null +++ b/lib/autoupdate/tools/migration.go @@ -0,0 +1,152 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/google/uuid" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/utils" +) + +const ( + // migrationFilePerms defines the permissions for files created during the migration process. + migrationFilePerms = 0o755 +) + +// migrateV1AndUpdateConfig launches migration process and add migrated +// tools to configuration file. +func migrateV1AndUpdateConfig(toolsDir string, tools []string) error { + if err := updateToolsConfig(toolsDir, func(ctc *ClientToolsConfig) error { + migratedTools, err := migrateV1(toolsDir, tools) + if err != nil { + return trace.Wrap(err) + } + if len(migratedTools) == 0 { + return nil + } + + for _, tool := range migratedTools { + ctc.AddTool(tool) + } + return nil + }); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// migrateV1 verifies the tool binary located in the tool's directory. +// If it is a symlink, it reads the target location and generates a tool object +// to be saved in the configuration for backward compatibility. +// If it is a regular binary, a new package folder should be created, +// and the binary should be copied to the new location. +// TODO(vapopov): DELETE in v21.0.0 - the version without caching will no longer be supported. +func migrateV1(toolsDir string, tools []string) (map[string]Tool, error) { + newPkg := fmt.Sprint(uuid.New().String(), updatePackageSuffixV2) + migratedTools := make(map[string]Tool) + for _, tool := range tools { + path := filepath.Join(toolsDir, tool) + info, err := os.Lstat(path) + if os.IsNotExist(err) { + continue + } + if err != nil { + return nil, trace.Wrap(err, "failed to retrieve information for tool %q", tool) + } + + toolVersion, err := CheckToolVersion(path) + if trace.IsBadParameter(err) { + // If we can't identify toolVersion, it is blocked by EDR software or binary + // is damaged we should continue migration process. + slog.ErrorContext(context.Background(), "failed to check the toolVersion", "error", err) + continue + } else if err != nil { + return nil, trace.Wrap(err) + } + + if info.Mode().Type()&os.ModeSymlink != 0 { + fullPath, err := os.Readlink(path) + if err != nil { + return nil, trace.Wrap(err, "failed to read symlink %q", path) + } + pkg, relPath, err := extractPackageName(toolsDir, fullPath) + if err != nil { + return nil, trace.Wrap(err) + } + if err := utils.RecursiveCopy(filepath.Join(toolsDir, pkg), filepath.Join(toolsDir, newPkg), nil); err != nil { + return nil, trace.Wrap(err) + } + if t, ok := migratedTools[toolVersion]; ok { + t.PathMap[tool] = filepath.Join(newPkg, relPath) + } else { + migratedTools[toolVersion] = Tool{ + Version: toolVersion, + PathMap: map[string]string{tool: filepath.Join(newPkg, relPath)}, + } + } + continue + } + + // Create new toolVersion of the package and move tools to new destination. + if t, ok := migratedTools[toolVersion]; ok { + newPath := filepath.Join(toolsDir, newPkg, tool) + if err := utils.CopyFile(path, newPath, migrationFilePerms); err != nil { + return nil, trace.Wrap(err) + } + t.PathMap[tool] = filepath.Join(newPkg, tool) + } else { + if err := os.Mkdir(filepath.Join(toolsDir, newPkg), migrationFilePerms); err != nil { + return nil, trace.Wrap(err) + } + newPath := filepath.Join(toolsDir, newPkg, tool) + if err := utils.CopyFile(path, newPath, migrationFilePerms); err != nil { + return nil, trace.Wrap(err) + } + migratedTools[toolVersion] = Tool{ + Version: toolVersion, + PathMap: map[string]string{tool: filepath.Join(newPkg, tool)}, + } + } + + } + + return migratedTools, nil +} + +func extractPackageName(toolsDir string, fullPath string) (string, string, error) { + rel, err := filepath.Rel(toolsDir, fullPath) + if err != nil { + return "", "", trace.Wrap(err) + } + dir := strings.SplitN(rel, string(filepath.Separator), 2) + if len(dir) == 2 && strings.HasSuffix(dir[0], updatePackageSuffix) { + return dir[0], dir[1], nil + } + + return "", fullPath, nil +} diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go index b236324c81d88..1d470e586d131 100644 --- a/lib/autoupdate/tools/updater.go +++ b/lib/autoupdate/tools/updater.go @@ -24,6 +24,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "io" "log/slog" @@ -33,6 +34,7 @@ import ( "path/filepath" "regexp" "runtime" + "strings" "syscall" "time" @@ -50,16 +52,22 @@ import ( const ( // teleportToolsVersionEnv is environment name for requesting specific version for update. teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION" + // teleportToolsVersionEnvDisabled is a special value that disables teleport tools updates + // when assigned to the teleportToolsVersionEnv environment variable. + teleportToolsVersionEnvDisabled = "off" // teleportToolsVersionReExecEnv is internal environment name for transferring original // version to re-executed ones. teleportToolsVersionReExecEnv = "TELEPORT_TOOLS_VERSION_REEXEC" + // teleportToolsDirsEnv overrides Teleport tools directory for saving updated + // versions. + teleportToolsDirsEnv = "TELEPORT_TOOLS_DIR" // reservedFreeDisk is the predefined amount of free disk space (in bytes) required // to remain available after downloading archives. reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb - // lockFileName is file used for locking update process in parallel. - lockFileName = ".lock" - // updatePackageSuffix is directory suffix used for package extraction in tools directory. + // updatePackageSuffix is directory suffix used for package extraction in tools directory for v1. updatePackageSuffix = "-update-pkg" + // updatePackageSuffix is directory suffix used for package extraction in tools directory for v2. + updatePackageSuffixV2 = "-update-pkg-v2" ) var ( @@ -67,6 +75,13 @@ var ( pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`) ) +// UpdateResponse contains information about after update process. +type UpdateResponse struct { + Version string `json:"version,omitempty"` + ReExec bool `json:"reExec,omitempty"` + Disabled bool `json:"disabled,omitempty"` +} + // Option applies an option value for the Updater. type Option func(u *Updater) @@ -133,37 +148,61 @@ func NewUpdater(toolsDir, localVersion string, options ...Option) *Updater { // CheckLocal is run at client tool startup and will only perform local checks. // Returns the version needs to be updated and re-executed, by re-execution flag we // understand that update and re-execute is required. -func (u *Updater) CheckLocal() (version string, reExec bool, err error) { +func (u *Updater) CheckLocal(ctx context.Context, profileName string) (resp *UpdateResponse, err error) { // Check if the user has requested a specific version of client tools. requestedVersion := os.Getenv(teleportToolsVersionEnv) switch requestedVersion { // The user has turned off any form of automatic updates. - case "off": - return "", false, nil + case teleportToolsVersionEnvDisabled: + return &UpdateResponse{Version: "", ReExec: false}, nil // Requested version already the same as client version. case u.localVersion: - return u.localVersion, false, nil + return &UpdateResponse{Version: u.localVersion, ReExec: false}, nil // No requested version, we continue. case "": // Requested version that is not the local one. default: if _, err := semver.NewVersion(requestedVersion); err != nil { - return "", false, trace.Wrap(err, "checking that request version is semantic") + return nil, trace.Wrap(err, "checking that request version is semantic") + } + return &UpdateResponse{Version: requestedVersion, ReExec: true}, nil + } + + // We should acquire and release the lock before checking the version + // by executing the binary, as it might block tool execution until the version + // check is completed, which can take several seconds. + ctc, err := getToolsConfig(u.toolsDir) + if err != nil { + return nil, trace.Wrap(err) + } + if config, ok := ctc.Configs[profileName]; ok { + if config.Disabled || config.Version == u.localVersion { + return &UpdateResponse{Version: config.Version, ReExec: false}, nil + } else { + return &UpdateResponse{Version: config.Version, ReExec: true}, nil } - return requestedVersion, true, nil } - // If a version of client tools has already been downloaded to - // tools directory, return that. - toolsVersion, err := CheckToolVersion(u.toolsDir) - if trace.IsNotFound(err) || toolsVersion == u.localVersion { - return u.localVersion, false, nil + // Backward compatibility check. If a version of the client tools has already been downloaded + // to the tools directory, return it. Version check failures should be ignored, as EDR software + // might block execution or a broken version may already exist in the tools' directory. + toolsVersion, err := CheckExecutedToolVersion(u.toolsDir) + if trace.IsNotFound(err) || errors.Is(err, ErrVersionCheck) || toolsVersion == u.localVersion { + return &UpdateResponse{Version: u.localVersion, ReExec: false}, nil } if err != nil { - return "", false, trace.Wrap(err) + return nil, trace.Wrap(err) } - return toolsVersion, true, nil + if !ctc.HasVersion(toolsVersion) { + if err := migrateV1AndUpdateConfig(u.toolsDir, u.tools); err != nil { + // Execution should not be interrupted if migration fails. Instead, it's better to + // re-download the version that was supposed to be migrated but failed for some reason. + slog.WarnContext(ctx, "Failed to migrate client tools", "error", err) + } + } + + return &UpdateResponse{Version: toolsVersion, ReExec: true}, nil } // CheckRemote first checks the version set by the environment variable. If not set or disabled, @@ -171,29 +210,45 @@ func (u *Updater) CheckLocal() (version string, reExec bool, err error) { // the `webapi/find` handler, which stores information about the required client tools version to // operate with this cluster. It returns the semantic version that needs updating and whether // re-execution is necessary, by re-execution flag we understand that update and re-execute is required. -func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string, insecure bool) (version string, reExec bool, err error) { +func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string, insecure bool) (response *UpdateResponse, err error) { + proxyHost := utils.TryHost(proxyAddr) // Check if the user has requested a specific version of client tools. requestedVersion := os.Getenv(teleportToolsVersionEnv) switch requestedVersion { // The user has turned off any form of automatic updates. - case "off": - return "", false, nil + case teleportToolsVersionEnvDisabled: + return &UpdateResponse{Version: "", ReExec: false}, nil // Requested version already the same as client version. case u.localVersion: - return u.localVersion, false, nil + if err := updateToolsConfig(u.toolsDir, func(ctc *ClientToolsConfig) error { + ctc.SetConfig(proxyHost, requestedVersion, false) + return nil + }); err != nil { + return nil, trace.Wrap(err) + } + return &UpdateResponse{Version: u.localVersion, ReExec: false}, nil // No requested version, we continue. case "": // Requested version that is not the local one. default: if _, err := semver.NewVersion(requestedVersion); err != nil { - return "", false, trace.Wrap(err, "checking that request version is semantic") + return nil, trace.Wrap(err, "checking that request version is semantic") } - return requestedVersion, true, nil + // If the environment variable is set during a remote check, + // prioritize this version for the current host and use it as the default + // for all commands under the current profile. + if err := updateToolsConfig(u.toolsDir, func(ctc *ClientToolsConfig) error { + ctc.SetConfig(proxyHost, requestedVersion, false) + return nil + }); err != nil { + return nil, trace.Wrap(err) + } + return &UpdateResponse{Version: requestedVersion, ReExec: true}, nil } certPool, err := x509.SystemCertPool() if err != nil { - return "", false, trace.Wrap(err) + return nil, trace.Wrap(err) } resp, err := webclient.Find(&webclient.Config{ Context: ctx, @@ -203,94 +258,77 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string, insecure bo Insecure: insecure, }) if err != nil { - return "", false, trace.Wrap(err) + return nil, trace.Wrap(err) } - // If a version of client tools has already been downloaded to - // tools directory, return that. - toolsVersion, err := CheckToolVersion(u.toolsDir) - if err != nil && !trace.IsNotFound(err) { - return "", false, trace.Wrap(err) - } + updateResp := &UpdateResponse{Version: u.localVersion, ReExec: false} switch { case !resp.AutoUpdate.ToolsAutoUpdate || resp.AutoUpdate.ToolsVersion == "": - if toolsVersion == "" { - return u.localVersion, false, nil - } + updateResp = &UpdateResponse{Version: u.localVersion, ReExec: false, Disabled: true} case u.localVersion == resp.AutoUpdate.ToolsVersion: - return u.localVersion, false, nil - case resp.AutoUpdate.ToolsVersion != toolsVersion: - return resp.AutoUpdate.ToolsVersion, true, nil + updateResp = &UpdateResponse{Version: u.localVersion, ReExec: false} + default: + updateResp = &UpdateResponse{Version: resp.AutoUpdate.ToolsVersion, ReExec: true} } - return toolsVersion, true, nil -} - -// UpdateWithLock acquires filesystem lock, downloads requested version package, -// unarchive and replace existing one. -func (u *Updater) UpdateWithLock(ctx context.Context, updateToolsVersion string) (err error) { - // Create tools directory if it does not exist. - if err := os.MkdirAll(u.toolsDir, 0o755); err != nil { - return trace.Wrap(err) - } - // Lock concurrent client tools execution util requested version is updated. - unlock, err := utils.FSWriteLock(filepath.Join(u.toolsDir, lockFileName)) - if err != nil { - return trace.Wrap(err) - } - defer func() { - err = trace.NewAggregate(err, unlock()) - }() - - // If the version of the running binary or the version downloaded to - // tools directory is the same as the requested version of client tools, - // nothing to be done, exit early. - toolsVersion, err := CheckToolVersion(u.toolsDir) - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } - if updateToolsVersion == toolsVersion { + if err := updateToolsConfig(u.toolsDir, func(ctc *ClientToolsConfig) error { + ctc.SetConfig(proxyHost, updateResp.Version, updateResp.Disabled) return nil + }); err != nil { + return nil, trace.Wrap(err) } - // Download and update client tools in tools directory. - if err := u.Update(ctx, updateToolsVersion); err != nil { - return trace.Wrap(err) - } - - return nil + return updateResp, nil } -// Update downloads requested version and replace it with existing one and cleanups the previous downloads -// with defined updater directory suffix. +// Update acquires filesystem lock, downloads requested version package, unarchive, replace +// existing one and cleanups the previous downloads with defined updater directory suffix. func (u *Updater) Update(ctx context.Context, toolsVersion string) error { - // Get platform specific download URLs. - packages, err := teleportPackageURLs(ctx, u.uriTemplate, u.baseURL, toolsVersion) - if err != nil { - return trace.Wrap(err) - } + err := updateToolsConfig(u.toolsDir, func(ctc *ClientToolsConfig) error { + // ignoreTools is the list of tools installed and tracked by the config. + // They should be preserved during cleanup. If we have more than [defaultSizeStoredVersion] + // versions, the updater will forget about the least used version. + var ignoreTools []string + for _, tool := range ctc.Tools { + // If the version of the running binary or the version downloaded to + // tools directory is the same as the requested version of client tools, + // nothing to be done, exit early. + if tool.Version == toolsVersion { + return nil + } + ignoreTools = append(ignoreTools, tool.PackageNames()...) + } - var pkgNames []string - for _, pkg := range packages { - pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) - if err := u.update(ctx, pkg, pkgName); err != nil { + // Get platform specific download URLs. + packages, err := teleportPackageURLs(ctx, u.uriTemplate, u.baseURL, toolsVersion) + if err != nil { return trace.Wrap(err) } - pkgNames = append(pkgNames, pkgName) - } - // Cleanup the tools directory with previously downloaded and un-archived versions. - if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgNames); err != nil { - slog.WarnContext(ctx, "failed to clean up tools directory", "error", err) - } + var pkgNames []string + for _, pkg := range packages { + pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffixV2) + if err := u.update(ctx, ctc, pkg, pkgName); err != nil { + return trace.Wrap(err) + } + pkgNames = append(pkgNames, pkgName) + } + // Cleanup all tools in directory with the specific prefix by ignoring tools + // that are currently recorded in the configuration. + if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffixV2, append(ignoreTools, pkgNames...)); err != nil { + slog.WarnContext(ctx, "failed to clean up tools directory", "error", err) + } - return nil + return nil + }) + + return trace.Wrap(err) } // update downloads the archive and validate against the hash. Download to a // temporary path within tools directory. -func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) error { +func (u *Updater) update(ctx context.Context, ctc *ClientToolsConfig, pkg packageURL, pkgName string) error { f, err := os.CreateTemp("", "teleport-") if err != nil { return trace.Wrap(err) @@ -330,38 +368,74 @@ func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) er } // Perform atomic replace so concurrent exec do not fail. - if err := packaging.ReplaceToolsBinaries(u.toolsDir, f.Name(), extractDir, u.tools); err != nil { + toolsMap, err := packaging.ReplaceToolsBinaries(f.Name(), extractDir, u.tools) + if err != nil { return trace.Wrap(err) } + for key, val := range toolsMap { + toolsMap[key] = filepath.Join(pkgName, val) + } + ctc.AddTool(Tool{Version: pkg.Version, PathMap: toolsMap}) + return nil } -// Exec re-executes tool command with same arguments and environ variables. -func (u *Updater) Exec(args []string) (int, error) { - path, err := toolName(u.toolsDir) - if err != nil { - return 0, trace.Wrap(err) +// ToolPath loads full path from config file to specific tool and version. +func (u *Updater) ToolPath(toolName, toolVersion string) (path string, err error) { + var tool *Tool + if err := updateToolsConfig(u.toolsDir, func(ctc *ClientToolsConfig) error { + tool = ctc.SelectVersion(toolVersion) + return nil + }); err != nil { + return "", trace.Wrap(err) } - // To prevent re-execution loop we have to disable update logic for re-execution, - // by unsetting current tools version env variable and setting it to "off". - if err := os.Unsetenv(teleportToolsVersionEnv); err != nil { - return 0, trace.Wrap(err) + if tool == nil { + return "", trace.NotFound("tool version %q not found", toolVersion) } - if err := os.Unsetenv(teleportToolsVersionReExecEnv); err != nil { - return 0, trace.Wrap(err) + relPath, ok := tool.PathMap[toolName] + if !ok { + return "", trace.NotFound("tool %q not found", toolName) } - env := os.Environ() + return filepath.Join(u.toolsDir, relPath), nil +} + +// Exec re-executes tool command with same arguments and environ variables. +func (u *Updater) Exec(ctx context.Context, toolsVersion string, args []string) (int, error) { executablePath, err := os.Executable() if err != nil { return 0, trace.Wrap(err) } - if path == executablePath { - env = append(env, teleportToolsVersionEnv+"=off") + path, err := u.ToolPath(filepath.Base(executablePath), toolsVersion) + if err != nil { + return 0, trace.Wrap(err) + } + + for _, unset := range []string{ + teleportToolsVersionReExecEnv, + teleportToolsDirsEnv, + } { + if err := os.Unsetenv(unset); err != nil { + return 0, trace.Wrap(err) + } + } + env := append(os.Environ(), fmt.Sprintf("%s=%s", teleportToolsDirsEnv, u.toolsDir)) + // To prevent re-execution loop we have to disable update logic for re-execution, + // by unsetting current tools version env variable and setting it to "off". + // The re-execution path and tools directory are absolute. Since the v2 logic + // no longer uses a static path, any re-execution from the tools directory + // must disable further re-execution. + if path == executablePath || strings.HasPrefix(path, u.toolsDir) { + if err := os.Unsetenv(teleportToolsVersionEnv); err != nil { + return 0, trace.Wrap(err) + } + env = append(env, teleportToolsVersionEnv+"="+teleportToolsVersionEnvDisabled) + slog.DebugContext(ctx, "Disable next re-execution") } env = append(env, fmt.Sprintf("%s=%s", teleportToolsVersionReExecEnv, u.localVersion)) + slog.DebugContext(ctx, "Re-execute updated version", "execute", path, "from", executablePath) if runtime.GOOS == constants.WindowsOS { cmd := exec.Command(path, args...) cmd.Env = env diff --git a/lib/autoupdate/tools/utils.go b/lib/autoupdate/tools/utils.go index e158ef534169a..0a32aee531ddd 100644 --- a/lib/autoupdate/tools/utils.go +++ b/lib/autoupdate/tools/utils.go @@ -39,22 +39,41 @@ import ( "github.com/gravitational/teleport/lib/autoupdate" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/packaging" ) -var errNoBaseURL = errors.New("baseURL is not defined") +var ( + // ErrNoBaseURL is returned when `TELEPORT_CDN_BASE_URL` must be set + // in order to proceed with managed updates. + ErrNoBaseURL = errors.New("baseURL is not defined") + // ErrVersionCheck is returned when the downloaded version fails + // to execute for version identification. + ErrVersionCheck = errors.New("version check failed") +) -// Dir returns the path to client tools in $TELEPORT_HOME/bin. +// Dir returns the client tools installation directory path, using the following fallback order: +// $TELEPORT_TOOLS_DIR, $TELEPORT_HOME/bin, and $HOME/.tsh/bin. func Dir() (string, error) { - home := os.Getenv(types.HomeEnvVar) - if home == "" { - var err error - home, err = os.UserHomeDir() - if err != nil { - return "", trace.Wrap(err) + toolsDir := os.Getenv(teleportToolsDirsEnv) + if toolsDir == "" { + toolsDir = os.Getenv(types.HomeEnvVar) + if toolsDir == "" { + var err error + toolsDir, err = os.UserHomeDir() + if err != nil { + return "", trace.Wrap(err) + } + toolsDir = filepath.Join(toolsDir, ".tsh", "bin") + } else { + toolsDir = filepath.Join(toolsDir, "bin") } } - return filepath.Join(home, ".tsh", "bin"), nil + toolsDir, err := filepath.Abs(toolsDir) + if err != nil { + return "", trace.Wrap(err) + } + return toolsDir, nil } // DefaultClientTools list of the client tools needs to be updated by default. @@ -67,16 +86,20 @@ func DefaultClientTools() []string { } } -// CheckToolVersion returns current installed client tools version, must return NotFoundError if -// the client tools is not found in tools directory. -func CheckToolVersion(toolsDir string) (string, error) { - // Find the path to the current executable. +// CheckExecutedToolVersion invokes the exact executable from the tools directory to retrieve its version. +func CheckExecutedToolVersion(toolsDir string) (string, error) { path, err := toolName(toolsDir) if err != nil { return "", trace.Wrap(err) } - if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { - return "", trace.NotFound("autoupdate tool not found in %q", toolsDir) + return CheckToolVersion(path) +} + +// CheckToolVersion returns client tools version, must return NotFoundError if +// the client tools is not found in specified path. +func CheckToolVersion(toolPath string) (string, error) { + if _, err := os.Stat(toolPath); errors.Is(err, os.ErrNotExist) { + return "", trace.NotFound("autoupdate tool not found in %q", toolPath) } else if err != nil { return "", trace.Wrap(err) } @@ -89,11 +112,13 @@ func CheckToolVersion(toolsDir string) (string, error) { // Execute "{tsh, tctl} version" and pass in TELEPORT_TOOLS_VERSION=off to // turn off all automatic updates code paths to prevent any recursion. - command := exec.CommandContext(ctx, path, "version") - command.Env = []string{teleportToolsVersionEnv + "=off"} + command := exec.CommandContext(ctx, toolPath, "version") + command.Env = []string{teleportToolsVersionEnv + "=" + teleportToolsVersionEnvDisabled} output, err := command.Output() if err != nil { - return "", trace.WrapWithMessage(err, "failed to determine version of %q tool", path) + slog.DebugContext(context.Background(), "failed to determine version", + "tool", toolPath, "error", err, "output", string(output)) + return "", ErrVersionCheck } // The output for "{tsh, tctl} version" can be multiple lines. Find the @@ -133,10 +158,35 @@ func GetReExecFromVersion(ctx context.Context) string { return reExecFromVersion } +// CleanUp cleans the tools directory with downloaded versions. +func CleanUp(toolsDir string, tools []string) error { + var aggErr []error + for _, tool := range tools { + if err := os.Remove(filepath.Join(toolsDir, tool)); err != nil && !os.IsNotExist(err) { + aggErr = append(aggErr, err) + } + } + if err := os.Remove(filepath.Join(toolsDir, lockFileName)); err != nil && !os.IsNotExist(err) { + aggErr = append(aggErr, err) + } + if err := os.Remove(filepath.Join(toolsDir, configFileName)); err != nil && !os.IsNotExist(err) { + aggErr = append(aggErr, err) + } + if err := packaging.RemoveWithSuffix(toolsDir, updatePackageSuffix, nil); err != nil { + aggErr = append(aggErr, err) + } + if err := packaging.RemoveWithSuffix(toolsDir, updatePackageSuffixV2, nil); err != nil { + aggErr = append(aggErr, err) + } + + return trace.NewAggregate(aggErr...) +} + // packageURL defines URLs to the archive and their archive sha256 hash file, and marks // if this package is optional, for such case download needs to be ignored if package // not found in CDN. type packageURL struct { + Version string Archive string Hash string Optional bool @@ -148,7 +198,7 @@ func teleportPackageURLs(ctx context.Context, uriTmpl string, baseURL, version s envBaseURL := os.Getenv(autoupdate.BaseURLEnvVar) if m.BuildType() == modules.BuildOSS && envBaseURL == "" { slog.WarnContext(ctx, "Client tools updates are disabled as they are licensed under AGPL. To use Community Edition builds or custom binaries, set the 'TELEPORT_CDN_BASE_URL' environment variable.") - return nil, errNoBaseURL + return nil, ErrNoBaseURL } var flags autoupdate.InstallFlags @@ -170,13 +220,13 @@ func teleportPackageURLs(ctx context.Context, uriTmpl string, baseURL, version s } return []packageURL{ - {Archive: teleportURL, Hash: teleportURL + ".sha256"}, - {Archive: tshURL, Hash: tshURL + ".sha256", Optional: true}, + {Version: version, Archive: teleportURL, Hash: teleportURL + ".sha256"}, + {Version: version, Archive: tshURL, Hash: tshURL + ".sha256", Optional: true}, }, nil } return []packageURL{ - {Archive: teleportURL, Hash: teleportURL + ".sha256"}, + {Version: version, Archive: teleportURL, Hash: teleportURL + ".sha256"}, }, nil } diff --git a/lib/client/api.go b/lib/client/api.go index 0c86b1eb19e7c..f565910852083 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -726,7 +726,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, return trace.Wrap(err) } - // Save profile to record proxy credentials + // Save profile to record proxy credentials. if err := tc.SaveProfile(opt.makeCurrentProfile); err != nil { log.Warningf("Failed to save profile: %v", err) return trace.Wrap(err) diff --git a/lib/utils/cli.go b/lib/utils/cli.go index 6417b932bd4d9..bac290805a2f5 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -377,6 +377,18 @@ func InitCLIParser(appName, appHelp string) (app *kingpin.Application) { return app.UsageTemplate(createUsageTemplate()) } +// InitHiddenCLIParser initializes a `kingpin.Application` that does not terminate the application +// or write any usage information to os.Stdout. Can be used in scenarios where multiple `kingpin.Application` +// instances are needed without interfering with subsequent parsing. Usage output is completely suppressed, +// and the default global `--help` flag is ignored to prevent the application from exiting. +func InitHiddenCLIParser() (app *kingpin.Application) { + app = kingpin.New("", "") + app.UsageWriter(io.Discard) + app.Terminate(func(i int) {}) + + return app +} + // createUsageTemplate creates an usage template for kingpin applications. func createUsageTemplate(opts ...func(*usageTemplateOptions)) string { opt := &usageTemplateOptions{ @@ -655,3 +667,27 @@ func FormatAlert(alert types.ClusterAlert) string { } return buf.String() } + +// FilterArguments filters the input arguments, keeping only those defined in the provided `kingpin.ApplicationModel`. +// For example, if the model defines only one boolean flag `--insecure`, all other arguments in `args` +// will be excluded, and only the `--insecure` flag will remain. +func FilterArguments(args []string, model *kingpin.ApplicationModel) []string { + var result []string + for _, flag := range model.Flags { + for i := range args { + if strings.HasPrefix(args[i], fmt.Sprint("--", flag.Name, "=")) { + result = append(result, args[i]) + break + } + if args[i] == fmt.Sprint("--", flag.Name) { + if flag.IsBoolFlag() { + result = append(result, args[i]) + } else if i+2 <= len(args) { + result = append(result, args[i], args[i+1]) + } + break + } + } + } + return result +} diff --git a/lib/utils/cli_test.go b/lib/utils/cli_test.go index bf1200b9334c5..5ee8ba6ebaea6 100644 --- a/lib/utils/cli_test.go +++ b/lib/utils/cli_test.go @@ -244,3 +244,54 @@ Commands: }) } } + +// TestFilterArguments tests filtering command arguments. +func TestFilterArguments(t *testing.T) { + t.Parallel() + + app := kingpin.New("tsh", "") + app.Flag("proxy", "").String() + app.Flag("check-update", "").Bool() + + tests := []struct { + args []string + expected []string + }{ + { + args: []string{"--insecure", "--proxy", "localhost", "--check-update", "test"}, + expected: []string{"--proxy", "localhost", "--check-update"}, + }, + { + args: []string{"--insecure", "--proxy=localhost", "--check-update", "test"}, + expected: []string{"--proxy=localhost", "--check-update"}, + }, + { + args: []string{"--proxy", "localhost", "test"}, + expected: []string{"--proxy", "localhost"}, + }, + { + args: []string{"--proxy"}, + expected: []string(nil), + }, + { + args: []string{"--insecure", "--check-update", "test", "--proxy=localhost"}, + expected: []string{"--proxy=localhost", "--check-update"}, + }, + { + args: []string{"--insecure", "--check-update", "test", "--proxy1=localhost"}, + expected: []string{"--check-update"}, + }, + { + args: []string{"--check-update", "test", "--proxy1", "localhost"}, + expected: []string{"--check-update"}, + }, + { + args: []string{"--insecure", "test", "--proxy1", "localhost", "--check-update"}, + expected: []string{"--check-update"}, + }, + } + + for i, tt := range tests { + require.Equal(t, tt.expected, FilterArguments(tt.args, app.Model()), fmt.Sprintf("test case %v", i)) + } +} diff --git a/lib/utils/packaging/unarchive.go b/lib/utils/packaging/unarchive.go index ed81a4b4b2d24..ddcab0da5c5c2 100644 --- a/lib/utils/packaging/unarchive.go +++ b/lib/utils/packaging/unarchive.go @@ -70,22 +70,22 @@ func RemoveWithSuffix(dir, suffix string, skipNames []string) error { // replaceZip un-archives the Teleport package in .zip format, iterates through // the compressed content, and ignores everything not matching the binaries specified -// in the execNames argument. The data is extracted to extractDir, and symlinks are created -// in toolsDir pointing to the extractDir path with binaries. -func replaceZip(toolsDir string, archivePath string, extractDir string, execNames []string) error { +// in the execNames argument. The data is extracted to extractDir, and copies are created in toolsDir. +func replaceZip(archivePath string, extractDir string, execNames []string) (map[string]string, error) { + execPaths := make(map[string]string, len(execNames)) f, err := os.Open(archivePath) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer f.Close() fi, err := f.Stat() if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } zipReader, err := zip.NewReader(f, fi.Size()) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } var totalSize uint64 = 0 @@ -101,7 +101,7 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName } // Verify that we have enough space for uncompressed zipFile. if err := checkFreeSpace(extractDir, totalSize); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } for _, zipFile := range zipReader.File { @@ -123,26 +123,17 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName if err != nil { return trace.Wrap(err) } - defer destFile.Close() - if _, err := io.Copy(destFile, file); err != nil { - return trace.Wrap(err) - } - appPath := filepath.Join(toolsDir, baseName) - // For the Windows build, we need to copy the binary to perform updates without requiring - // administrative access, which would otherwise be needed for creating symlinks. - // Since symlinks are not used on the Windows platform, there's no need to remove appPath - // before copying the new binary — it will simply be replaced. - if err := utils.CopyFile(dest, appPath, 0o755); err != nil { - return trace.Wrap(err) + return trace.NewAggregate(err, destFile.Close()) } return trace.Wrap(destFile.Close()) }(zipFile); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + execPaths[baseName] = baseName } - return nil + return execPaths, nil } // checkFreeSpace verifies that we have enough requested space (in bytes) at specific directory. diff --git a/lib/utils/packaging/unarchive_test.go b/lib/utils/packaging/unarchive_test.go index b124b603b0fd5..4b64083aec364 100644 --- a/lib/utils/packaging/unarchive_test.go +++ b/lib/utils/packaging/unarchive_test.go @@ -22,6 +22,7 @@ package packaging import ( "context" + "fmt" "os" "path/filepath" "runtime" @@ -36,21 +37,7 @@ import ( // TestPackaging verifies un-archiving of all supported teleport package formats. func TestPackaging(t *testing.T) { script := "#!/bin/sh\necho test" - - sourceDir, err := os.MkdirTemp(os.TempDir(), "source") - require.NoError(t, err) - - toolsDir, err := os.MkdirTemp(os.TempDir(), "dest") - require.NoError(t, err) - - extractDir, err := os.MkdirTemp(toolsDir, "extract") - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, os.RemoveAll(extractDir)) - require.NoError(t, os.RemoveAll(sourceDir)) - require.NoError(t, os.RemoveAll(toolsDir)) - }) + sourceDir := t.TempDir() // Create test script for packaging in relative path `teleport\bin` to ensure that // binaries going to be identified and extracted flatten to `extractDir`. @@ -62,20 +49,23 @@ func TestPackaging(t *testing.T) { ctx := context.Background() t.Run("tar.gz", func(t *testing.T) { + toolsDir := t.TempDir() + extractDir, err := os.MkdirTemp(toolsDir, "extract") + require.NoError(t, err) + archivePath := filepath.Join(toolsDir, "tsh.tar.gz") err = archive.CompressDirToTarGzFile(ctx, sourceDir, archivePath) require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") // For the .tar.gz format we extract app by app to check that content discard is not required. - err = replaceTarGz(toolsDir, archivePath, extractDir, []string{"tctl"}) - require.NoError(t, err) - err = replaceTarGz(toolsDir, archivePath, extractDir, []string{"tsh"}) + toolsMap, err := replaceTarGz(archivePath, extractDir, []string{"tsh", "tctl"}) require.NoError(t, err) - assert.FileExists(t, filepath.Join(toolsDir, "tsh"), "script not created") - assert.FileExists(t, filepath.Join(toolsDir, "tctl"), "script not created") + for tool, path := range toolsMap { + assert.FileExists(t, filepath.Join(extractDir, path), fmt.Sprintf("script: %q not found", tool)) + } - data, err := os.ReadFile(filepath.Join(toolsDir, "tsh")) + data, err := os.ReadFile(filepath.Join(extractDir, "tsh")) require.NoError(t, err) assert.Equal(t, script, string(data)) }) @@ -84,33 +74,42 @@ func TestPackaging(t *testing.T) { if runtime.GOOS != "darwin" { t.Skip("unsupported platform") } + toolsDir := t.TempDir() + extractDir, err := os.MkdirTemp(toolsDir, "extract") + require.NoError(t, err) + archivePath := filepath.Join(toolsDir, "tsh.pkg") err = archive.CompressDirToPkgFile(ctx, sourceDir, archivePath, "com.example.pkgtest") require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") - err = replacePkg(toolsDir, archivePath, filepath.Join(extractDir, "apps"), []string{"tsh", "tctl"}) - require.NoError(t, err) - assert.FileExists(t, filepath.Join(toolsDir, "tsh"), "script not created") - assert.FileExists(t, filepath.Join(toolsDir, "tctl"), "script not created") - - data, err := os.ReadFile(filepath.Join(toolsDir, "tsh")) + toolsMap, err := replacePkg(archivePath, filepath.Join(extractDir, "apps"), []string{"tsh", "tctl"}) require.NoError(t, err) - assert.Equal(t, script, string(data)) + for tool, path := range toolsMap { + assert.FileExists(t, filepath.Join(extractDir, "apps", path), fmt.Sprintf("script: %q not found", tool)) + data, err := os.ReadFile(filepath.Join(extractDir, "apps", path)) + require.NoError(t, err) + assert.Equal(t, script, string(data)) + } }) t.Run("zip", func(t *testing.T) { + toolsDir := t.TempDir() + extractDir, err := os.MkdirTemp(toolsDir, "extract") + require.NoError(t, err) + archivePath := filepath.Join(toolsDir, "tsh.zip") err = archive.CompressDirToZipFile(ctx, sourceDir, archivePath) require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") - err = replaceZip(toolsDir, archivePath, extractDir, []string{"tsh", "tctl"}) + toolsMap, err := replaceZip(archivePath, extractDir, []string{"tsh", "tctl"}) require.NoError(t, err) - assert.FileExists(t, filepath.Join(toolsDir, "tsh"), "script not created") - assert.FileExists(t, filepath.Join(toolsDir, "tctl"), "script not created") + for tool, path := range toolsMap { + assert.FileExists(t, filepath.Join(extractDir, path), fmt.Sprintf("script: %q not found", tool)) + } - data, err := os.ReadFile(filepath.Join(toolsDir, "tsh")) + data, err := os.ReadFile(filepath.Join(extractDir, "tsh")) require.NoError(t, err) assert.Equal(t, script, string(data)) }) diff --git a/lib/utils/packaging/unarchive_unix.go b/lib/utils/packaging/unarchive_unix.go index 8daf1b3aa5525..a525fb21c0499 100644 --- a/lib/utils/packaging/unarchive_unix.go +++ b/lib/utils/packaging/unarchive_unix.go @@ -31,7 +31,6 @@ import ( "runtime" "slices" - "github.com/google/renameio/v2" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/constants" @@ -43,12 +42,12 @@ import ( // // For Darwin, archivePath must be a .pkg file. // For other POSIX, archivePath must be a gzipped tarball. -func ReplaceToolsBinaries(toolsDir string, archivePath string, extractDir string, execNames []string) error { +func ReplaceToolsBinaries(archivePath string, extractDir string, execNames []string) (map[string]string, error) { switch runtime.GOOS { case constants.DarwinOS: - return replacePkg(toolsDir, archivePath, extractDir, execNames) + return replacePkg(archivePath, extractDir, execNames) default: - return replaceTarGz(toolsDir, archivePath, extractDir, execNames) + return replaceTarGz(archivePath, extractDir, execNames) } } @@ -56,19 +55,20 @@ func ReplaceToolsBinaries(toolsDir string, archivePath string, extractDir string // the compressed content, and ignores everything not matching the app binaries specified // in the apps argument. The data is extracted to extractDir, and symlinks are created // in toolsDir pointing to the extractDir path with binaries. -func replaceTarGz(toolsDir string, archivePath string, extractDir string, execNames []string) error { +func replaceTarGz(archivePath string, extractDir string, execNames []string) (map[string]string, error) { + execPaths := make(map[string]string, len(execNames)) if err := validateFreeSpaceTarGz(archivePath, extractDir, execNames); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } f, err := os.Open(archivePath) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer f.Close() gzipReader, err := gzip.NewReader(f) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } tarReader := tar.NewReader(gzipReader) for { @@ -77,7 +77,7 @@ func replaceTarGz(toolsDir string, archivePath string, extractDir string, execNa break } if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } baseName := filepath.Base(header.Name) // Skip over any files in the archive that are not in execNames. @@ -86,27 +86,22 @@ func replaceTarGz(toolsDir string, archivePath string, extractDir string, execNa } if err = func(header *tar.Header) error { - tempFile, err := renameio.TempFile(extractDir, filepath.Join(toolsDir, baseName)) + dest := filepath.Join(extractDir, baseName) + destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return trace.Wrap(err) } - defer tempFile.Cleanup() - if err := os.Chmod(tempFile.Name(), 0o755); err != nil { - return trace.Wrap(err) - } - if _, err := io.Copy(tempFile, tarReader); err != nil { - return trace.Wrap(err) - } - if err := tempFile.CloseAtomicallyReplace(); err != nil { - return trace.Wrap(err) + if _, err := io.Copy(destFile, tarReader); err != nil { + return trace.NewAggregate(err, destFile.Close()) } - return trace.Wrap(tempFile.Cleanup()) + return trace.Wrap(destFile.Close()) }(header); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + execPaths[baseName] = baseName } - return trace.Wrap(gzipReader.Close()) + return execPaths, trace.Wrap(gzipReader.Close()) } // validateFreeSpaceTarGz validates that extraction size match available disk space in `extractDir`. @@ -146,7 +141,8 @@ func validateFreeSpaceTarGz(archivePath string, extractDir string, execNames []s // The data is extracted to extractDir, and symlinks are created in toolsDir pointing to the binaries // in extractDir. Before creating the symlinks, each binary must be executed at least once to pass // OS signature verification. -func replacePkg(toolsDir string, archivePath string, extractDir string, execNames []string) error { +func replacePkg(archivePath string, extractDir string, execNames []string) (map[string]string, error) { + execPaths := make(map[string]string, len(execNames)) // Use "pkgutil" from the filesystem to expand the archive. In theory .pkg // files are xz archives, however it's still safer to use "pkgutil" in-case // Apple makes non-standard changes to the format. @@ -154,11 +150,11 @@ func replacePkg(toolsDir string, archivePath string, extractDir string, execName // Full command: pkgutil --expand-full NAME.pkg DIRECTORY/ pkgutil, err := exec.LookPath("pkgutil") if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if err = exec.Command(pkgutil, "--expand-full", archivePath, extractDir).Run(); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } err = filepath.Walk(extractDir, func(path string, info os.FileInfo, err error) error { @@ -191,18 +187,14 @@ func replacePkg(toolsDir string, archivePath string, extractDir string, execName if err := command.Run(); err != nil { return trace.WrapWithMessage(err, "failed to validate binary") } - - // Due to macOS applications not being a single binary (they are a - // directory), atomic operations are not possible. To work around this, use - // a symlink (which can be atomically swapped), then do a cleanup pass - // removing any stale copies of the expanded package. - newName := filepath.Join(toolsDir, filepath.Base(path)) - if err := renameio.Symlink(path, newName); err != nil { + relPath, err := filepath.Rel(extractDir, path) + if err != nil { return trace.Wrap(err) } + execPaths[info.Name()] = relPath return nil }) - return trace.Wrap(err) + return execPaths, trace.Wrap(err) } diff --git a/lib/utils/packaging/unarchive_windows.go b/lib/utils/packaging/unarchive_windows.go index c07471adce83c..f17b6c7e4bed7 100644 --- a/lib/utils/packaging/unarchive_windows.go +++ b/lib/utils/packaging/unarchive_windows.go @@ -25,6 +25,6 @@ package packaging // toolsDir/[name]. // // For Windows, archivePath must be a .zip file. -func ReplaceToolsBinaries(toolsDir string, archivePath string, extractPath string, execNames []string) error { - return replaceZip(toolsDir, archivePath, extractPath, execNames) +func ReplaceToolsBinaries(archivePath string, extractPath string, execNames []string) (map[string]string, error) { + return replaceZip(archivePath, extractPath, execNames) } diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 5da5b39d05685..752fa5d79bf41 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -300,6 +300,16 @@ func Host(hostname string) (string, error) { return host, nil } +// TryHost is a utility function that extracts host from the host:port pair, +// in case of any error returns the original value. +func TryHost(in string) string { + out, err := Host(in) + if err != nil { + return in + } + return out +} + // SplitHostPort splits host and port and checks that host is not empty func SplitHostPort(hostname string) (string, string, error) { host, port, err := net.SplitHostPort(hostname) diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index c1a47e0acdd28..cb3afd2fc9bc8 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -26,6 +26,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" @@ -73,11 +74,7 @@ type CLICommand interface { // // distribution: name of the Teleport distribution func Run(ctx context.Context, commands []CLICommand) { - if err := tools.CheckAndUpdateLocal(ctx, os.Args[1:]); err != nil { - utils.FatalError(err) - } - - err := TryRun(commands, os.Args[1:]) + err := TryRun(ctx, commands, os.Args[1:]) if err != nil { var exitError *common.ExitCodeError if errors.As(err, &exitError) { @@ -89,18 +86,42 @@ func Run(ctx context.Context, commands []CLICommand) { // TryRun is a helper function for Run to call - it runs a tctl command and returns an error. // This is useful for testing tctl, because we can capture the returned error in tests. -func TryRun(commands []CLICommand, args []string) error { +func TryRun(ctx context.Context, commands []CLICommand, args []string) error { utils.InitLogger(utils.LoggingForCLI, slog.LevelWarn) - // app is the command line parser - app := utils.InitCLIParser("tctl", GlobalHelpString) + var ccf tctlcfg.GlobalCLIFlags + muApp := utils.InitHiddenCLIParser() + muApp.Flag("auth-server", + fmt.Sprintf("Attempts to connect to specific auth/proxy address(es) instead of local auth [%v]", defaults.AuthConnectAddr().Addr)). + Envar(authAddrEnvVar). + StringsVar(&ccf.AuthServerAddr) + // We need to parse the arguments before executing managed updates to identify + // the profile name and the required version for the current cluster. + // All other commands and flags may change between versions, so full parsing + // should be performed only after managed updates are applied. + if _, err := muApp.Parse(utils.FilterArguments(args, muApp.Model())); err != nil { + slog.WarnContext(ctx, "can't identify current profile", "error", err) + } // cfg (teleport auth server configuration) is going to be shared by all // commands cfg := servicecfg.MakeDefaultConfig() cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() + cfg.TeleportHome = os.Getenv(types.HomeEnvVar) + if cfg.TeleportHome != "" { + cfg.TeleportHome = filepath.Clean(cfg.TeleportHome) + } - var ccf tctlcfg.GlobalCLIFlags + var name string + if len(ccf.AuthServerAddr) != 0 { + name = utils.TryHost(strings.TrimPrefix(strings.ToLower(ccf.AuthServerAddr[0]), "https://")) + } + if err := tools.CheckAndUpdateLocal(ctx, name, args); err != nil { + return trace.Wrap(err) + } + + // app is the command line parser + app := utils.InitCLIParser("tctl", GlobalHelpString) // Each command will add itself to the CLI parser. for i := range commands { @@ -157,14 +178,8 @@ func TryRun(commands []CLICommand, args []string) error { return trace.BadParameter("tctl --identity also requires --auth-server") } - cfg.TeleportHome = os.Getenv(types.HomeEnvVar) - if cfg.TeleportHome != "" { - cfg.TeleportHome = filepath.Clean(cfg.TeleportHome) - } - cfg.Debug = ccf.Debug - ctx := context.Background() clientFunc := commonclient.GetInitFunc(ccf, cfg) // Execute whatever is selected. for _, c := range commands { diff --git a/tool/tsh/common/autoupdate.go b/tool/tsh/common/autoupdate.go new file mode 100644 index 0000000000000..62f6dca8e5b0a --- /dev/null +++ b/tool/tsh/common/autoupdate.go @@ -0,0 +1,72 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "github.com/alecthomas/kingpin/v2" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/autoupdate/tools" +) + +type autoUpdateCommand struct { + update *managedUpdatesUpdateCommand +} + +func newUpdateCommand(app *kingpin.Application) *autoUpdateCommand { + root := &autoUpdateCommand{ + update: &managedUpdatesUpdateCommand{}, + } + + root.update.CmdClause = app.Command("update", + "Update client tools (tsh, tctl) to the latest version defined by the cluster configuration.") + root.update.CmdClause.Flag("clear", "Removes locally installed client tools updates from the Teleport home directory.").BoolVar(&root.update.clear) + + return root +} + +// managedUpdatesUpdateCommand additionally check for the latest available client tools +// version in cluster and runs update. +type managedUpdatesUpdateCommand struct { + *kingpin.CmdClause + clear bool +} + +func (c *managedUpdatesUpdateCommand) run(cf *CLIConf) error { + if c.clear { + toolsDir, err := tools.Dir() + if err != nil { + return trace.Wrap(err) + } + if err := tools.CleanUp(toolsDir, tools.DefaultClientTools()); err != nil { + return trace.Wrap(err) + } + return nil + } + + tc, err := makeClient(cf) + if err != nil { + return trace.Wrap(err) + } + if err := tools.DownloadUpdate(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil { + return trace.Wrap(err) + } + + return nil +} diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 9ef9e4c4eb610..18403d5c75c27 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -49,8 +49,8 @@ import ( // onProxyCommandSSH creates a local ssh proxy, dialing a node and transferring // data through stdin and stdout, to be used as an OpenSSH and PuTTY proxy // command. -func onProxyCommandSSH(cf *CLIConf) error { - tc, err := makeClient(cf) +func onProxyCommandSSH(cf *CLIConf, initFunc ClientInitFunc) error { + tc, err := initFunc(cf) if err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/common/tctl_test.go b/tool/tsh/common/tctl_test.go index 94958414f93fb..13c9a763386c3 100644 --- a/tool/tsh/common/tctl_test.go +++ b/tool/tsh/common/tctl_test.go @@ -140,7 +140,7 @@ func TestRemoteTctlWithProfile(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - err := common.TryRun(tt.commands, tt.args) + err := common.TryRun(context.Background(), tt.commands, tt.args) if tt.wantErrContains != "" { var exitError *toolcommon.ExitCodeError require.ErrorAs(t, err, &exitError) diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 88e147e50f065..d26ffc41b6390 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -139,6 +139,10 @@ var accessRequestModes = []string{ accessRequestModeRole, } +// ClientInitFunc defines a function that initiates a connection to +// the Teleport cluster using the CLI configuration. +type ClientInitFunc func(cf *CLIConf) (*client.TeleportClient, error) + // CLIConf stores command line arguments and flags: type CLIConf struct { // UserHost contains "[login]@hostname" argument to SSH command @@ -586,6 +590,8 @@ type CLIConf struct { // atomic here is overkill as the CLIConf is generally consumed sequentially. However, occasionally // we need concurrency safety, such as for [forEachProfileParallel]. clientStoreSet int32 + // checkManagedUpdates initiates check of managed update after client connects to cluster. + checkManagedUpdates bool } // Stdout returns the stdout writer. @@ -684,6 +690,7 @@ const ( proxyKubeConfigEnvVar = "TELEPORT_KUBECONFIG" noResumeEnvVar = "TELEPORT_NO_RESUME" requestModeEnvVar = "TELEPORT_REQUEST_MODE" + toolsCheckUpdateEnvVar = "TELEPORT_TOOLS_CHECK_UPDATE" clusterHelp = "Specify the Teleport cluster to connect" browserHelp = "Set to 'none' to suppress browser opening on login" @@ -727,10 +734,6 @@ func initLogger(cf *CLIConf) { // // DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions). func Run(ctx context.Context, args []string, opts ...CliOption) error { - if err := tools.CheckAndUpdateLocal(ctx, args); err != nil { - return trace.Wrap(err) - } - cf := CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), @@ -738,6 +741,23 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { DTAutoEnroll: dtenroll.AutoEnroll, } + // We need to parse the arguments before executing managed updates to identify + // the profile name and the required version for the current cluster. + // All other commands and flags may change between versions, so full parsing + // should be performed only after managed updates are applied. + var proxyArg string + muApp := utils.InitHiddenCLIParser() + muApp.Flag("proxy", "Teleport proxy address").Envar(proxyEnvVar).Hidden().StringVar(&proxyArg) + muApp.Flag("check-update", "Check for availability of managed update.").Envar(toolsCheckUpdateEnvVar).Hidden().BoolVar(&cf.checkManagedUpdates) + if _, err := muApp.Parse(utils.FilterArguments(args, muApp.Model())); err != nil { + slog.WarnContext(ctx, "can't identify current profile", "error", err) + } + // Check local update for specific proxy from configuration. + name := utils.TryHost(strings.TrimPrefix(strings.ToLower(proxyArg), "https://")) + if err := tools.CheckAndUpdateLocal(ctx, name, args); err != nil { + return trace.Wrap(err) + } + // run early to enable debug logging if env var is set. // this makes it possible to debug early startup functionality, particularly command aliases. initLogger(&cf) @@ -804,6 +824,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { StringVar(&cf.MlockMode) app.HelpFlag.Short('h') app.Flag("piv-slot", "Specify a PIV slot key to use for Hardware Key support instead of the default. Ex: \"9d\"").Envar("TELEPORT_PIV_SLOT").StringVar(&cf.PIVSlot) + app.Flag("check-update", "Check for availability of managed update.").Envar(toolsCheckUpdateEnvVar).Hidden().BoolVar(&cf.checkManagedUpdates) ver := app.Command("version", "Print the tsh client and Proxy server versions for the current context.") ver.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) @@ -1248,6 +1269,9 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { puttyConfig.Hidden() } + // Client-tools managed updates commands. + updateCommand := newUpdateCommand(app) + // FIDO2, TouchID and WebAuthnWin commands. f2 := fido2.NewCommand(app) tid := touchid.NewCommand(app) @@ -1414,7 +1438,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { case ver.FullCommand(): err = onVersion(&cf) case ssh.FullCommand(): - err = onSSH(&cf) + err = onSSH(&cf, wrapInitClientWithUpdateCheck(makeClient, args)) case resolve.FullCommand(): err = onResolve(&cf) // If quiet was specified for this command and @@ -1547,7 +1571,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { case scan.keys.FullCommand(): err = scan.keys.run(&cf) case proxySSH.FullCommand(): - err = onProxyCommandSSH(&cf) + err = onProxyCommandSSH(&cf, wrapInitClientWithUpdateCheck(makeClient, args)) case proxyDB.FullCommand(): err = onProxyCommandDB(&cf) case proxyApp.FullCommand(): @@ -1646,6 +1670,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = vnetDaemonCmd.run(&cf) case pivCmd.agent.FullCommand(): err = pivCmd.agent.run(&cf) + case updateCommand.update.FullCommand(): + err = updateCommand.update.run(&cf) default: // Handle commands that might not be available. switch { @@ -1925,7 +1951,7 @@ func onLogin(cf *CLIConf, reExecArgs ...string) error { // in case if parameters match, re-fetch kube clusters and print // current status case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "" || - host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "": + utils.TryHost(cf.Proxy) == utils.TryHost(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "": // The user has typed `tsh login`, if the running binary needs to // be updated, update and re-exec. @@ -1944,7 +1970,7 @@ func onLogin(cf *CLIConf, reExecArgs ...string) error { return trace.Wrap(printLoginInformation(cf, profile, profiles, cf.getAccessListsToReview(tc))) // if the proxy names match but nothing else is specified; show motd and update active profile and kube configs - case host(cf.Proxy) == host(profile.ProxyURL.Host) && + case utils.TryHost(cf.Proxy) == utils.TryHost(profile.ProxyURL.Host) && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "": // The user has typed `tsh login`, if the running binary needs to @@ -1977,7 +2003,7 @@ func onLogin(cf *CLIConf, reExecArgs ...string) error { // proxy is unspecified or the same as the currently provided proxy, // but cluster is specified, treat this as selecting a new cluster // for the same proxy - case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "": + case (cf.Proxy == "" || utils.TryHost(cf.Proxy) == utils.TryHost(profile.ProxyURL.Host)) && cf.SiteName != "": _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -2007,7 +2033,7 @@ func onLogin(cf *CLIConf, reExecArgs ...string) error { // proxy is unspecified or the same as the currently provided proxy, // but desired roles or request ID is specified, treat this as a // privilege escalation request for the same login session. - case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "": + case (cf.Proxy == "" || utils.TryHost(cf.Proxy) == utils.TryHost(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "": _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -3875,7 +3901,7 @@ func onResolve(cf *CLIConf) error { } // onSSH executes 'tsh ssh' command -func onSSH(cf *CLIConf) error { +func onSSH(cf *CLIConf, initFunc ClientInitFunc) error { // If "tsh ssh -V" is invoked, tsh is in OpenSSH compatibility mode, show // the version and exit. if cf.ShowVersion { @@ -3901,7 +3927,7 @@ func onSSH(cf *CLIConf) error { return trace.BadParameter("required argument '[user@]host' not provided") } - tc, err := makeClient(cf) + tc, err := initFunc(cf) if err != nil { return trace.Wrap(err) } @@ -4100,6 +4126,23 @@ func makeClient(cf *CLIConf) (*client.TeleportClient, error) { return tc, trace.Wrap(err) } +// wrapInitClientWithUpdateCheck wraps the client initialization function to the Teleport cluster, +// adding a managed update check immediately after the connection is established. +func wrapInitClientWithUpdateCheck(clientInitFunc ClientInitFunc, reExecArgs []string) ClientInitFunc { + return func(cf *CLIConf) (*client.TeleportClient, error) { + tc, err := clientInitFunc(cf) + if err != nil { + return nil, trace.Wrap(err) + } + if cf.checkManagedUpdates { + if err := tools.CheckAndUpdateRemote(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify, reExecArgs); err != nil { + return nil, trace.Wrap(err) + } + } + return tc, nil + } +} + // makeClient takes the command-line configuration and a proxy address and constructs & returns // a fully configured TeleportClient object func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, error) { @@ -5191,17 +5234,6 @@ func getTshEnv() map[string]string { return env } -// host is a utility function that extracts -// host from the host:port pair, in case of any error -// returns the original value -func host(in string) string { - out, err := utils.Host(in) - if err != nil { - return in - } - return out -} - func awaitRequestResolution(ctx context.Context, clt authclient.ClientI, req types.AccessRequest) (types.AccessRequest, error) { filter := types.AccessRequestFilter{ User: req.GetUser(), diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 33a9d4011125b..e8a4eecc50b47 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -439,6 +439,29 @@ func TestNoEnvVars(t *testing.T) { require.NoError(t, trace.NewAggregate(err, ctx.Err())) } +// TestDefaultPrintUsage verifies that the main `kingpin.Application` parser has not been +// previously terminated, and that it correctly prints the usage message when using the +// global `--help` flag or the `help` command, and both are identical. +func TestDefaultPrintUsage(t *testing.T) { + t.Parallel() + testExecutable, err := os.Executable() + require.NoError(t, err) + + ctx := context.Background() + + cmd := exec.CommandContext(ctx, testExecutable, "version", "--help") + cmd.Env = []string{fmt.Sprintf("%s=1", tshBinMainTestEnv), "TELEPORT_TOOLS_VERSION=off"} + flagOutput, err := cmd.CombinedOutput() + require.NoError(t, err) + require.Contains(t, string(flagOutput), "Print the tsh client and Proxy server versions for the current context") + + cmd = exec.CommandContext(ctx, testExecutable, "help", "version") + cmd.Env = []string{fmt.Sprintf("%s=1", tshBinMainTestEnv), "TELEPORT_TOOLS_VERSION=off"} + commandOutput, err := cmd.CombinedOutput() + require.NoError(t, err) + require.Equal(t, string(flagOutput), string(commandOutput)) +} + func TestFailedLogin(t *testing.T) { tmpHomePath := t.TempDir()