diff --git a/integration/autoupdate/tools/connect_privileged_updater_windows_test.go b/integration/autoupdate/tools/connect_privileged_updater_windows_test.go new file mode 100644 index 0000000000000..a1930c77f4982 --- /dev/null +++ b/integration/autoupdate/tools/connect_privileged_updater_windows_test.go @@ -0,0 +1,453 @@ +/* + * Teleport + * Copyright (C) 2026 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/Microsoft/go-winio" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/teleterm/autoupdate/privilegedupdater" +) + +func TestPrivilegedUpdateServiceSuccess(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + err := runPrivilegedUpdaterFlow(t, up) + require.NoError(t, err) +} + +func TestPrivilegedUpdateServiceRejectsDowngrade(t *testing.T) { + up := update{ + // The version is a downgrade compared to the current api.Version. + version: "0.0.1", + binary: []byte("payload"), + } + err := runPrivilegedUpdaterFlow(t, up) + require.ErrorIs(t, err, trace.BadParameter("update version 0.0.1 is not newer than current version %s", teleport.SemVer())) +} + +func TestPrivilegedUpdateServiceRejectsChecksumMismatch(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + + otherHash := sha256.Sum256([]byte("different-payload")) + err := runPrivilegedUpdaterFlow(t, up, withChecksumServerResponseWriter(func(w http.ResponseWriter) { + _, err := w.Write([]byte(hex.EncodeToString(otherHash[:]))) + require.NoError(t, err) + })) + require.ErrorIs(t, err, trace.BadParameter("hash of the update does not match downloaded checksum")) +} + +func TestPrivilegedUpdateServiceRejectsInvalidVersionFormat(t *testing.T) { + up := update{ + version: "not-a-semver", + binary: []byte("payload"), + } + err := runPrivilegedUpdaterFlow(t, up) + require.Error(t, err) + require.Contains(t, err.Error(), `invalid update version "not-a-semver"`) +} + +func TestPrivilegedUpdateServiceRejectsChecksumRequestFailure(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + + err := runPrivilegedUpdaterFlow(t, up, withChecksumServerResponseWriter(func(w http.ResponseWriter) { + http.Error(w, "failure", http.StatusInternalServerError) + })) + + require.Error(t, err) + require.Contains(t, err.Error(), "downloading update checksum") +} + +func TestPrivilegedUpdateServicePolicyOffRejectsUpdate(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + err := runPrivilegedUpdaterFlow(t, up, withServiceTestPolicyToolsVersion("off")) + require.Error(t, err) + require.ErrorIs(t, err, trace.AccessDenied(`ToolsVersion in HKLM\SOFTWARE\Policies\Teleport\TeleportConnect is "off", automatic updates are disabled by system policy`)) +} + +func TestPrivilegedUpdateServicePolicyVersionMismatch(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + err := runPrivilegedUpdaterFlow(t, up, withServiceTestPolicyToolsVersion("999.0.1")) + require.ErrorIs(t, err, trace.BadParameter("update version 999.0.0 does not match policy version 999.0.1")) +} + +func TestPrivilegedUpdateServiceRejectsMalformedMetadata(t *testing.T) { + cfg := getDefaultConfig(t) + + serviceErr := make(chan error, 1) + go func() { + serviceErr <- privilegedupdater.RunServiceTest(t.Context(), cfg) + }() + + conn := dialUpdaterPipe(t, 5*time.Second) + defer conn.Close() + + // Send malformed JSON metadata. + malformedMetadata := []byte("{") + require.NoError(t, binary.Write(conn, binary.LittleEndian, uint32(len(malformedMetadata)))) + n, err := conn.Write(malformedMetadata) + require.NoError(t, err) + require.Len(t, malformedMetadata, n) + require.NoError(t, conn.Close()) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + select { + case err := <-serviceErr: + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal update metadata") + case <-ctx.Done(): + t.Fatal("timed out") + } +} + +func TestPrivilegedUpdateServiceRejectsUpdateBaseDirFile(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + + baseDir := filepath.Join(t.TempDir(), "not-a-dir") + require.NoError(t, os.WriteFile(baseDir, []byte("x"), 0o600)) + + err := runPrivilegedUpdaterFlow(t, up, withServiceTestUpdateBaseDir(baseDir)) + require.ErrorIs(t, err, trace.BadParameter("security violation: %s exists but is not a directory", baseDir)) +} + +func TestPrivilegedUpdateServiceRejectsUpdateBaseDirReparsePoint(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + + targetDir := t.TempDir() + baseDir := filepath.Join(t.TempDir(), "junction-base") + createJunction(t, baseDir, targetDir) + + err := runPrivilegedUpdaterFlow(t, up, withServiceTestUpdateBaseDir(baseDir)) + require.ErrorIs(t, err, trace.BadParameter("security violation: %s is a reparse point", baseDir)) +} + +func TestPrivilegedUpdateServiceSafelyCleanupOldUpdates(t *testing.T) { + updateBaseDir := t.TempDir() + outsideDir := t.TempDir() + outsideFile := filepath.Join(outsideDir, "must-stay.txt") + require.NoError(t, os.WriteFile(outsideFile, []byte("outside"), 0o600)) + + staleDir := filepath.Join(updateBaseDir, "stale-update") + require.NoError(t, os.MkdirAll(staleDir, 0o700)) + require.NoError(t, os.WriteFile(filepath.Join(staleDir, "update.exe"), []byte("stale"), 0o600)) + + junctionPath := filepath.Join(updateBaseDir, "outside-junction") + createJunction(t, junctionPath, outsideDir) + + updateBinary := []byte("payload") + up := update{ + version: "999.0.0", + binary: updateBinary, + } + err := runPrivilegedUpdaterFlow(t, up, withServiceTestUpdateBaseDir(updateBaseDir)) + require.NoError(t, err) + + _, err = os.Stat(staleDir) + require.ErrorIs(t, err, os.ErrNotExist, "stale update directory should be removed") + + _, err = os.Lstat(junctionPath) + require.ErrorIs(t, err, os.ErrNotExist, "junction entry should be removed") + + _, err = os.Stat(outsideFile) + require.NoError(t, err, "cleanup must not remove files outside base dir via junction traversal") +} + +func TestPrivilegedUpdateServiceCorrectsUpdateBaseDirACL(t *testing.T) { + up := update{ + version: "999.0.0", + binary: []byte("payload"), + } + + defaultConfig := getDefaultConfig(t) + baseDir := filepath.Join(t.TempDir(), "new-dir") + require.NoError(t, os.MkdirAll(baseDir, 0o777)) + // Everyone has Full Control over this object, + // and the permission is inherited by all subfolders and files. + // This access will be corrected by the service. + setDirectoryDACL(t, baseDir, "D:(A;OICI;GA;;;WD)") + + err := runPrivilegedUpdaterFlow(t, up, withServiceTestUpdateBaseDir(baseDir)) + require.NoError(t, err) + + assertDirectorySecurityDescriptor(t, baseDir, defaultConfig.UpdateDirSecurityDescriptor) +} + +func TestPrivilegedUpdateServiceAllowOnlyOneClientConnection(t *testing.T) { + serviceErr := make(chan error, 1) + cfg := getDefaultConfig(t) + + go func() { + serviceErr <- privilegedupdater.RunServiceTest(t.Context(), cfg) + }() + + // First client connects and keeps the pipe open. This blocks the service in readUpdate. + firstConn := dialUpdaterPipe(t, 2*time.Second) + + // Second client should fail because waitForSingleClient closes the listener after first accept. + clientCtx2, cancel2 := context.WithTimeout(t.Context(), 2*time.Second) + t.Cleanup(cancel2) + secondConn, err := winio.DialPipeAccess(clientCtx2, privilegedupdater.PipePath, privilegedupdater.SafePipeReadWriteAccess) + if secondConn != nil { + _ = secondConn.Close() + } + require.Error(t, err, "second client unexpectedly connected") + + // Let the service exit cleanly from the blocked read path. + require.NoError(t, firstConn.Close()) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + select { + case err := <-serviceErr: + require.Error(t, err) + case <-ctx.Done(): + t.Fatal("timed out") + } +} + +type serviceConfig struct { + privilegedupdater.ServiceTestConfig + checksumServerResponseWriter func(http.ResponseWriter) +} + +type privilegedServiceMainConfigOption func(*serviceConfig) + +func withServiceTestUpdateBaseDir(path string) privilegedServiceMainConfigOption { + return func(cfg *serviceConfig) { + cfg.UpdateBaseDir = path + } +} + +func withChecksumServerResponseWriter(checksumResponseWriter func(w http.ResponseWriter)) privilegedServiceMainConfigOption { + return func(cfg *serviceConfig) { + cfg.checksumServerResponseWriter = checksumResponseWriter + } +} + +func withServiceTestPolicyToolsVersion(version string) privilegedServiceMainConfigOption { + return func(cfg *serviceConfig) { + cfg.PolicyToolsVersion = version + } +} + +type update struct { + version string + binary []byte +} + +// runPrivilegedUpdaterFlow runs the service implementation and sends the update via the named pipe. +func runPrivilegedUpdaterFlow(t *testing.T, update update, opts ...privilegedServiceMainConfigOption) error { + t.Helper() + + defaultCfg := getDefaultConfig(t) + cfg := &serviceConfig{ + ServiceTestConfig: *defaultCfg, + } + for _, opt := range opts { + opt(cfg) + } + + checksumPath := "/Teleport Connect Setup-" + update.version + ".exe.sha256" + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != checksumPath { + http.NotFound(w, r) + return + } + if cfg.checksumServerResponseWriter != nil { + cfg.checksumServerResponseWriter(w) + } else { + hash := sha256.Sum256(update.binary) + // By default, return a checksum for the passed file. + _, _ = w.Write([]byte(hex.EncodeToString(hash[:]))) + } + })) + t.Cleanup(server.Close) + + payloadPath := filepath.Join(t.TempDir(), "client-update.exe") + require.NoError(t, os.WriteFile(payloadPath, update.binary, 0o600)) + + serviceErr := make(chan error, 1) + installUpdateFromClientErr := make(chan error, 1) + go func() { + err := privilegedupdater.RunServiceTest(t.Context(), &privilegedupdater.ServiceTestConfig{ + UpdateDirSecurityDescriptor: cfg.UpdateDirSecurityDescriptor, + UpdateBaseDir: cfg.UpdateBaseDir, + PolicyToolsVersion: cfg.PolicyToolsVersion, + PolicyCDNBaseURL: server.URL, + HTTPClient: server.Client(), + PipeAuthenticatedUsersAccess: cfg.PipeAuthenticatedUsersAccess, + }) + // We are attempting to run a non-exe file. + // It will fail, so we check if we ran the correct file. + // The pattern should match: \\update.exe. + // In the production code, base-update-dir is %ProgramData%\TeleportConnectUpdater. + if err != nil && strings.Contains(err.Error(), "running installer") { + pattern := fmt.Sprintf( + `.*starting installer path=%s\\[0-9a-fA-F-]{36}\\update\.exe`, + regexp.QuoteMeta(cfg.UpdateBaseDir), + ) + require.Regexp(t, pattern, err.Error()) + require.Contains(t, err.Error(), "args=\"--updated /S /allusers\"") + serviceErr <- nil + return + } + serviceErr <- err + }() + go func() { + installUpdateFromClientErr <- privilegedupdater.InstallUpdateFromClient(t.Context(), payloadPath, false, update.version) + }() + + for i := 0; i < 2; i++ { + select { + case err := <-serviceErr: + return err + case err := <-installUpdateFromClientErr: + if err != nil { + return err + } + case <-t.Context().Done(): + t.Fatal("timed out") + return nil + } + } + return nil +} + +func dialUpdaterPipe(t *testing.T, timeout time.Duration) net.Conn { + t.Helper() + + var conn net.Conn + err := retryutils.RetryStaticFor(timeout, 25*time.Millisecond, func() error { + c, err := winio.DialPipeAccess(t.Context(), privilegedupdater.PipePath, privilegedupdater.SafePipeReadWriteAccess) + if err != nil { + return err + } + conn = c + return nil + }) + require.NoError(t, err, "failed to connect to updater pipe before timeout") + return conn +} + +// getDefaultConfig returns a base dir and a security descriptor. +func getDefaultConfig(t *testing.T) *privilegedupdater.ServiceTestConfig { + t.Helper() + + token := windows.GetCurrentProcessToken() + tokenUser, err := token.GetTokenUser() + require.NoError(t, err) + require.NotNil(t, tokenUser.User.Sid) + + ownerSID := tokenUser.User.Sid.String() + + // We can't use the production security descriptor as it requires the process to run with elevated privileges. + // Here we create a descriptor that restrict a bit the regular rights for authenticated users. + descriptor := "O:" + ownerSID + + "D:P" + + "(A;;FA;;;SY)" + + "(A;;FA;;;BA)" + + "(A;OICI;0x1301bf;;;AU)" // 0x1301bf - modify rights for AU (authenticated users) for dir and sub dirs (OICI) + + return &privilegedupdater.ServiceTestConfig{ + UpdateDirSecurityDescriptor: descriptor, + UpdateBaseDir: t.TempDir(), + // Allow Authenticated Users to create the pipe in tests. + PipeAuthenticatedUsersAccess: windows.GENERIC_READ | windows.GENERIC_WRITE, + } +} + +func createJunction(t *testing.T, linkPath, targetPath string) { + t.Helper() + + cmd := exec.Command("cmd", "/c", "mklink", "/J", linkPath, targetPath) + _, err := cmd.CombinedOutput() + require.NoError(t, err) +} + +func assertDirectorySecurityDescriptor(t *testing.T, path string, expectedDescriptor string) { + t.Helper() + + actualSD, err := windows.GetNamedSecurityInfo(path, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION|windows.DACL_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION) + require.NoError(t, err) + + expectedSD, err := windows.SecurityDescriptorFromString(expectedDescriptor) + require.NoError(t, err) + + // Comparing ACLs is non-trivial. + // + // In SDDL, "D:" starts the DACL section. + // "D:P" means the DACL is protected (no inheritance). + // After ACL changes, Windows may apply "D:PAI", where "AI" indicates + // auto-inherited ACEs. The descriptors are functionally equivalent + // for our purposes, so normalize before comparison. + expectedSDString := strings.Replace(expectedSD.String(), "D:P", "D:PAI", 1) + require.Equal(t, expectedSDString, actualSD.String(), "directory DACL does not match expected descriptor") +} + +func setDirectoryDACL(t *testing.T, path string, descriptor string) { + t.Helper() + + sd, err := windows.SecurityDescriptorFromString(descriptor) + require.NoError(t, err) + dacl, _, err := sd.DACL() + require.NoError(t, err) + + err = windows.SetNamedSecurityInfo(path, windows.SE_FILE_OBJECT, windows.DACL_SECURITY_INFORMATION, nil, nil, dacl, nil) + require.NoError(t, err) +} diff --git a/lib/teleterm/autoupdate/common/config.go b/lib/teleterm/autoupdate/common/config.go new file mode 100644 index 0000000000000..11c74968ae2cf --- /dev/null +++ b/lib/teleterm/autoupdate/common/config.go @@ -0,0 +1,35 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "github.com/gravitational/teleport/lib/autoupdate" + "github.com/gravitational/teleport/lib/modules" +) + +// TeleportToolsVersionOff indicates that managed updates are disabled ("off"). +const TeleportToolsVersionOff = "off" + +// GetDefaultBaseURL returns the default base URL used to download artifacts. +func GetDefaultBaseURL() string { + m := modules.GetModules() + // Uses the same logic as the teleport/lib/autoupdate/tools package. + if m.BuildType() != modules.BuildOSS { + return autoupdate.DefaultBaseURL + } + return "" +} diff --git a/lib/teleterm/autoupdate/common/registry_windows.go b/lib/teleterm/autoupdate/common/registry_windows.go new file mode 100644 index 0000000000000..ab853fa59c9c3 --- /dev/null +++ b/lib/teleterm/autoupdate/common/registry_windows.go @@ -0,0 +1,86 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + + "github.com/gravitational/trace" + "golang.org/x/sys/windows/registry" +) + +const ( + // TeleportConnectPoliciesKeyPath is the Windows registry path for Teleport Connect policy settings. + TeleportConnectPoliciesKeyPath = `SOFTWARE\Policies\Teleport\TeleportConnect` + // RegistryValueToolsVersion is the policy value name that pins the managed tools version. + RegistryValueToolsVersion = "ToolsVersion" + // RegistryValueCDNBaseURL is the policy value name that configures the managed update CDN base URL. + RegistryValueCDNBaseURL = "CdnBaseUrl" +) + +// PolicyValues defines the managed update policy configuration. +type PolicyValues struct { + // CDNBaseURL is the base URL used to download artifacts. + CDNBaseURL string + // Version specifies the enforced application version. + Version string +} + +// ReadRegistryPolicyValues reads system policy values (tools version and CDN base URL) for Teleport Connect. +func ReadRegistryPolicyValues(key registry.Key) (*PolicyValues, error) { + version, err := ReadRegistryValue(key, TeleportConnectPoliciesKeyPath, RegistryValueToolsVersion) + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + url, err := ReadRegistryValue(key, TeleportConnectPoliciesKeyPath, RegistryValueCDNBaseURL) + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + return &PolicyValues{ + CDNBaseURL: url, + Version: version, + }, nil +} + +// ReadRegistryValue reads a registry value. +func ReadRegistryValue(hive registry.Key, pathName string, valueName string) (path string, err error) { + key, err := registry.OpenKey(hive, pathName, registry.READ) + if err != nil { + if errors.Is(err, registry.ErrNotExist) { + return "", trace.NotFound("registry key %s not found", pathName) + } + return "", trace.Wrap(err, "opening registry key %s", pathName) + } + + defer func() { + if closeErr := key.Close(); closeErr != nil && err == nil { + err = trace.Wrap(closeErr, "closing registry key %s", pathName) + } + }() + + path, _, err = key.GetStringValue(valueName) + if err != nil { + if errors.Is(err, registry.ErrNotExist) { + return "", trace.NotFound("registry value %s not found in %s", valueName, pathName) + } + return "", trace.Wrap(err, "reading registry value %s from %s", valueName, pathName) + } + + return path, nil +} diff --git a/lib/teleterm/autoupdate/privilegedupdater/client_windows.go b/lib/teleterm/autoupdate/privilegedupdater/client_windows.go new file mode 100644 index 0000000000000..0877499075fa2 --- /dev/null +++ b/lib/teleterm/autoupdate/privilegedupdater/client_windows.go @@ -0,0 +1,164 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package privilegedupdater + +import ( + "context" + "errors" + "net" + "os" + "syscall" + "time" + + "github.com/Microsoft/go-winio" + "github.com/gravitational/trace" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "github.com/gravitational/teleport/api/utils/retryutils" +) + +const ( + serviceStartTimeout = 5 * time.Second + serviceStartRetryStep = 500 * time.Millisecond + serviceStartRetryMax = 500 * time.Millisecond + pipeDialTimeout = 3 * time.Second + pipeDialRetryStep = 100 * time.Millisecond + pipeDialRetryMax = 300 * time.Millisecond +) + +// RunServiceAndInstallUpdateFromClient is called by the client. +// It starts the update service, sends update metadata, and transfers the binary for validation and installation. +func RunServiceAndInstallUpdateFromClient(ctx context.Context, path string, forceRun bool, version string) error { + if err := ensureServiceRunning(ctx); err != nil { + // Service failed to start; fall back to client-side install (UAC). + if installErr := runInstaller(path, forceRun); installErr != nil { + return trace.Wrap(installErr, "fallback install failed after service start error: %v", err) + } + return nil + } + + err := InstallUpdateFromClient(ctx, path, forceRun, version) + return trace.Wrap(err) +} + +// InstallUpdateFromClient sends update metadata, and transfers the binary for validation and installation. +func InstallUpdateFromClient(ctx context.Context, path string, forceRun bool, version string) error { + conn, err := dialPipeWithRetry(ctx, PipePath) + if err != nil { + return trace.Wrap(err) + } + defer conn.Close() + + // The update must be read by the client running as a standard user. + // Passing the path directly to the SYSTEM service could cause it to read + // files the user is not permitted to access. + file, err := os.Open(path) + if err != nil { + return trace.Wrap(err) + } + defer file.Close() + + meta := updateMetadata{ForceRun: forceRun, Version: version} + return trace.Wrap(writeUpdate(conn, meta, file)) +} + +func dialPipeWithRetry(ctx context.Context, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, pipeDialTimeout) + defer cancel() + linearRetry, err := retryutils.NewLinear(retryutils.LinearConfig{ + Step: pipeDialRetryStep, + Max: pipeDialRetryMax, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + isRetryError := func(err error) bool { + return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) + } + + var conn net.Conn + err = linearRetry.For(ctx, func() error { + conn, err = winio.DialPipeAccess(ctx, path, uint32(SafePipeReadWriteAccess)) + if err != nil && !isRetryError(err) { + return retryutils.PermanentRetryError(trace.Wrap(err)) + } + return trace.Wrap(err) + }) + if err != nil { + return nil, trace.Wrap(err) + } + return conn, nil +} + +func ensureServiceRunning(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, serviceStartTimeout) + defer cancel() + // Avoid [mgr.Connect] because it requests elevated permissions. + scManager, err := windows.OpenSCManager(nil /*machine*/, nil /*database*/, windows.SC_MANAGER_CONNECT) + if err != nil { + return trace.Wrap(err, "opening Windows service manager") + } + defer windows.CloseServiceHandle(scManager) + serviceNamePtr, err := syscall.UTF16PtrFromString(serviceName) + if err != nil { + return trace.Wrap(err, "converting service name to UTF16") + } + serviceHandle, err := windows.OpenService(scManager, serviceNamePtr, serviceAccessFlags) + if err != nil { + return trace.Wrap(err, "opening Windows service %v", serviceName) + } + service := &mgr.Service{ + Name: serviceName, + Handle: serviceHandle, + } + defer service.Close() + + status, err := service.Query() + if err != nil { + return trace.Wrap(err, "querying service status") + } + if status.State == svc.Running { + return nil + } + + if err = service.Start(ServiceCommand); err != nil { + return trace.Wrap(err, "starting Windows service %s", serviceName) + } + + linearRetry, err := retryutils.NewLinear(retryutils.LinearConfig{ + Step: serviceStartRetryStep, + Max: serviceStartRetryMax, + }) + if err != nil { + return trace.Wrap(err) + } + + err = linearRetry.For(ctx, func() error { + status, err = service.Query() + if err != nil { + return retryutils.PermanentRetryError(trace.Wrap(err)) + } + if status.State != svc.Running { + return trace.Errorf("service not running yet") + } + return nil + }) + return trace.Wrap(err) +} diff --git a/lib/teleterm/autoupdate/privilegedupdater/protocol_windows.go b/lib/teleterm/autoupdate/privilegedupdater/protocol_windows.go new file mode 100644 index 0000000000000..324e0124e18a9 --- /dev/null +++ b/lib/teleterm/autoupdate/privilegedupdater/protocol_windows.go @@ -0,0 +1,107 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package privilegedupdater + +import ( + "encoding/binary" + "encoding/json" + "io" + "os" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/utils" +) + +const ( + PipePath = `\\.\pipe\TeleportConnectUpdaterPipe` + maxUpdateMetadataSize = 1 * 1024 * 1024 // 1 MiB + maxUpdatePayloadSize = 1 * 1024 * 1024 * 1024 // 1 GiB +) + +type updateMetadata struct { + // ForceRun determines whether to run the app after installing the update. + ForceRun bool `json:"force_run"` + // Version is update version. + Version string `json:"version"` +} + +// writeUpdate writes an update stream in the following order: +// 1. An uint32 specifying the length of the updateMetadata header. +// 2. The updateMetadata header of the specified length. +// 3. The update binary, read until EOF. +func writeUpdate(conn io.Writer, meta updateMetadata, file io.Reader) error { + if meta.Version == "" { + return trace.BadParameter("update version is required") + } + + metaBytes, err := json.Marshal(meta) + if err != nil { + return trace.Wrap(err) + } + if len(metaBytes) > maxUpdateMetadataSize { + return trace.BadParameter("update metadata payload too large") + } + + if err = binary.Write(conn, binary.LittleEndian, uint32(len(metaBytes))); err != nil { + return trace.Wrap(err) + } + if _, err = conn.Write(metaBytes); err != nil { + return trace.Wrap(err) + } + + _, err = io.Copy(conn, file) + return trace.Wrap(err) +} + +// readUpdate reads an update stream in the following order: +// 1. An uint32 specifying the length of the updateMetadata header. +// 2. The updateMetadata header of the specified length. +// 3. The update binary, read until EOF. +// +// It writes the installer to destinationPath and returns the parsed metadata. +func readUpdate(conn io.Reader, destinationPath string) (*updateMetadata, error) { + var jsonLen uint32 + if err := binary.Read(conn, binary.LittleEndian, &jsonLen); err != nil { + return nil, trace.Wrap(err) + } + if jsonLen > maxUpdateMetadataSize { + return nil, trace.BadParameter("update metadata payload too large") + } + + buf := make([]byte, jsonLen) + _, err := io.ReadFull(conn, buf) + if err != nil { + return nil, trace.Wrap(err) + } + meta := &updateMetadata{} + if err = json.Unmarshal(buf, meta); err != nil { + return nil, trace.Wrap(err, "failed to unmarshal update metadata") + } + if meta.Version == "" { + return nil, trace.BadParameter("update version is required") + } + + outFile, err := os.OpenFile(destinationPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) + if err != nil { + return nil, trace.Wrap(err) + } + + payloadReader := utils.LimitReader(conn, maxUpdatePayloadSize) + _, err = io.Copy(outFile, payloadReader) + return meta, trace.NewAggregate(err, outFile.Close()) +} diff --git a/lib/teleterm/autoupdate/privilegedupdater/service_windows.go b/lib/teleterm/autoupdate/privilegedupdater/service_windows.go new file mode 100644 index 0000000000000..d0bef6bc4079c --- /dev/null +++ b/lib/teleterm/autoupdate/privilegedupdater/service_windows.go @@ -0,0 +1,531 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package privilegedupdater + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + "unsafe" + + "github.com/Microsoft/go-winio" + "github.com/coreos/go-semver/semver" + "github.com/google/uuid" + "github.com/gravitational/trace" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/teleterm/autoupdate/common" + logutils "github.com/gravitational/teleport/lib/utils/log" + "github.com/gravitational/teleport/lib/windowsservice" +) + +const ( + ServiceCommand = "connect-updater-service" + serviceName = "TeleportConnectUpdater" + serviceDescription = "Installs Teleport Connect updates without requiring administrator privileges." + eventSource = "connect-updater" + serviceAccessFlags = windows.SERVICE_START | windows.SERVICE_QUERY_STATUS + serviceRunTimeout = 30 * time.Second + + // SafePipeReadWriteAccess defines access for Authenticated Users (AU). + //According to https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipe-security-and-access-rights + // and https://stackoverflow.com/questions/29947524/c-let-user-process-write-to-local-system-named-pipe-custom-security-descrip + // the pipe should not set GENERIC_WRITE for standard users as it would allow them to create the pipe. + SafePipeReadWriteAccess = windows.GENERIC_READ | windows.FILE_WRITE_DATA + + updateDirSecurityDescriptor = "O:SY" + // Owner SYSTEM + "D:P" + // 'P' blocks permissions inheritance from the parent directory + "(A;OICI;GA;;;SY)" + // Allow System Full Access + "(A;OICI;GA;;;BA)" // Allow Built-in Administrators Full Access +) + +// makePipeServerSecurityDescriptor allows SYSTEM/Admins Full Control and grants Authenticated Users the passed access mask. +func makePipeServerSecurityDescriptor(authenticatedUsersAccess uint32) string { + return "D:" + // DACL + "(A;;GA;;;SY)" + // Allow (A);; Generic All (GA);;; SYSTEM (SY) + "(A;;GA;;;BA)" + // Allow (A);; Generic All (GA);;; Built-in Admins (BA) + fmt.Sprintf("(A;;%#x;;;AU)", authenticatedUsersAccess) // Allow (A);; authenticatedUsersAccess ;;; Authenticated Users (AU) +} + +var log = logutils.NewPackageLogger(teleport.ComponentKey, "autoupdate") + +// ServiceTestConfig allows overriding certain updater config properties. +// For test use only. +type ServiceTestConfig struct { + // UpdateDirSecurityDescriptor overrides updateDirSecurityDescriptor. + UpdateDirSecurityDescriptor string + // UpdateBaseDir overrides the default %ProgramData%\TeleportConnectUpdater update path. + UpdateBaseDir string + // PolicyToolsVersion overrides ToolsVersion in HKLM\SOFTWARE\Policies\Teleport\TeleportConnect. + PolicyToolsVersion string + // PolicyCDNBaseURL overrides CdnBaseUrl in HKLM\SOFTWARE\Policies\Teleport\TeleportConnect. + PolicyCDNBaseURL string + // HTTPClient overrides the client used for checksum download. + HTTPClient *http.Client + // PipeAuthenticatedUsersAccess overrides Authenticated Users access mask in + // the named pipe DACL. If zero, SafePipeReadWriteAccess is used. + PipeAuthenticatedUsersAccess uint32 +} + +// InstallService installs the Teleport Connect privileged update service. +// This service enables installing updates without prompting the user for administrator permissions. +func InstallService(ctx context.Context) (err error) { + return trace.Wrap(windowsservice.Install(ctx, &windowsservice.InstallConfig{ + Name: serviceName, + Command: ServiceCommand, + Description: serviceDescription, + EventSourceName: eventSource, + AccessPermissions: serviceAccessFlags, + })) +} + +// UninstallService uninstalls Teleport Connect privileged update service. +func UninstallService(ctx context.Context) (err error) { + return trace.Wrap(windowsservice.Uninstall(ctx, &windowsservice.UninstallConfig{ + Name: serviceName, + EventSourceName: eventSource, + })) +} + +// RunService implements Teleport Connect privileged update service. +// This service enables installing updates without prompting the user for administrator permissions. +func RunService() error { + h := &handler{ + testCfg: &ServiceTestConfig{}, + } + + closeLogger, err := windowsservice.InitSlogEventLogger(eventSource) + if err != nil { + return trace.Wrap(err) + } + + err = windowsservice.Run(&windowsservice.RunConfig{ + Name: serviceName, + Handler: h, + Logger: log, + }) + return trace.NewAggregate(err, closeLogger()) +} + +// RunServiceTest implements Teleport Connect privileged update service. +// It runs the service implementation directly. +// For test use only. +func RunServiceTest(ctx context.Context, cfg *ServiceTestConfig) error { + h := &handler{ + testCfg: cfg, + } + return trace.Wrap(h.Execute(ctx, nil)) +} + +type handler struct { + testCfg *ServiceTestConfig +} + +func (h *handler) Execute(ctx context.Context, _ []string) (err error) { + ctx, cancel := context.WithTimeout(ctx, serviceRunTimeout) + defer cancel() + + updaterConfig, err := h.getUpdaterConfig() + if err != nil { + return trace.Wrap(err, "getting updater config") + } + + updateMeta, updatePath, err := h.readUpdateMeta(ctx) + if err != nil { + return trace.Wrap(err, "reading update metadata") + } + + if updaterConfig.Version != "" && updateMeta.Version != updaterConfig.Version { + return trace.BadParameter("update version %s does not match policy version %s", updateMeta.Version, updaterConfig.Version) + } + + if err = ensureIsUpgrade(updateMeta.Version); err != nil { + return trace.Wrap(err, "checking if update is upgrade") + } + + // TODO(gzdunek): Add signature verification. + + hash, err := h.downloadChecksum(ctx, updaterConfig.CDNBaseURL, updateMeta.Version) + if err != nil { + return trace.Wrap(err, "downloading update checksum") + } + + if err = verifyUpdateChecksum(updatePath, hash); err != nil { + return trace.Wrap(err, "verifying update checksum") + } + + return trace.Wrap(runInstaller(updatePath, updateMeta.ForceRun), "running installer") +} + +// getUpdaterConfig reads the per-machine config. +func (h *handler) getUpdaterConfig() (*common.PolicyValues, error) { + policyValues, err := common.ReadRegistryPolicyValues(registry.LOCAL_MACHINE) + if err != nil { + return nil, trace.Wrap(err) + } + + versionFromPolicy := policyValues.Version + if h.testCfg.PolicyToolsVersion != "" { + versionFromPolicy = h.testCfg.PolicyToolsVersion + } + if versionFromPolicy == common.TeleportToolsVersionOff { + return nil, trace.AccessDenied("%s in HKLM\\%s is %q, automatic updates are disabled by system policy", common.RegistryValueToolsVersion, common.TeleportConnectPoliciesKeyPath, common.TeleportToolsVersionOff) + } + + cdnBaseURL := policyValues.CDNBaseURL + if h.testCfg.PolicyCDNBaseURL != "" { + cdnBaseURL = h.testCfg.PolicyCDNBaseURL + } + if cdnBaseURL == "" { + cdnBaseURL = common.GetDefaultBaseURL() + } + if cdnBaseURL == "" { + return nil, trace.AccessDenied("client tools updates are disabled as they are licensed under AGPL. To use Community Edition builds or custom binaries, set %s in HKLM\\%s", common.RegistryValueCDNBaseURL, common.TeleportConnectPoliciesKeyPath) + } + + return &common.PolicyValues{ + CDNBaseURL: cdnBaseURL, + Version: versionFromPolicy, + }, nil +} + +type acceptResult struct { + conn net.Conn + err error +} + +func (h *handler) readUpdateMeta(ctx context.Context) (_ *updateMetadata, _ string, err error) { + pipeAuthenticatedUsersAccess := uint32(SafePipeReadWriteAccess) + if h.testCfg.PipeAuthenticatedUsersAccess != 0 { + pipeAuthenticatedUsersAccess = h.testCfg.PipeAuthenticatedUsersAccess + } + + conn, err := waitForSingleClient(ctx, pipeAuthenticatedUsersAccess) + if err != nil { + return nil, "", trace.Wrap(err, "waiting for client") + } + closeConnOnce := sync.OnceValue(conn.Close) + // Always defer conn.Close and return the error. + defer func() { + err = trace.NewAggregate(err, trace.Wrap(closeConnOnce(), "closing conn")) + }() + // Close conn early to unblock reads if ctx is canceled. + defer context.AfterFunc(ctx, func() { _ = closeConnOnce() })() + + dir, err := h.getSecureUpdateDir() + if err != nil { + return nil, "", trace.Wrap(err) + } + + updatePath := filepath.Join(dir, "update.exe") + updateMeta, err := readUpdate(conn, updatePath) + if err != nil { + return nil, "", trace.Wrap(err) + } + return updateMeta, updatePath, nil +} + +// waitForSingleClient waits for the first client and then closes the listener. +func waitForSingleClient(ctx context.Context, authenticatedUsersAccess uint32) (net.Conn, error) { + l, err := winio.ListenPipe(PipePath, &winio.PipeConfig{ + SecurityDescriptor: makePipeServerSecurityDescriptor(authenticatedUsersAccess), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resCh := make(chan acceptResult, 1) + + go func() { + conn, acceptErr := l.Accept() + resCh <- acceptResult{conn: conn, err: acceptErr} + }() + + select { + case <-ctx.Done(): + err = l.Close() + // Drain the goroutine — l.Close() unblocks Accept(). + res := <-resCh + if res.conn != nil { + _ = res.conn.Close() + } + return nil, trace.NewAggregate(ctx.Err(), err) + case res := <-resCh: + if res.err != nil { + return nil, trace.Wrap(res.err) + } + if err = l.Close(); err != nil { + return nil, trace.NewAggregate(err, res.conn.Close()) + } + return res.conn, nil + } +} + +// getSecureUpdateDir secures %ProgramData%\TeleportConnectUpdater directory and then returns +// a unique %ProgramData%\TeleportConnectUpdater\ path. +func (h *handler) getSecureUpdateDir() (string, error) { + updateRoot := h.testCfg.UpdateBaseDir + if updateRoot == "" { + programData, err := windows.KnownFolderPath(windows.FOLDERID_ProgramData, 0) + if err != nil { + return "", trace.Wrap(err, "reading ProgramData path") + } + updateRoot = filepath.Join(programData, "TeleportConnectUpdater") + } + + descriptor := updateDirSecurityDescriptor + if h.testCfg.UpdateDirSecurityDescriptor != "" { + descriptor = h.testCfg.UpdateDirSecurityDescriptor + } + sd, err := windows.SecurityDescriptorFromString(descriptor) + if err != nil { + return "", trace.Wrap(err, "creating security descriptor") + } + + sa := &windows.SecurityAttributes{ + Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), + SecurityDescriptor: sd, + InheritHandle: 0, + } + + if err = ensureDirIsSecure(updateRoot, sa); err != nil { + return "", trace.Wrap(err, "securing TeleportConnectUpdater directory") + } + + err = cleanupOldUpdates(updateRoot) + if err != nil { + return "", trace.Wrap(err, "cleaning up old updates") + } + + // Create a per-update random directory. This prevents DLL planting attacks, as the update is executed from its own directory. + newGUID := uuid.New().String() + updateDir := filepath.Join(updateRoot, newGUID) + updateDirPtr, err := windows.UTF16PtrFromString(updateDir) + if err != nil { + return "", trace.Wrap(err) + } + + if err = windows.CreateDirectory(updateDirPtr, sa); err != nil { + return "", trace.Wrap(err, "failed to create update dir") + } + + return updateDir, nil +} + +// ensureDirIsSecure guarantees that the directory exists and is locked down to SYSTEM/Admins only. +func ensureDirIsSecure(dir string, sa *windows.SecurityAttributes) error { + namePtr, err := windows.UTF16PtrFromString(dir) + if err != nil { + return trace.Wrap(err) + } + + // Try to create the directory with the secure ACLs immediately. + err = windows.CreateDirectory(namePtr, sa) + // If the directory exists, continue with verification and reapply the ACLs. + if err != nil && !errors.Is(err, windows.ERROR_ALREADY_EXISTS) { + return trace.Wrap(err, "creating directory") + } + + // If the directory exists, open a handle with DACL modification rights + // We use FILE_FLAG_OPEN_REPARSE_POINT to ensure we open the directory itself, + // not a target it might point to (it could be a junction). + dirHandle, err := windows.CreateFile( + namePtr, + windows.READ_CONTROL|windows.WRITE_DAC|windows.WRITE_OWNER, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_FLAG_BACKUP_SEMANTICS, + 0, + ) + if err != nil { + return trace.Wrap(err, "failed to open handle to existing directory") + } + defer windows.CloseHandle(dirHandle) + + // Verify it is a real directory (not a symlink/junction) + // This prevents redirection attacks where we might unexpectedly secure a system folder. + var info windows.ByHandleFileInformation + if err = windows.GetFileInformationByHandle(dirHandle, &info); err != nil { + return trace.Wrap(err, "getting file information") + } + + if info.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 { + return trace.BadParameter("security violation: %s is a reparse point", dir) + } + + if info.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 { + return trace.BadParameter("security violation: %s exists but is not a directory", dir) + } + + owner, _, err := sa.SecurityDescriptor.Owner() + if err != nil { + return trace.Wrap(err, "reading owner from security descriptor") + } + dacl, _, err := sa.SecurityDescriptor.DACL() + if err != nil { + return trace.Wrap(err, "reading DACL from security descriptor") + } + + // Reapply directory ACLs. + err = windows.SetSecurityInfo( + dirHandle, + windows.SE_FILE_OBJECT, + // PROTECTED_DACL_SECURITY_INFORMATION stops the directory from inheriting + // "User Write" permissions from the parent (%ProgramData%). + windows.OWNER_SECURITY_INFORMATION|windows.DACL_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, + owner, + nil, + dacl, + nil, + ) + + return trace.Wrap(err, "resetting directory security") +} + +// cleanupOldUpdates removes stale update directories and files from the cache. +// Failures to remove individual entries are logged and ignored so cleanup can continue. +// +// This is fine, as updates are always stored in freshly generated, random subdirectories. +// This saves us from accidentally executing attacker-controlled files (e.g., planted DLLs), +// +// Important: +// This function runs with SYSTEM privileges and relies on the Go standard library’s +// os.RemoveAll implementation on Windows. It detects reparse points (symlinks and +// junctions) and removes the link itself without ever recursing into the target, +// mitigating junction/symlink crossing attacks. +func cleanupOldUpdates(baseDir string) error { + entries, err := os.ReadDir(baseDir) + if err != nil { + return trace.Wrap(err) + } + for _, entry := range entries { + fullPath := filepath.Join(baseDir, entry.Name()) + + err = os.RemoveAll(fullPath) + if err != nil { + log.Error("Failed to remove old update file", "path", fullPath, "error", err) + } + } + return nil +} + +func ensureIsUpgrade(updateVersion string) error { + updateSemver, err := semver.NewVersion(updateVersion) + if err != nil { + return trace.Wrap(err, "invalid update version %q", updateVersion) + } + current := teleport.SemVer() + if current == nil { + return trace.BadParameter("current version is not available") + } + if updateSemver.Compare(*current) <= 0 { + return trace.BadParameter("update version %s is not newer than current version %s", updateSemver, current) + } + return nil +} + +func (h *handler) downloadChecksum(ctx context.Context, baseUrl string, version string) ([]byte, error) { + parsedBaseURL, err := url.Parse(baseUrl) + if err != nil { + return nil, trace.Wrap(err, "parsing base URL") + } + // Keep updater policy aligned with Service.GetConfig RPC validation and reject non-TLS CDNs even if this path is called outside the UI flow. + if parsedBaseURL.Scheme != "https" { + return nil, trace.BadParameter("CDN base URL must be https") + } + filename := fmt.Sprintf("Teleport Connect Setup-%s.exe.sha256", version) + downloadURL := parsedBaseURL.JoinPath(filename) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL.String(), nil) + if err != nil { + return nil, trace.Wrap(err) + } + client := http.DefaultClient + if h.testCfg.HTTPClient != nil { + client = h.testCfg.HTTPClient + } + resp, err := client.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, trace.BadParameter("update hash request failed with status %s", resp.Status) + } + + var buf bytes.Buffer + _, err = io.CopyN(&buf, resp.Body, sha256.Size*2) // SHA bytes to hex + if err != nil { + return nil, trace.Wrap(err) + } + hexBytes, err := hex.DecodeString(buf.String()) + if err != nil { + return nil, trace.Wrap(err) + } + return hexBytes, nil +} + +func verifyUpdateChecksum(updatePath string, expectedHash []byte) error { + file, err := os.Open(updatePath) + if err != nil { + return trace.Wrap(err) + } + defer file.Close() + + hasher := sha256.New() + if _, err = io.Copy(hasher, file); err != nil { + return trace.Wrap(err) + } + actual := hasher.Sum(nil) + if !bytes.Equal(actual, expectedHash) { + return trace.BadParameter("hash of the update does not match downloaded checksum") + } + return nil +} + +func runInstaller(updatePath string, forceRun bool) error { + args := []string{"--updated", "/S", "/allusers"} + if forceRun { + args = append(args, "--force-run") + } + cmd := exec.Command(updatePath, args...) + + log.Info("Running command", "command", cmd.String()) + + err := cmd.Start() + if err != nil { + return trace.Wrap(err, "starting installer path=%s args=%q", updatePath, strings.Join(args, " ")) + } + + // Release the handle so the parent process can exit and the installer will continue. + return trace.Wrap(cmd.Process.Release()) +} diff --git a/lib/teleterm/autoupdate/service.go b/lib/teleterm/autoupdate/service.go index 093d94c8a4fbe..9177419c5e8d8 100644 --- a/lib/teleterm/autoupdate/service.go +++ b/lib/teleterm/autoupdate/service.go @@ -31,7 +31,7 @@ import ( "github.com/gravitational/teleport/api/client/webclient" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/auto_update/v1" "github.com/gravitational/teleport/lib/autoupdate" - "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/teleterm/autoupdate/common" "github.com/gravitational/teleport/lib/teleterm/clusters" ) @@ -42,7 +42,6 @@ const ( // launching tsh with TELEPORT_TOOLS_VERSION=off, and forwards the real value via // FORWARDED_TELEPORT_TOOLS_VERSION. forwardedTeleportToolsEnvVar = "FORWARDED_TELEPORT_TOOLS_VERSION" - teleportToolsVersionOff = "off" ) // Service implements gRPC service for autoupdate. @@ -154,7 +153,7 @@ func (s *Service) GetConfig(_ context.Context, _ *api.GetConfigRequest) (*api.Ge switch toolsVersionValue { case "": toolsVersionSource = api.ConfigSource_CONFIG_SOURCE_UNSPECIFIED - case teleportToolsVersionOff: + case common.TeleportToolsVersionOff: break default: if _, err = semver.NewVersion(toolsVersionValue); err != nil { @@ -172,10 +171,9 @@ func (s *Service) GetConfig(_ context.Context, _ *api.GetConfigRequest) (*api.Ge } } - m := modules.GetModules() - // Uses the same logic as the teleport/lib/autoupdate/tools package. - if cdnBaseUrlValue == "" && m.BuildType() != modules.BuildOSS { - cdnBaseUrlValue = autoupdate.DefaultBaseURL + defaultBaseUrlValue := common.GetDefaultBaseURL() + if cdnBaseUrlValue == "" && defaultBaseUrlValue != "" { + cdnBaseUrlValue = defaultBaseUrlValue cdnBaseUrlSource = api.ConfigSource_CONFIG_SOURCE_DEFAULT } diff --git a/lib/teleterm/autoupdate/service_windows.go b/lib/teleterm/autoupdate/service_windows.go index 09eae1e6d9e2f..49d59b157120c 100644 --- a/lib/teleterm/autoupdate/service_windows.go +++ b/lib/teleterm/autoupdate/service_windows.go @@ -18,7 +18,6 @@ package autoupdate import ( "context" - "errors" "os" "path/filepath" @@ -26,6 +25,7 @@ import ( "golang.org/x/sys/windows/registry" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/auto_update/v1" + "github.com/gravitational/teleport/lib/teleterm/autoupdate/common" ) const ( @@ -33,10 +33,6 @@ const ( teleportConnectGUID = "22539266-67e8-54a3-83b9-dfdca7b33ee1" teleportConnectKeyPath = `SOFTWARE\` + teleportConnectGUID registryValueInstallLocation = "InstallLocation" - - teleportConnectPoliciesKeyPath = `SOFTWARE\Policies\Teleport\TeleportConnect` - registryValueToolsVersion = "ToolsVersion" - registryValueCDNBaseURL = "CdnBaseUrl" ) // GetInstallationMetadata returns installation metadata of the currently running app instance. @@ -56,40 +52,40 @@ func platformGetConfig() (*api.GetConfigResponse, error) { return nil, trace.Wrap(err) } - machineValues, err := readRegistryPolicyValues(registry.LOCAL_MACHINE) + machineValues, err := common.ReadRegistryPolicyValues(registry.LOCAL_MACHINE) if err != nil { return nil, trace.Wrap(err) } config := &api.GetConfigResponse{ CdnBaseUrl: &api.ConfigValue{ - Value: machineValues.cdnBaseURL, + Value: machineValues.CDNBaseURL, Source: api.ConfigSource_CONFIG_SOURCE_POLICY, }, ToolsVersion: &api.ConfigValue{ - Value: machineValues.version, + Value: machineValues.Version, Source: api.ConfigSource_CONFIG_SOURCE_POLICY, }, } // If per-machine config is fully set, there's no need to check other sources. - perMachineConfigFullySet := machineValues.cdnBaseURL != "" && machineValues.version != "" + perMachineConfigFullySet := machineValues.CDNBaseURL != "" && machineValues.Version != "" if perMachineConfigFullySet { return config, nil } if !perMachine { - userValues, err := readRegistryPolicyValues(registry.CURRENT_USER) + userValues, err := common.ReadRegistryPolicyValues(registry.CURRENT_USER) if err != nil { return nil, trace.Wrap(err) } - if machineValues.cdnBaseURL == "" { - config.CdnBaseUrl.Value = userValues.cdnBaseURL + if machineValues.CDNBaseURL == "" { + config.CdnBaseUrl.Value = userValues.CDNBaseURL } - if machineValues.version == "" { - config.ToolsVersion.Value = userValues.version + if machineValues.Version == "" { + config.ToolsVersion.Value = userValues.Version } } @@ -111,7 +107,7 @@ func platformGetConfig() (*api.GetConfigResponse, error) { } func isPerMachineInstall() (bool, error) { - perMachineLocation, err := readRegistryValue(registry.LOCAL_MACHINE, teleportConnectKeyPath, registryValueInstallLocation) + perMachineLocation, err := common.ReadRegistryValue(registry.LOCAL_MACHINE, teleportConnectKeyPath, registryValueInstallLocation) if err != nil { if trace.IsNotFound(err) { return false, nil @@ -129,51 +125,3 @@ func isPerMachineInstall() (bool, error) { return exePath == exePathInPerMachineLocation, nil } - -type policyValue struct { - cdnBaseURL string - version string -} - -func readRegistryPolicyValues(key registry.Key) (*policyValue, error) { - version, err := readRegistryValue(key, teleportConnectPoliciesKeyPath, registryValueToolsVersion) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - - url, err := readRegistryValue(key, teleportConnectPoliciesKeyPath, registryValueCDNBaseURL) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - - return &policyValue{ - cdnBaseURL: url, - version: version, - }, nil -} - -func readRegistryValue(hive registry.Key, pathName string, valueName string) (path string, err error) { - key, err := registry.OpenKey(hive, pathName, registry.READ) - if err != nil { - if errors.Is(err, registry.ErrNotExist) { - return "", trace.NotFound("registry key %s not found", pathName) - } - return "", trace.Wrap(err, "opening registry key %s", pathName) - } - - defer func() { - if closeErr := key.Close(); closeErr != nil && err == nil { - err = trace.Wrap(closeErr, "closing registry key %s", pathName) - } - }() - - path, _, err = key.GetStringValue(valueName) - if err != nil { - if errors.Is(err, registry.ErrNotExist) { - return "", trace.NotFound("registry value %s not found in %s", valueName, pathName) - } - return "", trace.Wrap(err, "reading registry value %s from %s", valueName, pathName) - } - - return path, nil -}