diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go
index 2ccb27d0cb72b..c501b92fe00d4 100644
--- a/api/client/webclient/webclient.go
+++ b/api/client/webclient/webclient.go
@@ -297,6 +297,10 @@ type PingResponse struct {
ServerVersion string `json:"server_version"`
// MinClientVersion is the minimum client version required by the server.
MinClientVersion string `json:"min_client_version"`
+ // ToolsVersion defines the version of {tsh, tctl} for client auto-upgrade.
+ ToolsVersion string `json:"tools_version"`
+ // ToolsAutoupdate enables client autoupdate feature.
+ ToolsAutoupdate bool `json:"tools_autoupdate"`
// ClusterName contains the name of the Teleport cluster.
ClusterName string `json:"cluster_name"`
diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod
index bc6b5fb8f4df1..0f248045cc81b 100644
--- a/integrations/terraform/go.mod
+++ b/integrations/terraform/go.mod
@@ -180,6 +180,7 @@ require (
github.com/google/go-tpm-tools v0.4.4 // indirect
github.com/google/go-tspi v0.3.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
+ github.com/google/renameio/v2 v2.0.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/safetext v0.0.0-20240104143208-7a7d9b3d812f // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum
index 03c8fc34ff45b..2e396be5c6465 100644
--- a/integrations/terraform/go.sum
+++ b/integrations/terraform/go.sum
@@ -1259,7 +1259,6 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
-github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
diff --git a/lib/client/api.go b/lib/client/api.go
index 74fd055476196..ae9cf5af4bffb 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -93,6 +93,7 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/agentconn"
+ "github.com/gravitational/teleport/tool/common/update"
)
const (
@@ -684,6 +685,31 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error,
return trace.Wrap(err)
}
+ // The user has typed a command like `tsh ssh ...` without being logged in,
+ // if the running binary needs to be updated, update and re-exec.
+ //
+ // If needed, download the new version of {tsh, tctl} and re-exec. Make
+ // sure to exit this process with the same exit code as the child process.
+ //
+ toolsVersion, reexec, err := update.CheckRemote(ctx, tc.WebProxyAddr)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if reexec {
+ // Download the version of client tools required by the cluster.
+ err := update.Download(toolsVersion)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := update.Exec()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ os.Exit(code)
+ }
+
if opt.afterLoginHook != nil {
if err := opt.afterLoginHook(); err != nil {
return trace.Wrap(err)
diff --git a/tool/common/update/client.go b/tool/common/update/client.go
new file mode 100644
index 0000000000000..b36e023fe894e
--- /dev/null
+++ b/tool/common/update/client.go
@@ -0,0 +1,61 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "net/http"
+ "net/url"
+ "time"
+
+ "golang.org/x/net/http/httpproxy"
+
+ apidefaults "github.com/gravitational/teleport/api/defaults"
+ tracehttp "github.com/gravitational/teleport/api/observability/tracing/http"
+ apiutils "github.com/gravitational/teleport/api/utils"
+)
+
+type downloadConfig struct {
+ // Insecure turns off TLS certificate verification when enabled.
+ Insecure bool
+ // Pool defines the set of root CAs to use when verifying server
+ // certificates.
+ Pool *x509.CertPool
+ // Timeout is a timeout for requests.
+ Timeout time.Duration
+}
+
+func newClient(cfg *downloadConfig) *http.Client {
+ rt := apiutils.NewHTTPRoundTripper(&http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: cfg.Insecure,
+ RootCAs: cfg.Pool,
+ },
+ Proxy: func(req *http.Request) (*url.URL, error) {
+ return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
+ },
+ IdleConnTimeout: apidefaults.DefaultIOTimeout,
+ }, nil)
+
+ return &http.Client{
+ Transport: tracehttp.NewTransport(rt),
+ Timeout: cfg.Timeout,
+ }
+}
diff --git a/tool/common/update/feature_ent.go b/tool/common/update/feature_ent.go
new file mode 100644
index 0000000000000..9e9b10cab2287
--- /dev/null
+++ b/tool/common/update/feature_ent.go
@@ -0,0 +1,26 @@
+//go:build webassets_ent
+// +build webassets_ent
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+func init() {
+ featureFlag |= FlagEnt
+}
diff --git a/tool/common/update/feature_fips.go b/tool/common/update/feature_fips.go
new file mode 100644
index 0000000000000..5fae3c6936933
--- /dev/null
+++ b/tool/common/update/feature_fips.go
@@ -0,0 +1,26 @@
+//go:build fips
+// +build fips
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+func init() {
+ featureFlag |= FlagFips
+}
diff --git a/tool/common/update/integration/main.go b/tool/common/update/integration/main.go
new file mode 100644
index 0000000000000..ac67fca48f657
--- /dev/null
+++ b/tool/common/update/integration/main.go
@@ -0,0 +1,63 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package main
+
+import (
+ "errors"
+ "fmt"
+ "os"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/gravitational/teleport/tool/common/update"
+)
+
+var version = "development"
+
+func main() {
+ // At process startup, check if a version has already been downloaded to
+ // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION
+ // environment variable. If so, re-exec that version of {tsh, tctl}.
+ toolsVersion, reExec := update.CheckLocal()
+ if reExec {
+ // Download the version of client tools required by the cluster. This
+ // is required if the user passed in the TELEPORT_TOOLS_VERSION
+ // explicitly.
+ err := update.Download(toolsVersion)
+ if errors.Is(err, update.ErrCanceled) {
+ os.Exit(0)
+ return
+ }
+ if err != nil {
+ log.Fatalf("Failed to download version (%v): %v", toolsVersion, err)
+ return
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := update.Exec()
+ if err != nil {
+ log.Fatalf("Failed to re-exec client tool: %v", err)
+ } else {
+ os.Exit(code)
+ }
+ }
+ if len(os.Args) > 1 && os.Args[1] == "version" {
+ fmt.Printf("Teleport v%v git", version)
+ }
+}
diff --git a/tool/common/update/progress.go b/tool/common/update/progress.go
new file mode 100644
index 0000000000000..a1ce6f66a5d4d
--- /dev/null
+++ b/tool/common/update/progress.go
@@ -0,0 +1,80 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "os/signal"
+ "strings"
+)
+
+var (
+ ErrCanceled = fmt.Errorf("canceled")
+)
+
+// cancelableTeeReader is a copy of TeeReader with ability to react on signal notifier
+// to cancel reading process.
+func cancelableTeeReader(r io.Reader, w io.Writer, signals ...os.Signal) io.Reader {
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, signals...)
+
+ return &teeReader{r, w, sigs}
+}
+
+type teeReader struct {
+ r io.Reader
+ w io.Writer
+ sigs chan os.Signal
+}
+
+func (t *teeReader) Read(p []byte) (n int, err error) {
+ select {
+ case <-t.sigs:
+ return 0, ErrCanceled
+ default:
+ n, err = t.r.Read(p)
+ if n > 0 {
+ if n, err := t.w.Write(p[:n]); err != nil {
+ return n, err
+ }
+ }
+ }
+ return
+}
+
+type progressWriter struct {
+ n int64
+ limit int64
+}
+
+func (w *progressWriter) Write(p []byte) (int, error) {
+ w.n = w.n + int64(len(p))
+
+ n := int((w.n*100)/w.limit) / 10
+ bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n)
+ fmt.Printf("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)")
+
+ if w.n == w.limit {
+ fmt.Printf("\n")
+ }
+
+ return len(p), nil
+}
diff --git a/tool/common/update/update.go b/tool/common/update/update.go
new file mode 100644
index 0000000000000..3a4f97ce6800e
--- /dev/null
+++ b/tool/common/update/update.go
@@ -0,0 +1,431 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/coreos/go-semver/semver"
+ "github.com/gravitational/trace"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/api/types"
+)
+
+const (
+ teleportToolsVersion = "TELEPORT_TOOLS_VERSION"
+ checksumHexLen = 64
+ reservedFreeDisk = 10 * 1024 * 1024
+
+ FlagEnt = 1 << 0
+ FlagFips = 1 << 1
+)
+
+var (
+ pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
+ baseUrl = "https://cdn.teleport.dev"
+ defaultClient = newClient(&downloadConfig{})
+ featureFlag int
+)
+
+// CheckLocal is run at client tool startup and will only perform local checks.
+func CheckLocal() (string, bool) {
+ // If a version of client tools has already been downloaded to
+ // $TELEPORT_HOME/bin, return that.
+ toolsVersion, err := version()
+ if err != nil {
+ return "", false
+ }
+
+ // Check if the user has requested a specific version of client tools.
+ requestedVersion := os.Getenv(teleportToolsVersion)
+ switch {
+ // The user has turned off any form of automatic updates.
+ case requestedVersion == "off":
+ return "", false
+ // Requested version already the same as client version.
+ case teleport.Version == requestedVersion:
+ return requestedVersion, false
+ // The user has requested a specific version of client tools.
+ case requestedVersion != "" && requestedVersion != toolsVersion:
+ return requestedVersion, true
+ }
+
+ return toolsVersion, false
+}
+
+// CheckRemote will check against Proxy Service if client tools need to be
+// updated.
+func CheckRemote(ctx context.Context, proxyAddr string) (string, bool, error) {
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return "", false, trace.Wrap(err)
+ }
+ resp, err := webclient.Find(&webclient.Config{
+ Context: ctx,
+ ProxyAddr: proxyAddr,
+ Pool: certPool,
+ Timeout: 30 * time.Second,
+ })
+ if err != nil {
+ return "", false, trace.Wrap(err)
+ }
+
+ // If a version of client tools has already been downloaded to
+ // $TELEPORT_HOME/bin, return that.
+ toolsVersion, err := version()
+ if err != nil {
+ return "", false, nil
+ }
+
+ requestedVersion := os.Getenv(teleportToolsVersion)
+ switch {
+ // The user has turned off any form of automatic updates.
+ case requestedVersion == "off":
+ return "", false, nil
+ // Requested version already the same as client version.
+ case teleport.Version == requestedVersion:
+ return requestedVersion, false, nil
+ case requestedVersion != "" && requestedVersion != toolsVersion:
+ return requestedVersion, true, nil
+ case !resp.ToolsAutoupdate || resp.ToolsVersion == "":
+ return "", false, nil
+ case teleport.Version == resp.ToolsVersion:
+ return resp.ToolsVersion, false, nil
+ case resp.ToolsVersion != toolsVersion:
+ return resp.ToolsVersion, true, nil
+ }
+
+ return toolsVersion, false, nil
+}
+
+// Download downloads requested version package, unarchive and replace existing one.
+func Download(toolsVersion string) error {
+ // If the version of the running binary or the version downloaded to
+ // $TELEPORT_HOME/bin is the same as the requested version of client tools,
+ // nothing to be done, exit early.
+ teleportVersion, err := version()
+ if err != nil {
+ if !trace.IsNotFound(err) {
+ return trace.Wrap(err)
+ }
+ }
+ if toolsVersion == teleport.Version || toolsVersion == teleportVersion {
+ return nil
+ }
+
+ // Create $TELEPORT_HOME/bin if it does not exist.
+ dir, err := toolsDir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Download and update {tsh, tctl} in $TELEPORT_HOME/bin.
+ if err := update(toolsVersion); err != nil {
+ return trace.Wrap(err)
+ }
+
+ return nil
+}
+
+func update(toolsVersion string) error {
+ dir, err := toolsDir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ // Lock to allow multiple concurrent {tsh, tctl} to run.
+ unlock, err := lock(dir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer unlock()
+
+ // Get platform specific download URLs.
+ archiveURL, hashURL, err := urls(toolsVersion)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ log.Debugf("Archive download path: %v.", archiveURL)
+
+ // Download the archive and validate against the hash. Download to a
+ // temporary path within $TELEPORT_HOME/bin.
+ hash, err := downloadHash(hashURL)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ path, err := downloadArchive(archiveURL, hash)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer os.Remove(path)
+
+ // Perform atomic replace so concurrent exec do not fail.
+ if err := replace(path, hash); err != nil {
+ return trace.Wrap(err)
+ }
+
+ return nil
+}
+
+// urls returns the URL for the Teleport archive to download. The format is:
+// https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz
+func urls(toolsVersion string) (string, string, error) {
+ var archive string
+
+ switch runtime.GOOS {
+ case "darwin":
+ archive = baseUrl + "/tsh-" + toolsVersion + ".pkg"
+ case "windows":
+ archive = baseUrl + "/teleport-v" + toolsVersion + "-windows-amd64-bin.zip"
+ case "linux":
+ edition := ""
+ if featureFlag&FlagEnt != 0 || featureFlag&FlagFips != 0 {
+ edition = "ent-"
+ }
+ fips := ""
+ if featureFlag&FlagFips != 0 {
+ fips = "fips-"
+ }
+
+ var b strings.Builder
+ b.WriteString(baseUrl + "/teleport-")
+ if edition != "" {
+ b.WriteString(edition)
+ }
+ b.WriteString("v" + toolsVersion + "-" + runtime.GOOS + "-" + runtime.GOARCH + "-")
+ if fips != "" {
+ b.WriteString(fips)
+ }
+ b.WriteString("bin.tar.gz")
+ archive = b.String()
+ default:
+ return "", "", trace.BadParameter("unsupported runtime: %v", runtime.GOOS)
+ }
+
+ return archive, archive + ".sha256", nil
+}
+
+func downloadHash(url string) (string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ resp, err := defaultClient.Do(req)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return "", trace.BadParameter("request failed with: %v", resp.StatusCode)
+ }
+
+ var buf bytes.Buffer
+ _, err = io.CopyN(&buf, resp.Body, checksumHexLen)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ raw := buf.String()
+ if _, err = hex.DecodeString(raw); err != nil {
+ return "", trace.Wrap(err)
+ }
+ return raw, nil
+}
+
+func downloadArchive(url string, hash string) (string, error) {
+ resp, err := defaultClient.Get(url)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode)
+ }
+
+ dir, err := toolsDir()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ if resp.ContentLength != -1 {
+ if err := checkFreeSpace(dir, uint64(resp.ContentLength)); err != nil {
+ return "", trace.Wrap(err)
+ }
+ }
+
+ // Caller of this function will remove this file after the atomic swap has
+ // occurred.
+ f, err := os.CreateTemp(dir, "tmp-")
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ h := sha256.New()
+ pw := &progressWriter{n: 0, limit: resp.ContentLength}
+ body := cancelableTeeReader(io.TeeReader(resp.Body, h), pw, syscall.SIGINT, syscall.SIGTERM)
+
+ // It is a little inefficient to download the file to disk and then re-load
+ // it into memory to unarchive later, but this is safer as it allows {tsh,
+ // tctl} to validate the hash before trying to operate on the archive.
+ _, err = io.Copy(f, body)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ if fmt.Sprintf("%x", h.Sum(nil)) != hash {
+ return "", trace.BadParameter("hash of archive does not match downloaded archive")
+ }
+
+ return f.Name(), nil
+}
+
+// Exec re-executes tool command with same arguments and environ variables.
+func Exec() (int, error) {
+ path, err := toolName()
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+
+ cmd := exec.Command(path, os.Args[1:]...)
+ cmd.Env = os.Environ()
+ // To prevent re-execution loop we have to disable update logic for re-execution.
+ cmd.Env = append(cmd.Env, teleportToolsVersion+"=off")
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Run(); err != nil {
+ return 0, trace.Wrap(err)
+ }
+
+ return cmd.ProcessState.ExitCode(), nil
+}
+
+func version() (string, error) {
+ // Find the path to the current executable.
+ path, err := toolName()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ // Set a timeout to not let "{tsh, tctl} version" block forever. Allow up
+ // to 10 seconds because sometimes MDM tools like Jamf cause a lot of
+ // latency in launching binaries.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // Execute "{tsh, tctl} version" and pass in TELEPORT_TOOLS_VERSION=off to
+ // turn off all automatic updates code paths to prevent any recursion.
+ command := exec.CommandContext(ctx, path, "version")
+ command.Env = []string{teleportToolsVersion + "=off"}
+ output, err := command.Output()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ // The output for "{tsh, tctl} version" can be multiple lines. Find the
+ // actual version line and extract the version.
+ scanner := bufio.NewScanner(bytes.NewReader(output))
+ for scanner.Scan() {
+ line := scanner.Text()
+
+ if !strings.HasPrefix(line, "Teleport") {
+ continue
+ }
+
+ matches := pattern.FindStringSubmatch(line)
+ if len(matches) != 2 {
+ return "", trace.BadParameter("invalid version line: %v", line)
+ }
+ version, err := semver.NewVersion(matches[1])
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ return version.String(), nil
+ }
+
+ return "", trace.BadParameter("unable to determine version")
+}
+
+// toolsDir returns the path to {tsh, tctl} in $TELEPORT_HOME/bin.
+func toolsDir() (string, error) {
+ home := os.Getenv(types.HomeEnvVar)
+ if home == "" {
+ var err error
+ home, err = os.UserHomeDir()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ }
+
+ return filepath.Join(filepath.Clean(home), ".tsh", "bin"), nil
+}
+
+// toolName returns the path to {tsh, tctl} for the executable that started
+// the current process.
+func toolName() (string, error) {
+ base, err := toolsDir()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ executablePath, err := os.Executable()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ toolName := filepath.Base(executablePath)
+
+ return filepath.Join(base, toolName), nil
+}
+
+// checkFreeSpace verifies that we have enough requested space at specific directory.
+func checkFreeSpace(path string, requested uint64) error {
+ free, err := freeDiskWithReserve(path)
+ if err != nil {
+ return trace.Errorf("failed to calculate free disk in %q: %v", path, err)
+ }
+ // Bail if there's not enough free disk space at the target.
+ if requested > free {
+ return trace.Errorf("%q needs %d additional bytes of disk space", path, requested-free)
+ }
+ return nil
+}
diff --git a/tool/common/update/update_darwin.go b/tool/common/update/update_darwin.go
new file mode 100644
index 0000000000000..252d40b8c8093
--- /dev/null
+++ b/tool/common/update/update_darwin.go
@@ -0,0 +1,110 @@
+//go:build darwin
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+ log "github.com/sirupsen/logrus"
+)
+
+func replace(path string, hash string) error {
+ dir, err := toolsDir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Use "pkgutil" from the filesystem to expand the archive. In theory .pkg
+ // files are xz archives, however it's still safer to use "pkgutil" in-case
+ // Apple makes non-standard changes to the format.
+ //
+ // Full command: pkgutil --expand-full NAME.pkg DIRECTORY/
+ pkgutil, err := exec.LookPath("pkgutil")
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ expandPath := filepath.Join(dir, hash+"-pkg")
+ out, err := exec.Command(pkgutil, "--expand-full", path, expandPath).Output()
+ if err != nil {
+ log.Debugf("Failed to run pkgutil: %v.", out)
+ return trace.Wrap(err)
+ }
+
+ for _, app := range []string{"tsh", "tctl"} {
+ // The first time a signed and notarized binary macOS application is run,
+ // execution is paused while it gets sent to Apple to verify. Once Apple
+ // approves the binary, the "com.apple.macl" extended attribute is added
+ // and the process is allowed to execute. This process is not concurrent, any
+ // other operations (like moving the application) on the application during
+ // this time will lead to the application being sent SIGKILL.
+ //
+ // Since {tsh, tctl} have to be concurrent, execute {tsh, tctl} before
+ // performing any swap operations. This ensures that the "com.apple.macl"
+ // extended attribute is set and macOS will not send a SIGKILL to the
+ // process if multiple processes are trying to operate on it.
+ expandExecPath := filepath.Join(expandPath, "Payload", app+".app", "Contents", "MacOS", app)
+ command := exec.Command(expandExecPath, "version", "--client")
+ command.Env = []string{teleportToolsVersion + "=off"}
+ if err := command.Run(); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Due to macOS applications not being a single binary (they are a
+ // directory), atomic operations are not possible. To work around this, use
+ // a symlink (which can be atomically swapped), then do a cleanup pass
+ // removing any stale copies of the expanded package.
+ oldName := filepath.Join(expandPath, "Payload", app+".app", "Contents", "MacOS", app)
+ newName := filepath.Join(dir, app)
+ if err := renameio.Symlink(oldName, newName); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ // Perform a cleanup pass to remove any old copies of "{tsh, tctl}.app".
+ err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if !info.IsDir() {
+ return nil
+ }
+ if hash+"-pkg" == info.Name() {
+ return nil
+ }
+ if !strings.HasSuffix(info.Name(), "-pkg") {
+ return nil
+ }
+
+ // Found a stale expanded package.
+ if err := os.RemoveAll(filepath.Join(dir, info.Name())); err != nil {
+ return err
+ }
+
+ return nil
+ })
+
+ return nil
+}
diff --git a/tool/common/update/update_helper_test.go b/tool/common/update/update_helper_test.go
new file mode 100644
index 0000000000000..2eb1b050e0aa9
--- /dev/null
+++ b/tool/common/update/update_helper_test.go
@@ -0,0 +1,87 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "net/http"
+ "sync"
+)
+
+type limitRequest struct {
+ limit int64
+ lock chan struct{}
+}
+
+// limitedResponseWriter wraps http.ResponseWriter and enforces a write limit.
+type limitedResponseWriter struct {
+ requests chan limitRequest
+}
+
+// newLimitedResponseWriter creates a new limitedResponseWriter with the lock.
+func newLimitedResponseWriter() *limitedResponseWriter {
+ lw := &limitedResponseWriter{
+ requests: make(chan limitRequest, 10),
+ }
+ return lw
+}
+
+// Wrap wraps response writer if limit was previously requested, if not, return original one.
+func (lw *limitedResponseWriter) Wrap(w http.ResponseWriter) http.ResponseWriter {
+ select {
+ case request := <-lw.requests:
+ return &wrapper{
+ ResponseWriter: w,
+ request: request,
+ }
+ default:
+ return w
+ }
+}
+
+func (lw *limitedResponseWriter) SetLimitRequest(limit limitRequest) {
+ lw.requests <- limit
+}
+
+// wrapper wraps the http response writer to control writing operation by blocking it.
+type wrapper struct {
+ http.ResponseWriter
+
+ written int64
+ request limitRequest
+ released bool
+
+ mutex sync.Mutex
+}
+
+// Write writes data to the underlying ResponseWriter but respects the byte limit.
+func (lw *wrapper) Write(p []byte) (int, error) {
+ lw.mutex.Lock()
+ defer lw.mutex.Unlock()
+
+ if lw.written >= lw.request.limit && !lw.released {
+ // Send signal that lock is acquired and wait till it was released by response.
+ lw.request.lock <- struct{}{}
+ <-lw.request.lock
+ lw.released = true
+ }
+
+ n, err := lw.ResponseWriter.Write(p)
+ lw.written += int64(n)
+ return n, err
+}
diff --git a/tool/common/update/update_linux.go b/tool/common/update/update_linux.go
new file mode 100644
index 0000000000000..8e02b9bc407e5
--- /dev/null
+++ b/tool/common/update/update_linux.go
@@ -0,0 +1,93 @@
+//go:build linux
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "archive/tar"
+ "compress/gzip"
+ "errors"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+ log "github.com/sirupsen/logrus"
+)
+
+func replace(path string, _ string) error {
+ dir, err := toolsDir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ tempDir := renameio.TempDir(dir)
+
+ f, err := os.Open(path)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ gzipReader, err := gzip.NewReader(f)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ tarReader := tar.NewReader(gzipReader)
+ for {
+ header, err := tarReader.Next()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ // Skip over any files in the archive that are not {tsh, tctl}.
+ if header.Name != "teleport/tctl" &&
+ header.Name != "teleport/tsh" &&
+ header.Name != "teleport/tbot" {
+ if _, err := io.Copy(io.Discard, tarReader); err != nil {
+ log.Debugf("failed to discard %v: %v.", header.Name, err)
+ }
+ continue
+ }
+
+ // Verify that we have enough space for uncompressed file.
+ if err := checkFreeSpace(tempDir, uint64(header.Size)); err != nil {
+ return trace.Wrap(err)
+ }
+
+ dest := filepath.Join(dir, strings.TrimPrefix(header.Name, "teleport/"))
+ t, err := renameio.TempFile(tempDir, dest)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if err := os.Chmod(t.Name(), 0755); err != nil {
+ return trace.Wrap(err)
+ }
+ defer t.Cleanup()
+
+ if _, err := io.Copy(t, tarReader); err != nil {
+ return trace.Wrap(err)
+ }
+ if err := t.CloseAtomicallyReplace(); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+ return nil
+}
diff --git a/tool/common/update/update_test.go b/tool/common/update/update_test.go
new file mode 100644
index 0000000000000..ac7cda653a3f5
--- /dev/null
+++ b/tool/common/update/update_test.go
@@ -0,0 +1,409 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "archive/tar"
+ "archive/zip"
+ "bytes"
+ "compress/gzip"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ testBinaryName = "updater"
+)
+
+var (
+ // testVersions list of the pre-compiled binaries with encoded versions to check.
+ testVersions = []string{
+ "1.2.3",
+ "3.2.1",
+ }
+ limitedWriter = newLimitedResponseWriter()
+)
+
+// TestLock verifies that second lock call is blocked until first is released.
+func TestLock(t *testing.T) {
+ var locked atomic.Bool
+
+ dir := os.TempDir()
+
+ // Acquire first lock should not return any error.
+ unlock, err := lock(dir)
+ require.NoError(t, err)
+ locked.Store(true)
+
+ signal := make(chan struct{})
+ errChan := make(chan error)
+ go func() {
+ signal <- struct{}{}
+ unlock, err := lock(dir)
+ if err != nil {
+ errChan <- err
+ }
+ if locked.Load() {
+ errChan <- fmt.Errorf("first lock is still acquired, second lock must be blocking")
+ }
+ unlock()
+ signal <- struct{}{}
+ }()
+
+ <-signal
+ // We have to wait till next lock is reached to ensure we block execution of goroutine.
+ // Since this is system call we can't track if the function reach blocking state already.
+ time.Sleep(100 * time.Millisecond)
+ locked.Store(false)
+ unlock()
+
+ select {
+ case <-signal:
+ case err := <-errChan:
+ require.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Errorf("second lock is not released")
+ }
+}
+
+// TestUpdateInterruptSignal verifies the interrupt signal send to the process must stop downloading.
+func TestUpdateInterruptSignal(t *testing.T) {
+ dir, err := toolsDir()
+ require.NoError(t, err, "failed to find tools directory")
+
+ err = os.MkdirAll(dir, 0755)
+ require.NoError(t, err, "failed to create tools directory")
+
+ // Initial fetch the updater binary un-archive and replace.
+ err = update(testVersions[0])
+ require.NoError(t, err)
+
+ var output bytes.Buffer
+ cmd := exec.Command(filepath.Join(dir, "tsh"), "version")
+ cmd.Stdout = &output
+ cmd.Stderr = &output
+ cmd.Env = append(
+ os.Environ(),
+ fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
+ )
+ err = cmd.Start()
+ require.NoError(t, err, "failed to start updater")
+ pid := cmd.Process.Pid
+
+ errChan := make(chan error)
+ go func() {
+ errChan <- cmd.Wait()
+ }()
+
+ // By setting the limit request next test http serving file going blocked until unlock is sent.
+ lock := make(chan struct{})
+ limitedWriter.SetLimitRequest(limitRequest{
+ limit: 1024,
+ lock: lock,
+ })
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Errorf("failed to wait till the download is started")
+ case <-lock:
+ time.Sleep(100 * time.Millisecond)
+ require.NoError(t, sendInterrupt(pid))
+ lock <- struct{}{}
+ }
+
+ // Wait till process finished with exit code 0, but we still should get progress
+ // bar in output content.
+ select {
+ case <-time.After(5 * time.Second):
+ t.Errorf("failed to wait till the process interrupted")
+ case err := <-errChan:
+ require.NoError(t, err)
+ }
+ require.Contains(t, output.String(), "Update progress:")
+}
+
+func TestUpdate(t *testing.T) {
+ dir, err := toolsDir()
+ require.NoError(t, err, "failed to find tools directory")
+
+ err = os.MkdirAll(dir, 0755)
+ require.NoError(t, err, "failed to create tools directory")
+
+ // Fetch compiled test binary with updater logic and install to $TELEPORT_HOME.
+ err = update(testVersions[0])
+ require.NoError(t, err)
+
+ // Verify that the installed version is equal to requested one.
+ cmd := exec.Command(filepath.Join(dir, "tsh"), "version")
+ out, err := cmd.Output()
+ require.NoError(t, err)
+
+ matches := pattern.FindStringSubmatch(string(out))
+ require.Len(t, matches, 2)
+ require.Equal(t, testVersions[0], matches[1])
+
+ // Execute version command again with setting the new version which must
+ // trigger re-execution of the same command after downloading requested version.
+ cmd = exec.Command(filepath.Join(dir, "tsh"), "version")
+ cmd.Env = append(
+ os.Environ(),
+ fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
+ )
+ out, err = cmd.Output()
+ require.NoError(t, err)
+
+ matches = pattern.FindStringSubmatch(string(out))
+ require.Len(t, matches, 2)
+ require.Equal(t, testVersions[1], matches[1])
+}
+
+func TestMain(m *testing.M) {
+ tmp, err := os.MkdirTemp(os.TempDir(), testBinaryName)
+ if err != nil {
+ log.Fatalf("failed to create temporary directory: %v", err)
+ }
+
+ srv, address := startTestHTTPServer(tmp)
+ baseUrl = fmt.Sprintf("http://%s", address)
+
+ for _, version := range testVersions {
+ if err := buildAndArchiveApps(tmp, version, address); err != nil {
+ log.Fatalf("failed to build testing app binary archive: %v", err)
+ }
+ }
+
+ // Run tests after binary is built.
+ code := m.Run()
+
+ if err := srv.Close(); err != nil {
+ log.Fatalf("failed to shutdown server: %v", err)
+ }
+ if err := os.RemoveAll(tmp); err != nil {
+ log.Fatalf("failed to remove temporary directory: %v", err)
+ }
+
+ os.Exit(code)
+}
+
+// serve256File calculates sha256 checksum for requested file.
+func serve256File(w http.ResponseWriter, _ *http.Request, filePath string) {
+ log.Printf("Calculating and serving file checksum: %s\n", filePath)
+
+ w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(filePath)+".sha256\"")
+ w.Header().Set("Content-Type", "plain/text")
+
+ hash := sha256.New()
+ file, err := os.Open(filePath)
+ if err != nil {
+ http.Error(w, "failed to open file", http.StatusInternalServerError)
+ return
+ }
+ defer file.Close()
+
+ if _, err := io.Copy(hash, file); err != nil {
+ http.Error(w, "failed to write to hash", http.StatusInternalServerError)
+ return
+ }
+ if _, err := hex.NewEncoder(w).Write(hash.Sum(nil)); err != nil {
+ http.Error(w, "failed to write checksum", http.StatusInternalServerError)
+ }
+}
+
+// generateZipFile compresses the file into a `.zip` format. This format intended to be
+// used only for windows platform and mocking paths for windows archive.
+func generateZipFile(filePath, destPath string) error {
+ archive, err := os.Create(destPath)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer archive.Close()
+
+ zipWriter := zip.NewWriter(archive)
+ defer zipWriter.Close()
+
+ return filepath.Walk(filePath, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if info.IsDir() {
+ return nil
+ }
+ file, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ zipFileWriter, err := zipWriter.Create(filepath.Base(path))
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ _, err = io.Copy(zipFileWriter, file)
+ return trace.Wrap(err)
+ })
+}
+
+// generateTarGzFile compresses files into a `.tar.gz` format specifically in file
+// structure related to linux packaging.
+func generateTarGzFile(filePath, destPath string) error {
+ archive, err := os.Create(destPath)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer archive.Close()
+
+ gzipWriter := gzip.NewWriter(archive)
+ defer gzipWriter.Close()
+
+ tarWriter := tar.NewWriter(gzipWriter)
+ defer tarWriter.Close()
+
+ return filepath.Walk(filePath, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if info.IsDir() {
+ return nil
+ }
+ file, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ header, err := tar.FileInfoHeader(info, info.Name())
+ if err != nil {
+ return err
+ }
+ header.Name = filepath.Join("teleport", filepath.Base(info.Name()))
+ if err := tarWriter.WriteHeader(header); err != nil {
+ return err
+ }
+
+ _, err = io.Copy(tarWriter, file)
+ return trace.Wrap(err)
+ })
+}
+
+// generatePkgFile runs the macOS `pkgbuild` command to generate a .pkg file from the source.
+func generatePkgFile(filePath, destPath string) error {
+ cmd := exec.Command("pkgbuild",
+ "--root", filePath,
+ "--identifier", "com.example.pkgtest",
+ "--version", "1.0",
+ destPath,
+ )
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Printf("failed to generate .pkg: %s\n", output)
+ return err
+ }
+
+ return nil
+}
+
+// startTestHTTPServer starts the file-serving HTTP server for testing.
+func startTestHTTPServer(baseDir string) (*http.Server, string) {
+ srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ filePath := filepath.Join(baseDir, r.URL.Path)
+ switch {
+ case strings.HasSuffix(r.URL.Path, ".sha256"):
+ serve256File(w, r, strings.TrimSuffix(filePath, ".sha256"))
+ default:
+ http.ServeFile(limitedWriter.Wrap(w), r, filePath)
+ }
+ })}
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ log.Fatalf("failed to create listener: %v", err)
+ }
+
+ go func() {
+ if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ log.Printf("failed to start server: %s", err)
+ }
+ }()
+
+ return srv, listener.Addr().String()
+}
+
+// buildAndArchiveApps compiles the updater integration and pack it depends on platform is used.
+func buildAndArchiveApps(path string, version string, address string) error {
+ versionPath := filepath.Join(path, version)
+ for _, app := range []string{"tsh", "tctl"} {
+ output := filepath.Join(versionPath, app)
+ switch runtime.GOOS {
+ case "windows":
+ output = filepath.Join(versionPath, app+".exe")
+ case "darwin":
+ output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app)
+ }
+ if err := buildBinary(output, version, address); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+ switch runtime.GOOS {
+ case "darwin":
+ return trace.Wrap(generatePkgFile(versionPath, path+"/tsh-"+version+".pkg"))
+ case "windows":
+ return trace.Wrap(generateZipFile(versionPath, path+"/teleport-v"+version+"-windows-amd64-bin.zip"))
+ case "linux":
+ return trace.Wrap(generateTarGzFile(versionPath, path+"/teleport-v"+version+"-linux-"+runtime.GOARCH+"-bin.tar.gz"))
+ default:
+ return trace.BadParameter("unsupported platform")
+ }
+}
+
+// buildBinary executes command to build updater integration logic for testing.
+func buildBinary(output string, version string, address string) error {
+ cmd := exec.Command(
+ "go", "build", "-o", output,
+ "-ldflags", strings.Join([]string{
+ fmt.Sprintf("-X 'main.version=%s'", version),
+ fmt.Sprintf("-X 'github.com/gravitational/teleport/tool/common/update.baseUrl=http://%s'", address),
+ }, " "),
+ "./integration",
+ )
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ return trace.Wrap(cmd.Run())
+}
diff --git a/tool/common/update/update_unix.go b/tool/common/update/update_unix.go
new file mode 100644
index 0000000000000..dc51d964baddd
--- /dev/null
+++ b/tool/common/update/update_unix.go
@@ -0,0 +1,82 @@
+//go:build !windows
+
+package update
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sys/unix"
+)
+
+func lock(dir string) (func(), error) {
+ ctx := context.Background()
+ // Build the path to the lock file that will be used by flock.
+ lockFile := filepath.Join(dir, ".lock")
+
+ // Create the advisory lock using flock.
+ lf, err := os.OpenFile(lockFile, os.O_CREATE|os.O_WRONLY, 0666)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ rc, err := lf.SyscallConn()
+ if err != nil {
+ _ = lf.Close()
+ return nil, trace.Wrap(err)
+ }
+ if err := rc.Control(func(fd uintptr) {
+ err = syscall.Flock(int(fd), syscall.LOCK_EX)
+ }); err != nil {
+ _ = lf.Close()
+ return nil, trace.Wrap(err)
+ }
+
+ return func() {
+ rc, err := lf.SyscallConn()
+ if err != nil {
+ _ = lf.Close()
+ slog.DebugContext(ctx, "failed to acquire syscall connection", "error", err)
+ return
+ }
+ if err := rc.Control(func(fd uintptr) {
+ err = syscall.Flock(int(fd), syscall.LOCK_UN)
+ }); err != nil {
+ slog.DebugContext(ctx, "failed to unlock file", "file", lockFile, "error", err)
+ }
+ if err := lf.Close(); err != nil {
+ slog.DebugContext(ctx, "failed to close lock file", "file", lockFile, "error", err)
+ }
+ }, nil
+}
+
+// sendInterrupt sends a SIGINT to the process.
+func sendInterrupt(pid int) error {
+ err := syscall.Kill(pid, syscall.SIGINT)
+ if errors.Is(err, syscall.ESRCH) {
+ return trace.BadParameter("can't find the process: %v", pid)
+ }
+ return trace.Wrap(err)
+}
+
+// freeDiskWithReserve returns the available disk space.
+func freeDiskWithReserve(dir string) (uint64, error) {
+ var stat unix.Statfs_t
+ err := unix.Statfs(dir, &stat)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ if stat.Bsize < 0 {
+ return 0, trace.Errorf("invalid size")
+ }
+ avail := stat.Bavail * uint64(stat.Bsize)
+ if reservedFreeDisk > avail {
+ return 0, trace.Errorf("no free space left")
+ }
+ return avail - reservedFreeDisk, nil
+}
diff --git a/tool/common/update/update_windows.go b/tool/common/update/update_windows.go
new file mode 100644
index 0000000000000..66c2a68093c8b
--- /dev/null
+++ b/tool/common/update/update_windows.go
@@ -0,0 +1,167 @@
+//go:build windows
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package update
+
+import (
+ "archive/zip"
+ "context"
+ "io"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+
+ "github.com/gravitational/trace"
+)
+
+var (
+ kernel = windows.NewLazyDLL("kernel32.dll")
+ proc = kernel.NewProc("CreateFileW")
+ ctrlEvent = kernel.NewProc("GenerateConsoleCtrlEvent")
+)
+
+func replace(path string, hash string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ fi, err := f.Stat()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ zipReader, err := zip.NewReader(f, fi.Size())
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ dir, err := toolsDir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ tempDir, err := os.MkdirTemp(dir, hash)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ for _, r := range zipReader.File {
+ // Skip over any files in the archive that are not {tsh, tctl}.
+ if r.Name != "tsh.exe" && r.Name != "tctl.exe" {
+ continue
+ }
+
+ // Verify that we have enough space for uncompressed file.
+ if err := checkFreeSpace(tempDir, r.UncompressedSize64); err != nil {
+ return trace.Wrap(err)
+ }
+
+ rr, err := r.Open()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer rr.Close()
+
+ dest := filepath.Join(tempDir, r.Name)
+ t, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer t.Close()
+
+ if _, err := io.Copy(t, rr); err != nil {
+ return trace.Wrap(err)
+ }
+ os.Remove(filepath.Join(dir, r.Name))
+ if err := os.Symlink(dest, filepath.Join(dir, r.Name)); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ return nil
+}
+
+// lock implements locking mechanism for blocking another process acquire the lock until its released.
+func lock(dir string) (func(), error) {
+ path := filepath.Join(dir, ".lock")
+ lockPath, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ var lockFile *os.File
+ for lockFile == nil {
+ fd, _, err := proc.Call(
+ uintptr(unsafe.Pointer(lockPath)),
+ uintptr(windows.GENERIC_READ|windows.GENERIC_WRITE),
+ // Exclusive lock, for shared must be used: uintptr(windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE).
+ uintptr(0),
+ uintptr(0),
+ uintptr(windows.OPEN_ALWAYS),
+ uintptr(windows.FILE_ATTRIBUTE_NORMAL),
+ 0,
+ )
+ switch err.(windows.Errno) {
+ case windows.NO_ERROR, windows.ERROR_ALREADY_EXISTS:
+ lockFile = os.NewFile(fd, path)
+ case windows.ERROR_SHARING_VIOLATION:
+ // if the file is locked by another process we have to wait until the next check.
+ time.Sleep(time.Second)
+ default:
+ windows.CloseHandle(windows.Handle(fd))
+ return nil, trace.Wrap(err)
+ }
+ }
+
+ if err := windows.SetHandleInformation(windows.Handle(lockFile.Fd()), windows.HANDLE_FLAG_INHERIT, 1); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return func() {
+ if err := lockFile.Close(); err != nil {
+ slog.DebugContext(context.Background(), "failed to close lock file", "file", lockFile, "error", err)
+ }
+ }, nil
+}
+
+// sendInterrupt sends a Ctrl-Break event to the process.
+func sendInterrupt(pid int) error {
+ r, _, err := ctrlEvent.Call(uintptr(syscall.CTRL_BREAK_EVENT), uintptr(pid))
+ if r == 0 {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// freeDiskWithReserve returns the available disk space.
+func freeDiskWithReserve(dir string) (uint64, error) {
+ var avail uint64
+ err := windows.GetDiskFreeSpaceEx(windows.StringToUTF16Ptr(dir), &avail, nil, nil)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ if reservedFreeDisk > avail {
+ return 0, trace.Errorf("no free space left")
+ }
+ return avail - reservedFreeDisk, nil
+}
diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go
index 996449f8fc5e6..7f5a79aff37c7 100644
--- a/tool/tctl/common/tctl.go
+++ b/tool/tctl/common/tctl.go
@@ -54,6 +54,7 @@ import (
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/tool/common"
+ "github.com/gravitational/teleport/tool/common/update"
)
const (
@@ -104,6 +105,27 @@ type CLICommand interface {
//
// distribution: name of the Teleport distribution
func Run(commands []CLICommand) {
+ // At process startup, check if a version has already been downloaded to
+ // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION
+ // environment variable. If so, re-exec that version of {tsh, tctl}.
+ toolsVersion, reexec := update.CheckLocal()
+ if reexec {
+ // Download the version of client tools required by the cluster. This
+ // is required if the user passed in the TELEPORT_TOOLS_VERSION
+ // explicitly.
+ if err := update.Download(toolsVersion); err != nil {
+ utils.FatalError(err)
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := update.Exec()
+ if err != nil {
+ utils.FatalError(err)
+ } else {
+ os.Exit(code)
+ }
+ }
+
err := TryRun(commands, os.Args[1:])
if err != nil {
var exitError *common.ExitCodeError
diff --git a/tool/todo.txt b/tool/todo.txt
new file mode 100644
index 0000000000000..d71e3bc5a78a5
--- /dev/null
+++ b/tool/todo.txt
@@ -0,0 +1,6 @@
+1. Windows support. File locking.
+2. tbot support.
+3. Errors and warnings for self-managed updates.
+4. Different editions of Teleport. OSS should download OSS, Enterprise should download Enterprise.
+5. Test migration path.
+6. Testing.
diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go
index 500dcab53c79f..1abaa6918b286 100644
--- a/tool/tsh/common/tsh.go
+++ b/tool/tsh/common/tsh.go
@@ -96,6 +96,7 @@ import (
"github.com/gravitational/teleport/tool/common"
"github.com/gravitational/teleport/tool/common/fido2"
"github.com/gravitational/teleport/tool/common/touchid"
+ "github.com/gravitational/teleport/tool/common/update"
"github.com/gravitational/teleport/tool/common/webauthnwin"
)
@@ -652,6 +653,7 @@ const (
proxyKubeConfigEnvVar = "TELEPORT_KUBECONFIG"
noResumeEnvVar = "TELEPORT_NO_RESUME"
requestModeEnvVar = "TELEPORT_REQUEST_MODE"
+ toolsVersionEnvVar = "TELEPORT_TOOLS_VERSION"
clusterHelp = "Specify the Teleport cluster to connect"
browserHelp = "Set to 'none' to suppress browser opening on login"
@@ -695,6 +697,29 @@ func initLogger(cf *CLIConf) {
//
// DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions).
func Run(ctx context.Context, args []string, opts ...CliOption) error {
+ // At process startup, check if a version has already been downloaded to
+ // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION
+ // environment variable. If so, re-exec that version of {tsh, tctl}.
+ toolsVersion, reexec := update.CheckLocal()
+ if reexec {
+ // Download the version of client tools required by the cluster. This
+ // is required if the user passed in the TELEPORT_TOOLS_VERSION
+ // explicitly.
+ if err := update.Download(toolsVersion); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := update.Exec()
+ if err != nil {
+ log.Debugf("Failed to re-exec client tool: %v.", err)
+ // TODO(russjones): Is 255 the correct error code here?
+ os.Exit(255)
+ } else {
+ os.Exit(code)
+ }
+ }
+
cf := CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
@@ -1832,7 +1857,15 @@ func onLogin(cf *CLIConf) error {
}
tc.HomePath = cf.HomePath
- // client is already logged in and profile is not expired
+ // The user is not logged in and has typed in `tsh --proxy=... login`, if
+ // the running binary needs to be updated, update and re-exec.
+ if profile == nil {
+ if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ // The user is already logged in and the profile is not expired.
if profile != nil && !profile.IsExpired(time.Now()) {
switch {
// in case if nothing is specified, re-fetch kube clusters and print
@@ -1842,6 +1875,13 @@ func onLogin(cf *CLIConf) error {
// current status
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "" ||
host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "":
+
+ // The user has typed `tsh login`, if the running binary needs to
+ // be updated, update and re-exec.
+ if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil {
+ return trace.Wrap(err)
+ }
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1855,6 +1895,13 @@ func onLogin(cf *CLIConf) error {
// if the proxy names match but nothing else is specified; show motd and update active profile and kube configs
case host(cf.Proxy) == host(profile.ProxyURL.Host) &&
cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "":
+
+ // The user has typed `tsh login`, if the running binary needs to
+ // be updated, update and re-exec.
+ if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil {
+ return trace.Wrap(err)
+ }
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1880,6 +1927,7 @@ func onLogin(cf *CLIConf) error {
// but cluster is specified, treat this as selecting a new cluster
// for the same proxy
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "":
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1910,6 +1958,7 @@ func onLogin(cf *CLIConf) error {
// but desired roles or request ID is specified, treat this as a
// privilege escalation request for the same login session.
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "":
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1925,7 +1974,11 @@ func onLogin(cf *CLIConf) error {
// otherwise just pass through to standard login
default:
-
+ // The user is logged in and has typed in `tsh --proxy=... login`, if
+ // the running binary needs to be updated, update and re-exec.
+ if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil {
+ return trace.Wrap(err)
+ }
}
}
@@ -5411,6 +5464,30 @@ const (
"https://goteleport.com/docs/access-controls/guides/headless/#troubleshooting"
)
+func updateAndRun(ctx context.Context, proxy string) error {
+ // If needed, download the new version of {tsh, tctl} and re-exec. Make
+ // sure to exit this process with the same exit code as the child process.
+ toolsVersion, reexec, err := update.CheckRemote(ctx, proxy)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if reexec {
+ // Download the version of client tools required by the cluster.
+ if err := update.Download(toolsVersion); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := update.Exec()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ os.Exit(code)
+ }
+
+ return nil
+}
+
// Lock the process memory to prevent rsa keys and certificates in memory from being exposed in a swap.
func tryLockMemory(cf *CLIConf) error {
if cf.MlockMode == mlockModeAuto {