diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index 2ccb27d0cb72b..c501b92fe00d4 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -297,6 +297,10 @@ type PingResponse struct { ServerVersion string `json:"server_version"` // MinClientVersion is the minimum client version required by the server. MinClientVersion string `json:"min_client_version"` + // ToolsVersion defines the version of {tsh, tctl} for client auto-upgrade. + ToolsVersion string `json:"tools_version"` + // ToolsAutoupdate enables client autoupdate feature. + ToolsAutoupdate bool `json:"tools_autoupdate"` // ClusterName contains the name of the Teleport cluster. ClusterName string `json:"cluster_name"` diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index bc6b5fb8f4df1..0f248045cc81b 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -180,6 +180,7 @@ require ( github.com/google/go-tpm-tools v0.4.4 // indirect github.com/google/go-tspi v0.3.0 // indirect github.com/google/gofuzz v1.2.0 // indirect + github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/google/safetext v0.0.0-20240104143208-7a7d9b3d812f // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 03c8fc34ff45b..2e396be5c6465 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -1259,7 +1259,6 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM= github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= -github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= diff --git a/lib/client/api.go b/lib/client/api.go index 74fd055476196..ae9cf5af4bffb 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -93,6 +93,7 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/agentconn" + "github.com/gravitational/teleport/tool/common/update" ) const ( @@ -684,6 +685,31 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, return trace.Wrap(err) } + // The user has typed a command like `tsh ssh ...` without being logged in, + // if the running binary needs to be updated, update and re-exec. + // + // If needed, download the new version of {tsh, tctl} and re-exec. Make + // sure to exit this process with the same exit code as the child process. + // + toolsVersion, reexec, err := update.CheckRemote(ctx, tc.WebProxyAddr) + if err != nil { + return trace.Wrap(err) + } + if reexec { + // Download the version of client tools required by the cluster. + err := update.Download(toolsVersion) + if err != nil { + return trace.Wrap(err) + } + + // Re-execute client tools with the correct version of client tools. + code, err := update.Exec() + if err != nil { + return trace.Wrap(err) + } + os.Exit(code) + } + if opt.afterLoginHook != nil { if err := opt.afterLoginHook(); err != nil { return trace.Wrap(err) diff --git a/tool/common/update/client.go b/tool/common/update/client.go new file mode 100644 index 0000000000000..b36e023fe894e --- /dev/null +++ b/tool/common/update/client.go @@ -0,0 +1,61 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "net/url" + "time" + + "golang.org/x/net/http/httpproxy" + + apidefaults "github.com/gravitational/teleport/api/defaults" + tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" + apiutils "github.com/gravitational/teleport/api/utils" +) + +type downloadConfig struct { + // Insecure turns off TLS certificate verification when enabled. + Insecure bool + // Pool defines the set of root CAs to use when verifying server + // certificates. + Pool *x509.CertPool + // Timeout is a timeout for requests. + Timeout time.Duration +} + +func newClient(cfg *downloadConfig) *http.Client { + rt := apiutils.NewHTTPRoundTripper(&http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: cfg.Insecure, + RootCAs: cfg.Pool, + }, + Proxy: func(req *http.Request) (*url.URL, error) { + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + }, + IdleConnTimeout: apidefaults.DefaultIOTimeout, + }, nil) + + return &http.Client{ + Transport: tracehttp.NewTransport(rt), + Timeout: cfg.Timeout, + } +} diff --git a/tool/common/update/feature_ent.go b/tool/common/update/feature_ent.go new file mode 100644 index 0000000000000..9e9b10cab2287 --- /dev/null +++ b/tool/common/update/feature_ent.go @@ -0,0 +1,26 @@ +//go:build webassets_ent +// +build webassets_ent + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +func init() { + featureFlag |= FlagEnt +} diff --git a/tool/common/update/feature_fips.go b/tool/common/update/feature_fips.go new file mode 100644 index 0000000000000..5fae3c6936933 --- /dev/null +++ b/tool/common/update/feature_fips.go @@ -0,0 +1,26 @@ +//go:build fips +// +build fips + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +func init() { + featureFlag |= FlagFips +} diff --git a/tool/common/update/integration/main.go b/tool/common/update/integration/main.go new file mode 100644 index 0000000000000..ac67fca48f657 --- /dev/null +++ b/tool/common/update/integration/main.go @@ -0,0 +1,63 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "errors" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/tool/common/update" +) + +var version = "development" + +func main() { + // At process startup, check if a version has already been downloaded to + // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION + // environment variable. If so, re-exec that version of {tsh, tctl}. + toolsVersion, reExec := update.CheckLocal() + if reExec { + // Download the version of client tools required by the cluster. This + // is required if the user passed in the TELEPORT_TOOLS_VERSION + // explicitly. + err := update.Download(toolsVersion) + if errors.Is(err, update.ErrCanceled) { + os.Exit(0) + return + } + if err != nil { + log.Fatalf("Failed to download version (%v): %v", toolsVersion, err) + return + } + + // Re-execute client tools with the correct version of client tools. + code, err := update.Exec() + if err != nil { + log.Fatalf("Failed to re-exec client tool: %v", err) + } else { + os.Exit(code) + } + } + if len(os.Args) > 1 && os.Args[1] == "version" { + fmt.Printf("Teleport v%v git", version) + } +} diff --git a/tool/common/update/progress.go b/tool/common/update/progress.go new file mode 100644 index 0000000000000..a1ce6f66a5d4d --- /dev/null +++ b/tool/common/update/progress.go @@ -0,0 +1,80 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "fmt" + "io" + "os" + "os/signal" + "strings" +) + +var ( + ErrCanceled = fmt.Errorf("canceled") +) + +// cancelableTeeReader is a copy of TeeReader with ability to react on signal notifier +// to cancel reading process. +func cancelableTeeReader(r io.Reader, w io.Writer, signals ...os.Signal) io.Reader { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, signals...) + + return &teeReader{r, w, sigs} +} + +type teeReader struct { + r io.Reader + w io.Writer + sigs chan os.Signal +} + +func (t *teeReader) Read(p []byte) (n int, err error) { + select { + case <-t.sigs: + return 0, ErrCanceled + default: + n, err = t.r.Read(p) + if n > 0 { + if n, err := t.w.Write(p[:n]); err != nil { + return n, err + } + } + } + return +} + +type progressWriter struct { + n int64 + limit int64 +} + +func (w *progressWriter) Write(p []byte) (int, error) { + w.n = w.n + int64(len(p)) + + n := int((w.n*100)/w.limit) / 10 + bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n) + fmt.Printf("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)") + + if w.n == w.limit { + fmt.Printf("\n") + } + + return len(p), nil +} diff --git a/tool/common/update/update.go b/tool/common/update/update.go new file mode 100644 index 0000000000000..3a4f97ce6800e --- /dev/null +++ b/tool/common/update/update.go @@ -0,0 +1,431 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + "syscall" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/types" +) + +const ( + teleportToolsVersion = "TELEPORT_TOOLS_VERSION" + checksumHexLen = 64 + reservedFreeDisk = 10 * 1024 * 1024 + + FlagEnt = 1 << 0 + FlagFips = 1 << 1 +) + +var ( + pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`) + baseUrl = "https://cdn.teleport.dev" + defaultClient = newClient(&downloadConfig{}) + featureFlag int +) + +// CheckLocal is run at client tool startup and will only perform local checks. +func CheckLocal() (string, bool) { + // If a version of client tools has already been downloaded to + // $TELEPORT_HOME/bin, return that. + toolsVersion, err := version() + if err != nil { + return "", false + } + + // Check if the user has requested a specific version of client tools. + requestedVersion := os.Getenv(teleportToolsVersion) + switch { + // The user has turned off any form of automatic updates. + case requestedVersion == "off": + return "", false + // Requested version already the same as client version. + case teleport.Version == requestedVersion: + return requestedVersion, false + // The user has requested a specific version of client tools. + case requestedVersion != "" && requestedVersion != toolsVersion: + return requestedVersion, true + } + + return toolsVersion, false +} + +// CheckRemote will check against Proxy Service if client tools need to be +// updated. +func CheckRemote(ctx context.Context, proxyAddr string) (string, bool, error) { + certPool, err := x509.SystemCertPool() + if err != nil { + return "", false, trace.Wrap(err) + } + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: proxyAddr, + Pool: certPool, + Timeout: 30 * time.Second, + }) + if err != nil { + return "", false, trace.Wrap(err) + } + + // If a version of client tools has already been downloaded to + // $TELEPORT_HOME/bin, return that. + toolsVersion, err := version() + if err != nil { + return "", false, nil + } + + requestedVersion := os.Getenv(teleportToolsVersion) + switch { + // The user has turned off any form of automatic updates. + case requestedVersion == "off": + return "", false, nil + // Requested version already the same as client version. + case teleport.Version == requestedVersion: + return requestedVersion, false, nil + case requestedVersion != "" && requestedVersion != toolsVersion: + return requestedVersion, true, nil + case !resp.ToolsAutoupdate || resp.ToolsVersion == "": + return "", false, nil + case teleport.Version == resp.ToolsVersion: + return resp.ToolsVersion, false, nil + case resp.ToolsVersion != toolsVersion: + return resp.ToolsVersion, true, nil + } + + return toolsVersion, false, nil +} + +// Download downloads requested version package, unarchive and replace existing one. +func Download(toolsVersion string) error { + // If the version of the running binary or the version downloaded to + // $TELEPORT_HOME/bin is the same as the requested version of client tools, + // nothing to be done, exit early. + teleportVersion, err := version() + if err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + if toolsVersion == teleport.Version || toolsVersion == teleportVersion { + return nil + } + + // Create $TELEPORT_HOME/bin if it does not exist. + dir, err := toolsDir() + if err != nil { + return trace.Wrap(err) + } + if err := os.MkdirAll(dir, 0755); err != nil { + return trace.Wrap(err) + } + + // Download and update {tsh, tctl} in $TELEPORT_HOME/bin. + if err := update(toolsVersion); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func update(toolsVersion string) error { + dir, err := toolsDir() + if err != nil { + return trace.Wrap(err) + } + // Lock to allow multiple concurrent {tsh, tctl} to run. + unlock, err := lock(dir) + if err != nil { + return trace.Wrap(err) + } + defer unlock() + + // Get platform specific download URLs. + archiveURL, hashURL, err := urls(toolsVersion) + if err != nil { + return trace.Wrap(err) + } + log.Debugf("Archive download path: %v.", archiveURL) + + // Download the archive and validate against the hash. Download to a + // temporary path within $TELEPORT_HOME/bin. + hash, err := downloadHash(hashURL) + if err != nil { + return trace.Wrap(err) + } + path, err := downloadArchive(archiveURL, hash) + if err != nil { + return trace.Wrap(err) + } + defer os.Remove(path) + + // Perform atomic replace so concurrent exec do not fail. + if err := replace(path, hash); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// urls returns the URL for the Teleport archive to download. The format is: +// https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz +func urls(toolsVersion string) (string, string, error) { + var archive string + + switch runtime.GOOS { + case "darwin": + archive = baseUrl + "/tsh-" + toolsVersion + ".pkg" + case "windows": + archive = baseUrl + "/teleport-v" + toolsVersion + "-windows-amd64-bin.zip" + case "linux": + edition := "" + if featureFlag&FlagEnt != 0 || featureFlag&FlagFips != 0 { + edition = "ent-" + } + fips := "" + if featureFlag&FlagFips != 0 { + fips = "fips-" + } + + var b strings.Builder + b.WriteString(baseUrl + "/teleport-") + if edition != "" { + b.WriteString(edition) + } + b.WriteString("v" + toolsVersion + "-" + runtime.GOOS + "-" + runtime.GOARCH + "-") + if fips != "" { + b.WriteString(fips) + } + b.WriteString("bin.tar.gz") + archive = b.String() + default: + return "", "", trace.BadParameter("unsupported runtime: %v", runtime.GOOS) + } + + return archive, archive + ".sha256", nil +} + +func downloadHash(url string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", trace.Wrap(err) + } + resp, err := defaultClient.Do(req) + if err != nil { + return "", trace.Wrap(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", trace.BadParameter("request failed with: %v", resp.StatusCode) + } + + var buf bytes.Buffer + _, err = io.CopyN(&buf, resp.Body, checksumHexLen) + if err != nil { + return "", trace.Wrap(err) + } + raw := buf.String() + if _, err = hex.DecodeString(raw); err != nil { + return "", trace.Wrap(err) + } + return raw, nil +} + +func downloadArchive(url string, hash string) (string, error) { + resp, err := defaultClient.Get(url) + if err != nil { + return "", trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode) + } + + dir, err := toolsDir() + if err != nil { + return "", trace.Wrap(err) + } + + if resp.ContentLength != -1 { + if err := checkFreeSpace(dir, uint64(resp.ContentLength)); err != nil { + return "", trace.Wrap(err) + } + } + + // Caller of this function will remove this file after the atomic swap has + // occurred. + f, err := os.CreateTemp(dir, "tmp-") + if err != nil { + return "", trace.Wrap(err) + } + + h := sha256.New() + pw := &progressWriter{n: 0, limit: resp.ContentLength} + body := cancelableTeeReader(io.TeeReader(resp.Body, h), pw, syscall.SIGINT, syscall.SIGTERM) + + // It is a little inefficient to download the file to disk and then re-load + // it into memory to unarchive later, but this is safer as it allows {tsh, + // tctl} to validate the hash before trying to operate on the archive. + _, err = io.Copy(f, body) + if err != nil { + return "", trace.Wrap(err) + } + if fmt.Sprintf("%x", h.Sum(nil)) != hash { + return "", trace.BadParameter("hash of archive does not match downloaded archive") + } + + return f.Name(), nil +} + +// Exec re-executes tool command with same arguments and environ variables. +func Exec() (int, error) { + path, err := toolName() + if err != nil { + return 0, trace.Wrap(err) + } + + cmd := exec.Command(path, os.Args[1:]...) + cmd.Env = os.Environ() + // To prevent re-execution loop we have to disable update logic for re-execution. + cmd.Env = append(cmd.Env, teleportToolsVersion+"=off") + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return 0, trace.Wrap(err) + } + + return cmd.ProcessState.ExitCode(), nil +} + +func version() (string, error) { + // Find the path to the current executable. + path, err := toolName() + if err != nil { + return "", trace.Wrap(err) + } + + // Set a timeout to not let "{tsh, tctl} version" block forever. Allow up + // to 10 seconds because sometimes MDM tools like Jamf cause a lot of + // latency in launching binaries. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Execute "{tsh, tctl} version" and pass in TELEPORT_TOOLS_VERSION=off to + // turn off all automatic updates code paths to prevent any recursion. + command := exec.CommandContext(ctx, path, "version") + command.Env = []string{teleportToolsVersion + "=off"} + output, err := command.Output() + if err != nil { + return "", trace.Wrap(err) + } + + // The output for "{tsh, tctl} version" can be multiple lines. Find the + // actual version line and extract the version. + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "Teleport") { + continue + } + + matches := pattern.FindStringSubmatch(line) + if len(matches) != 2 { + return "", trace.BadParameter("invalid version line: %v", line) + } + version, err := semver.NewVersion(matches[1]) + if err != nil { + return "", trace.Wrap(err) + } + return version.String(), nil + } + + return "", trace.BadParameter("unable to determine version") +} + +// toolsDir returns the path to {tsh, tctl} in $TELEPORT_HOME/bin. +func toolsDir() (string, error) { + home := os.Getenv(types.HomeEnvVar) + if home == "" { + var err error + home, err = os.UserHomeDir() + if err != nil { + return "", trace.Wrap(err) + } + } + + return filepath.Join(filepath.Clean(home), ".tsh", "bin"), nil +} + +// toolName returns the path to {tsh, tctl} for the executable that started +// the current process. +func toolName() (string, error) { + base, err := toolsDir() + if err != nil { + return "", trace.Wrap(err) + } + + executablePath, err := os.Executable() + if err != nil { + return "", trace.Wrap(err) + } + toolName := filepath.Base(executablePath) + + return filepath.Join(base, toolName), nil +} + +// checkFreeSpace verifies that we have enough requested space at specific directory. +func checkFreeSpace(path string, requested uint64) error { + free, err := freeDiskWithReserve(path) + if err != nil { + return trace.Errorf("failed to calculate free disk in %q: %v", path, err) + } + // Bail if there's not enough free disk space at the target. + if requested > free { + return trace.Errorf("%q needs %d additional bytes of disk space", path, requested-free) + } + return nil +} diff --git a/tool/common/update/update_darwin.go b/tool/common/update/update_darwin.go new file mode 100644 index 0000000000000..252d40b8c8093 --- /dev/null +++ b/tool/common/update/update_darwin.go @@ -0,0 +1,110 @@ +//go:build darwin + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" +) + +func replace(path string, hash string) error { + dir, err := toolsDir() + if err != nil { + return trace.Wrap(err) + } + + // Use "pkgutil" from the filesystem to expand the archive. In theory .pkg + // files are xz archives, however it's still safer to use "pkgutil" in-case + // Apple makes non-standard changes to the format. + // + // Full command: pkgutil --expand-full NAME.pkg DIRECTORY/ + pkgutil, err := exec.LookPath("pkgutil") + if err != nil { + return trace.Wrap(err) + } + expandPath := filepath.Join(dir, hash+"-pkg") + out, err := exec.Command(pkgutil, "--expand-full", path, expandPath).Output() + if err != nil { + log.Debugf("Failed to run pkgutil: %v.", out) + return trace.Wrap(err) + } + + for _, app := range []string{"tsh", "tctl"} { + // The first time a signed and notarized binary macOS application is run, + // execution is paused while it gets sent to Apple to verify. Once Apple + // approves the binary, the "com.apple.macl" extended attribute is added + // and the process is allowed to execute. This process is not concurrent, any + // other operations (like moving the application) on the application during + // this time will lead to the application being sent SIGKILL. + // + // Since {tsh, tctl} have to be concurrent, execute {tsh, tctl} before + // performing any swap operations. This ensures that the "com.apple.macl" + // extended attribute is set and macOS will not send a SIGKILL to the + // process if multiple processes are trying to operate on it. + expandExecPath := filepath.Join(expandPath, "Payload", app+".app", "Contents", "MacOS", app) + command := exec.Command(expandExecPath, "version", "--client") + command.Env = []string{teleportToolsVersion + "=off"} + if err := command.Run(); err != nil { + return trace.Wrap(err) + } + + // Due to macOS applications not being a single binary (they are a + // directory), atomic operations are not possible. To work around this, use + // a symlink (which can be atomically swapped), then do a cleanup pass + // removing any stale copies of the expanded package. + oldName := filepath.Join(expandPath, "Payload", app+".app", "Contents", "MacOS", app) + newName := filepath.Join(dir, app) + if err := renameio.Symlink(oldName, newName); err != nil { + return trace.Wrap(err) + } + } + + // Perform a cleanup pass to remove any old copies of "{tsh, tctl}.app". + err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + return nil + } + if hash+"-pkg" == info.Name() { + return nil + } + if !strings.HasSuffix(info.Name(), "-pkg") { + return nil + } + + // Found a stale expanded package. + if err := os.RemoveAll(filepath.Join(dir, info.Name())); err != nil { + return err + } + + return nil + }) + + return nil +} diff --git a/tool/common/update/update_helper_test.go b/tool/common/update/update_helper_test.go new file mode 100644 index 0000000000000..2eb1b050e0aa9 --- /dev/null +++ b/tool/common/update/update_helper_test.go @@ -0,0 +1,87 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "net/http" + "sync" +) + +type limitRequest struct { + limit int64 + lock chan struct{} +} + +// limitedResponseWriter wraps http.ResponseWriter and enforces a write limit. +type limitedResponseWriter struct { + requests chan limitRequest +} + +// newLimitedResponseWriter creates a new limitedResponseWriter with the lock. +func newLimitedResponseWriter() *limitedResponseWriter { + lw := &limitedResponseWriter{ + requests: make(chan limitRequest, 10), + } + return lw +} + +// Wrap wraps response writer if limit was previously requested, if not, return original one. +func (lw *limitedResponseWriter) Wrap(w http.ResponseWriter) http.ResponseWriter { + select { + case request := <-lw.requests: + return &wrapper{ + ResponseWriter: w, + request: request, + } + default: + return w + } +} + +func (lw *limitedResponseWriter) SetLimitRequest(limit limitRequest) { + lw.requests <- limit +} + +// wrapper wraps the http response writer to control writing operation by blocking it. +type wrapper struct { + http.ResponseWriter + + written int64 + request limitRequest + released bool + + mutex sync.Mutex +} + +// Write writes data to the underlying ResponseWriter but respects the byte limit. +func (lw *wrapper) Write(p []byte) (int, error) { + lw.mutex.Lock() + defer lw.mutex.Unlock() + + if lw.written >= lw.request.limit && !lw.released { + // Send signal that lock is acquired and wait till it was released by response. + lw.request.lock <- struct{}{} + <-lw.request.lock + lw.released = true + } + + n, err := lw.ResponseWriter.Write(p) + lw.written += int64(n) + return n, err +} diff --git a/tool/common/update/update_linux.go b/tool/common/update/update_linux.go new file mode 100644 index 0000000000000..8e02b9bc407e5 --- /dev/null +++ b/tool/common/update/update_linux.go @@ -0,0 +1,93 @@ +//go:build linux + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "archive/tar" + "compress/gzip" + "errors" + "io" + "os" + "path/filepath" + "strings" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" +) + +func replace(path string, _ string) error { + dir, err := toolsDir() + if err != nil { + return trace.Wrap(err) + } + tempDir := renameio.TempDir(dir) + + f, err := os.Open(path) + if err != nil { + return trace.Wrap(err) + } + + gzipReader, err := gzip.NewReader(f) + if err != nil { + return trace.Wrap(err) + } + + tarReader := tar.NewReader(gzipReader) + for { + header, err := tarReader.Next() + if errors.Is(err, io.EOF) { + break + } + // Skip over any files in the archive that are not {tsh, tctl}. + if header.Name != "teleport/tctl" && + header.Name != "teleport/tsh" && + header.Name != "teleport/tbot" { + if _, err := io.Copy(io.Discard, tarReader); err != nil { + log.Debugf("failed to discard %v: %v.", header.Name, err) + } + continue + } + + // Verify that we have enough space for uncompressed file. + if err := checkFreeSpace(tempDir, uint64(header.Size)); err != nil { + return trace.Wrap(err) + } + + dest := filepath.Join(dir, strings.TrimPrefix(header.Name, "teleport/")) + t, err := renameio.TempFile(tempDir, dest) + if err != nil { + return trace.Wrap(err) + } + if err := os.Chmod(t.Name(), 0755); err != nil { + return trace.Wrap(err) + } + defer t.Cleanup() + + if _, err := io.Copy(t, tarReader); err != nil { + return trace.Wrap(err) + } + if err := t.CloseAtomicallyReplace(); err != nil { + return trace.Wrap(err) + } + } + return nil +} diff --git a/tool/common/update/update_test.go b/tool/common/update/update_test.go new file mode 100644 index 0000000000000..ac7cda653a3f5 --- /dev/null +++ b/tool/common/update/update_test.go @@ -0,0 +1,409 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +const ( + testBinaryName = "updater" +) + +var ( + // testVersions list of the pre-compiled binaries with encoded versions to check. + testVersions = []string{ + "1.2.3", + "3.2.1", + } + limitedWriter = newLimitedResponseWriter() +) + +// TestLock verifies that second lock call is blocked until first is released. +func TestLock(t *testing.T) { + var locked atomic.Bool + + dir := os.TempDir() + + // Acquire first lock should not return any error. + unlock, err := lock(dir) + require.NoError(t, err) + locked.Store(true) + + signal := make(chan struct{}) + errChan := make(chan error) + go func() { + signal <- struct{}{} + unlock, err := lock(dir) + if err != nil { + errChan <- err + } + if locked.Load() { + errChan <- fmt.Errorf("first lock is still acquired, second lock must be blocking") + } + unlock() + signal <- struct{}{} + }() + + <-signal + // We have to wait till next lock is reached to ensure we block execution of goroutine. + // Since this is system call we can't track if the function reach blocking state already. + time.Sleep(100 * time.Millisecond) + locked.Store(false) + unlock() + + select { + case <-signal: + case err := <-errChan: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Errorf("second lock is not released") + } +} + +// TestUpdateInterruptSignal verifies the interrupt signal send to the process must stop downloading. +func TestUpdateInterruptSignal(t *testing.T) { + dir, err := toolsDir() + require.NoError(t, err, "failed to find tools directory") + + err = os.MkdirAll(dir, 0755) + require.NoError(t, err, "failed to create tools directory") + + // Initial fetch the updater binary un-archive and replace. + err = update(testVersions[0]) + require.NoError(t, err) + + var output bytes.Buffer + cmd := exec.Command(filepath.Join(dir, "tsh"), "version") + cmd.Stdout = &output + cmd.Stderr = &output + cmd.Env = append( + os.Environ(), + fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), + ) + err = cmd.Start() + require.NoError(t, err, "failed to start updater") + pid := cmd.Process.Pid + + errChan := make(chan error) + go func() { + errChan <- cmd.Wait() + }() + + // By setting the limit request next test http serving file going blocked until unlock is sent. + lock := make(chan struct{}) + limitedWriter.SetLimitRequest(limitRequest{ + limit: 1024, + lock: lock, + }) + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Errorf("failed to wait till the download is started") + case <-lock: + time.Sleep(100 * time.Millisecond) + require.NoError(t, sendInterrupt(pid)) + lock <- struct{}{} + } + + // Wait till process finished with exit code 0, but we still should get progress + // bar in output content. + select { + case <-time.After(5 * time.Second): + t.Errorf("failed to wait till the process interrupted") + case err := <-errChan: + require.NoError(t, err) + } + require.Contains(t, output.String(), "Update progress:") +} + +func TestUpdate(t *testing.T) { + dir, err := toolsDir() + require.NoError(t, err, "failed to find tools directory") + + err = os.MkdirAll(dir, 0755) + require.NoError(t, err, "failed to create tools directory") + + // Fetch compiled test binary with updater logic and install to $TELEPORT_HOME. + err = update(testVersions[0]) + require.NoError(t, err) + + // Verify that the installed version is equal to requested one. + cmd := exec.Command(filepath.Join(dir, "tsh"), "version") + out, err := cmd.Output() + require.NoError(t, err) + + matches := pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[0], matches[1]) + + // Execute version command again with setting the new version which must + // trigger re-execution of the same command after downloading requested version. + cmd = exec.Command(filepath.Join(dir, "tsh"), "version") + cmd.Env = append( + os.Environ(), + fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), + ) + out, err = cmd.Output() + require.NoError(t, err) + + matches = pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[1], matches[1]) +} + +func TestMain(m *testing.M) { + tmp, err := os.MkdirTemp(os.TempDir(), testBinaryName) + if err != nil { + log.Fatalf("failed to create temporary directory: %v", err) + } + + srv, address := startTestHTTPServer(tmp) + baseUrl = fmt.Sprintf("http://%s", address) + + for _, version := range testVersions { + if err := buildAndArchiveApps(tmp, version, address); err != nil { + log.Fatalf("failed to build testing app binary archive: %v", err) + } + } + + // Run tests after binary is built. + code := m.Run() + + if err := srv.Close(); err != nil { + log.Fatalf("failed to shutdown server: %v", err) + } + if err := os.RemoveAll(tmp); err != nil { + log.Fatalf("failed to remove temporary directory: %v", err) + } + + os.Exit(code) +} + +// serve256File calculates sha256 checksum for requested file. +func serve256File(w http.ResponseWriter, _ *http.Request, filePath string) { + log.Printf("Calculating and serving file checksum: %s\n", filePath) + + w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(filePath)+".sha256\"") + w.Header().Set("Content-Type", "plain/text") + + hash := sha256.New() + file, err := os.Open(filePath) + if err != nil { + http.Error(w, "failed to open file", http.StatusInternalServerError) + return + } + defer file.Close() + + if _, err := io.Copy(hash, file); err != nil { + http.Error(w, "failed to write to hash", http.StatusInternalServerError) + return + } + if _, err := hex.NewEncoder(w).Write(hash.Sum(nil)); err != nil { + http.Error(w, "failed to write checksum", http.StatusInternalServerError) + } +} + +// generateZipFile compresses the file into a `.zip` format. This format intended to be +// used only for windows platform and mocking paths for windows archive. +func generateZipFile(filePath, destPath string) error { + archive, err := os.Create(destPath) + if err != nil { + return trace.Wrap(err) + } + defer archive.Close() + + zipWriter := zip.NewWriter(archive) + defer zipWriter.Close() + + return filepath.Walk(filePath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + zipFileWriter, err := zipWriter.Create(filepath.Base(path)) + if err != nil { + return trace.Wrap(err) + } + + _, err = io.Copy(zipFileWriter, file) + return trace.Wrap(err) + }) +} + +// generateTarGzFile compresses files into a `.tar.gz` format specifically in file +// structure related to linux packaging. +func generateTarGzFile(filePath, destPath string) error { + archive, err := os.Create(destPath) + if err != nil { + return trace.Wrap(err) + } + defer archive.Close() + + gzipWriter := gzip.NewWriter(archive) + defer gzipWriter.Close() + + tarWriter := tar.NewWriter(gzipWriter) + defer tarWriter.Close() + + return filepath.Walk(filePath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + header, err := tar.FileInfoHeader(info, info.Name()) + if err != nil { + return err + } + header.Name = filepath.Join("teleport", filepath.Base(info.Name())) + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + + _, err = io.Copy(tarWriter, file) + return trace.Wrap(err) + }) +} + +// generatePkgFile runs the macOS `pkgbuild` command to generate a .pkg file from the source. +func generatePkgFile(filePath, destPath string) error { + cmd := exec.Command("pkgbuild", + "--root", filePath, + "--identifier", "com.example.pkgtest", + "--version", "1.0", + destPath, + ) + + output, err := cmd.CombinedOutput() + if err != nil { + log.Printf("failed to generate .pkg: %s\n", output) + return err + } + + return nil +} + +// startTestHTTPServer starts the file-serving HTTP server for testing. +func startTestHTTPServer(baseDir string) (*http.Server, string) { + srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filePath := filepath.Join(baseDir, r.URL.Path) + switch { + case strings.HasSuffix(r.URL.Path, ".sha256"): + serve256File(w, r, strings.TrimSuffix(filePath, ".sha256")) + default: + http.ServeFile(limitedWriter.Wrap(w), r, filePath) + } + })} + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + log.Fatalf("failed to create listener: %v", err) + } + + go func() { + if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Printf("failed to start server: %s", err) + } + }() + + return srv, listener.Addr().String() +} + +// buildAndArchiveApps compiles the updater integration and pack it depends on platform is used. +func buildAndArchiveApps(path string, version string, address string) error { + versionPath := filepath.Join(path, version) + for _, app := range []string{"tsh", "tctl"} { + output := filepath.Join(versionPath, app) + switch runtime.GOOS { + case "windows": + output = filepath.Join(versionPath, app+".exe") + case "darwin": + output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app) + } + if err := buildBinary(output, version, address); err != nil { + return trace.Wrap(err) + } + } + switch runtime.GOOS { + case "darwin": + return trace.Wrap(generatePkgFile(versionPath, path+"/tsh-"+version+".pkg")) + case "windows": + return trace.Wrap(generateZipFile(versionPath, path+"/teleport-v"+version+"-windows-amd64-bin.zip")) + case "linux": + return trace.Wrap(generateTarGzFile(versionPath, path+"/teleport-v"+version+"-linux-"+runtime.GOARCH+"-bin.tar.gz")) + default: + return trace.BadParameter("unsupported platform") + } +} + +// buildBinary executes command to build updater integration logic for testing. +func buildBinary(output string, version string, address string) error { + cmd := exec.Command( + "go", "build", "-o", output, + "-ldflags", strings.Join([]string{ + fmt.Sprintf("-X 'main.version=%s'", version), + fmt.Sprintf("-X 'github.com/gravitational/teleport/tool/common/update.baseUrl=http://%s'", address), + }, " "), + "./integration", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return trace.Wrap(cmd.Run()) +} diff --git a/tool/common/update/update_unix.go b/tool/common/update/update_unix.go new file mode 100644 index 0000000000000..dc51d964baddd --- /dev/null +++ b/tool/common/update/update_unix.go @@ -0,0 +1,82 @@ +//go:build !windows + +package update + +import ( + "context" + "errors" + "log/slog" + "os" + "path/filepath" + "syscall" + + "github.com/gravitational/trace" + "golang.org/x/sys/unix" +) + +func lock(dir string) (func(), error) { + ctx := context.Background() + // Build the path to the lock file that will be used by flock. + lockFile := filepath.Join(dir, ".lock") + + // Create the advisory lock using flock. + lf, err := os.OpenFile(lockFile, os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + return nil, trace.Wrap(err) + } + + rc, err := lf.SyscallConn() + if err != nil { + _ = lf.Close() + return nil, trace.Wrap(err) + } + if err := rc.Control(func(fd uintptr) { + err = syscall.Flock(int(fd), syscall.LOCK_EX) + }); err != nil { + _ = lf.Close() + return nil, trace.Wrap(err) + } + + return func() { + rc, err := lf.SyscallConn() + if err != nil { + _ = lf.Close() + slog.DebugContext(ctx, "failed to acquire syscall connection", "error", err) + return + } + if err := rc.Control(func(fd uintptr) { + err = syscall.Flock(int(fd), syscall.LOCK_UN) + }); err != nil { + slog.DebugContext(ctx, "failed to unlock file", "file", lockFile, "error", err) + } + if err := lf.Close(); err != nil { + slog.DebugContext(ctx, "failed to close lock file", "file", lockFile, "error", err) + } + }, nil +} + +// sendInterrupt sends a SIGINT to the process. +func sendInterrupt(pid int) error { + err := syscall.Kill(pid, syscall.SIGINT) + if errors.Is(err, syscall.ESRCH) { + return trace.BadParameter("can't find the process: %v", pid) + } + return trace.Wrap(err) +} + +// freeDiskWithReserve returns the available disk space. +func freeDiskWithReserve(dir string) (uint64, error) { + var stat unix.Statfs_t + err := unix.Statfs(dir, &stat) + if err != nil { + return 0, trace.Wrap(err) + } + if stat.Bsize < 0 { + return 0, trace.Errorf("invalid size") + } + avail := stat.Bavail * uint64(stat.Bsize) + if reservedFreeDisk > avail { + return 0, trace.Errorf("no free space left") + } + return avail - reservedFreeDisk, nil +} diff --git a/tool/common/update/update_windows.go b/tool/common/update/update_windows.go new file mode 100644 index 0000000000000..66c2a68093c8b --- /dev/null +++ b/tool/common/update/update_windows.go @@ -0,0 +1,167 @@ +//go:build windows + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package update + +import ( + "archive/zip" + "context" + "io" + "log/slog" + "os" + "path/filepath" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" + + "github.com/gravitational/trace" +) + +var ( + kernel = windows.NewLazyDLL("kernel32.dll") + proc = kernel.NewProc("CreateFileW") + ctrlEvent = kernel.NewProc("GenerateConsoleCtrlEvent") +) + +func replace(path string, hash string) error { + f, err := os.Open(path) + if err != nil { + return trace.Wrap(err) + } + fi, err := f.Stat() + if err != nil { + return trace.Wrap(err) + } + zipReader, err := zip.NewReader(f, fi.Size()) + if err != nil { + return trace.Wrap(err) + } + + dir, err := toolsDir() + if err != nil { + return trace.Wrap(err) + } + tempDir, err := os.MkdirTemp(dir, hash) + if err != nil { + return trace.Wrap(err) + } + + for _, r := range zipReader.File { + // Skip over any files in the archive that are not {tsh, tctl}. + if r.Name != "tsh.exe" && r.Name != "tctl.exe" { + continue + } + + // Verify that we have enough space for uncompressed file. + if err := checkFreeSpace(tempDir, r.UncompressedSize64); err != nil { + return trace.Wrap(err) + } + + rr, err := r.Open() + if err != nil { + return trace.Wrap(err) + } + defer rr.Close() + + dest := filepath.Join(tempDir, r.Name) + t, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) + if err != nil { + return trace.Wrap(err) + } + defer t.Close() + + if _, err := io.Copy(t, rr); err != nil { + return trace.Wrap(err) + } + os.Remove(filepath.Join(dir, r.Name)) + if err := os.Symlink(dest, filepath.Join(dir, r.Name)); err != nil { + return trace.Wrap(err) + } + } + + return nil +} + +// lock implements locking mechanism for blocking another process acquire the lock until its released. +func lock(dir string) (func(), error) { + path := filepath.Join(dir, ".lock") + lockPath, err := windows.UTF16PtrFromString(path) + if err != nil { + return nil, trace.Wrap(err) + } + + var lockFile *os.File + for lockFile == nil { + fd, _, err := proc.Call( + uintptr(unsafe.Pointer(lockPath)), + uintptr(windows.GENERIC_READ|windows.GENERIC_WRITE), + // Exclusive lock, for shared must be used: uintptr(windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE). + uintptr(0), + uintptr(0), + uintptr(windows.OPEN_ALWAYS), + uintptr(windows.FILE_ATTRIBUTE_NORMAL), + 0, + ) + switch err.(windows.Errno) { + case windows.NO_ERROR, windows.ERROR_ALREADY_EXISTS: + lockFile = os.NewFile(fd, path) + case windows.ERROR_SHARING_VIOLATION: + // if the file is locked by another process we have to wait until the next check. + time.Sleep(time.Second) + default: + windows.CloseHandle(windows.Handle(fd)) + return nil, trace.Wrap(err) + } + } + + if err := windows.SetHandleInformation(windows.Handle(lockFile.Fd()), windows.HANDLE_FLAG_INHERIT, 1); err != nil { + return nil, trace.Wrap(err) + } + + return func() { + if err := lockFile.Close(); err != nil { + slog.DebugContext(context.Background(), "failed to close lock file", "file", lockFile, "error", err) + } + }, nil +} + +// sendInterrupt sends a Ctrl-Break event to the process. +func sendInterrupt(pid int) error { + r, _, err := ctrlEvent.Call(uintptr(syscall.CTRL_BREAK_EVENT), uintptr(pid)) + if r == 0 { + return trace.Wrap(err) + } + return nil +} + +// freeDiskWithReserve returns the available disk space. +func freeDiskWithReserve(dir string) (uint64, error) { + var avail uint64 + err := windows.GetDiskFreeSpaceEx(windows.StringToUTF16Ptr(dir), &avail, nil, nil) + if err != nil { + return 0, trace.Wrap(err) + } + if reservedFreeDisk > avail { + return 0, trace.Errorf("no free space left") + } + return avail - reservedFreeDisk, nil +} diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 996449f8fc5e6..7f5a79aff37c7 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -54,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/common" + "github.com/gravitational/teleport/tool/common/update" ) const ( @@ -104,6 +105,27 @@ type CLICommand interface { // // distribution: name of the Teleport distribution func Run(commands []CLICommand) { + // At process startup, check if a version has already been downloaded to + // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION + // environment variable. If so, re-exec that version of {tsh, tctl}. + toolsVersion, reexec := update.CheckLocal() + if reexec { + // Download the version of client tools required by the cluster. This + // is required if the user passed in the TELEPORT_TOOLS_VERSION + // explicitly. + if err := update.Download(toolsVersion); err != nil { + utils.FatalError(err) + } + + // Re-execute client tools with the correct version of client tools. + code, err := update.Exec() + if err != nil { + utils.FatalError(err) + } else { + os.Exit(code) + } + } + err := TryRun(commands, os.Args[1:]) if err != nil { var exitError *common.ExitCodeError diff --git a/tool/todo.txt b/tool/todo.txt new file mode 100644 index 0000000000000..d71e3bc5a78a5 --- /dev/null +++ b/tool/todo.txt @@ -0,0 +1,6 @@ +1. Windows support. File locking. +2. tbot support. +3. Errors and warnings for self-managed updates. +4. Different editions of Teleport. OSS should download OSS, Enterprise should download Enterprise. +5. Test migration path. +6. Testing. diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 500dcab53c79f..1abaa6918b286 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/tool/common" "github.com/gravitational/teleport/tool/common/fido2" "github.com/gravitational/teleport/tool/common/touchid" + "github.com/gravitational/teleport/tool/common/update" "github.com/gravitational/teleport/tool/common/webauthnwin" ) @@ -652,6 +653,7 @@ const ( proxyKubeConfigEnvVar = "TELEPORT_KUBECONFIG" noResumeEnvVar = "TELEPORT_NO_RESUME" requestModeEnvVar = "TELEPORT_REQUEST_MODE" + toolsVersionEnvVar = "TELEPORT_TOOLS_VERSION" clusterHelp = "Specify the Teleport cluster to connect" browserHelp = "Set to 'none' to suppress browser opening on login" @@ -695,6 +697,29 @@ func initLogger(cf *CLIConf) { // // DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions). func Run(ctx context.Context, args []string, opts ...CliOption) error { + // At process startup, check if a version has already been downloaded to + // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION + // environment variable. If so, re-exec that version of {tsh, tctl}. + toolsVersion, reexec := update.CheckLocal() + if reexec { + // Download the version of client tools required by the cluster. This + // is required if the user passed in the TELEPORT_TOOLS_VERSION + // explicitly. + if err := update.Download(toolsVersion); err != nil { + return trace.Wrap(err) + } + + // Re-execute client tools with the correct version of client tools. + code, err := update.Exec() + if err != nil { + log.Debugf("Failed to re-exec client tool: %v.", err) + // TODO(russjones): Is 255 the correct error code here? + os.Exit(255) + } else { + os.Exit(code) + } + } + cf := CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), @@ -1832,7 +1857,15 @@ func onLogin(cf *CLIConf) error { } tc.HomePath = cf.HomePath - // client is already logged in and profile is not expired + // The user is not logged in and has typed in `tsh --proxy=... login`, if + // the running binary needs to be updated, update and re-exec. + if profile == nil { + if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil { + return trace.Wrap(err) + } + } + + // The user is already logged in and the profile is not expired. if profile != nil && !profile.IsExpired(time.Now()) { switch { // in case if nothing is specified, re-fetch kube clusters and print @@ -1842,6 +1875,13 @@ func onLogin(cf *CLIConf) error { // current status case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "" || host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "": + + // The user has typed `tsh login`, if the running binary needs to + // be updated, update and re-exec. + if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil { + return trace.Wrap(err) + } + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1855,6 +1895,13 @@ func onLogin(cf *CLIConf) error { // if the proxy names match but nothing else is specified; show motd and update active profile and kube configs case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "": + + // The user has typed `tsh login`, if the running binary needs to + // be updated, update and re-exec. + if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil { + return trace.Wrap(err) + } + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1880,6 +1927,7 @@ func onLogin(cf *CLIConf) error { // but cluster is specified, treat this as selecting a new cluster // for the same proxy case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "": + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1910,6 +1958,7 @@ func onLogin(cf *CLIConf) error { // but desired roles or request ID is specified, treat this as a // privilege escalation request for the same login session. case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "": + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1925,7 +1974,11 @@ func onLogin(cf *CLIConf) error { // otherwise just pass through to standard login default: - + // The user is logged in and has typed in `tsh --proxy=... login`, if + // the running binary needs to be updated, update and re-exec. + if err := updateAndRun(context.Background(), tc.WebProxyAddr); err != nil { + return trace.Wrap(err) + } } } @@ -5411,6 +5464,30 @@ const ( "https://goteleport.com/docs/access-controls/guides/headless/#troubleshooting" ) +func updateAndRun(ctx context.Context, proxy string) error { + // If needed, download the new version of {tsh, tctl} and re-exec. Make + // sure to exit this process with the same exit code as the child process. + toolsVersion, reexec, err := update.CheckRemote(ctx, proxy) + if err != nil { + return trace.Wrap(err) + } + if reexec { + // Download the version of client tools required by the cluster. + if err := update.Download(toolsVersion); err != nil { + return trace.Wrap(err) + } + + // Re-execute client tools with the correct version of client tools. + code, err := update.Exec() + if err != nil { + return trace.Wrap(err) + } + os.Exit(code) + } + + return nil +} + // Lock the process memory to prevent rsa keys and certificates in memory from being exposed in a swap. func tryLockMemory(cf *CLIConf) error { if cf.MlockMode == mlockModeAuto {