diff --git a/lib/vnet/install_service_windows.go b/lib/vnet/install_service_windows.go
index fac0d068ddef8..5699b0200c3ec 100644
--- a/lib/vnet/install_service_windows.go
+++ b/lib/vnet/install_service_windows.go
@@ -18,170 +18,41 @@ package vnet
import (
"context"
- "errors"
- "fmt"
"os"
"path/filepath"
- "strings"
"github.com/gravitational/trace"
"golang.org/x/sys/windows"
- "golang.org/x/sys/windows/svc/eventlog"
- "golang.org/x/sys/windows/svc/mgr"
- "github.com/gravitational/teleport"
- eventlogutils "github.com/gravitational/teleport/lib/utils/log/eventlog"
+ "github.com/gravitational/teleport/lib/windowsservice"
)
-// InstallService installs the VNet windows service.
-//
-// Windows services are installed by the service manager, which takes a path to
-// the service executable. So that regular users are not able to overwrite the
-// executable at that path, we use a path under %PROGRAMFILES%, which is not
-// writable by regular users by default.
-func InstallService(ctx context.Context) (err error) {
+const eventSource = "vnet"
+
+// InstallService installs the VNet Windows service.
+func InstallService(ctx context.Context) error {
tshPath, err := os.Executable()
if err != nil {
return trace.Wrap(err, "getting current exe path")
}
- if err := assertTshInProgramFiles(tshPath); err != nil {
- return trace.Wrap(err, "checking if tsh.exe is installed under %%PROGRAMFILES%%")
- }
if err := assertWintunInstalled(tshPath); err != nil {
return trace.Wrap(err, "checking if wintun.dll is installed next to %s", tshPath)
}
-
- svcMgr, err := mgr.Connect()
- if err != nil {
- return trace.Wrap(err, "connecting to Windows service manager")
- }
- svc, err := svcMgr.OpenService(serviceName)
- if err != nil {
- if !errors.Is(err, windows.ERROR_SERVICE_DOES_NOT_EXIST) {
- return trace.Wrap(err, "unexpected error checking if Windows service %s exists", serviceName)
- }
- // The service has not been created yet and must be installed.
- svc, err = svcMgr.CreateService(
- serviceName,
- tshPath,
- mgr.Config{
- StartType: mgr.StartManual,
- },
- ServiceCommand,
- )
- if err != nil {
- return trace.Wrap(err, "creating VNet Windows service")
- }
- }
- if err := svc.Close(); err != nil {
- return trace.Wrap(err, "closing VNet Windows service")
- }
- if err := grantServiceRights(); err != nil {
- return trace.Wrap(err, "granting authenticated users permission to control the VNet Windows service")
- }
- if err := installEventSource(); err != nil {
- trace.Wrap(err, "creating event source for logging")
- }
- if err := logInstallationEvent("VNet service installed"); err != nil {
- trace.Wrap(err, "logging installation event")
- }
- return nil
-}
-
-// UninstallService uninstalls the VNet windows service.
-func UninstallService(ctx context.Context) (err error) {
- svcMgr, err := mgr.Connect()
- if err != nil {
- return trace.Wrap(err, "connecting to Windows service manager")
- }
- svc, err := svcMgr.OpenService(serviceName)
- if err != nil {
- return trace.Wrap(err, "opening Windows service %s", serviceName)
- }
- if err := svc.Delete(); err != nil {
- return trace.Wrap(err, "deleting Windows service %s", serviceName)
- }
- if err := svc.Close(); err != nil {
- return trace.Wrap(err, "closing VNet Windows service")
- }
-
- if err := logInstallationEvent("VNet service uninstalled"); err != nil {
- trace.Wrap(err, "logging installation event")
- }
- if err := eventlogutils.Remove(eventlogutils.LogName, eventSource); err != nil {
- return trace.Wrap(err, "removing event source for logging")
- }
-
- return nil
-}
-
-func grantServiceRights() error {
- // Get the current security info for the service, requesting only the DACL
- // (discretionary access control list).
- si, err := windows.GetNamedSecurityInfo(serviceName, windows.SE_SERVICE, windows.DACL_SECURITY_INFORMATION)
- if err != nil {
- return trace.Wrap(err, "getting current service security information")
- }
- // Get the DACL from the security info.
- dacl, _ /*defaulted*/, err := si.DACL()
- if err != nil {
- return trace.Wrap(err, "getting current service DACL")
- }
- // This is the universal well-known SID for "Authenticated Users".
- authenticatedUsersSID, err := windows.StringToSid("S-1-5-11")
- if err != nil {
- return trace.Wrap(err, "parsing authenticated users SID")
- }
- // Build an explicit access entry allowing authenticated users to start,
- // stop, and query the service.
- ea := []windows.EXPLICIT_ACCESS{{
+ return trace.Wrap(windowsservice.Install(ctx, &windowsservice.InstallConfig{
+ Name: serviceName,
+ Description: serviceDescription,
+ Command: ServiceCommand,
+ EventSourceName: eventSource,
AccessPermissions: windows.SERVICE_QUERY_STATUS | windows.SERVICE_START | windows.SERVICE_STOP,
- AccessMode: windows.GRANT_ACCESS,
- Trustee: windows.TRUSTEE{
- TrusteeForm: windows.TRUSTEE_IS_SID,
- TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
- TrusteeValue: windows.TrusteeValueFromSID(authenticatedUsersSID),
- },
- }}
- // Merge the new explicit access entry with the existing DACL.
- dacl, err = windows.ACLFromEntries(ea, dacl)
- if err != nil {
- return trace.Wrap(err, "merging service DACL entries")
- }
- // Set the DACL on the service security info.
- if err := windows.SetNamedSecurityInfo(
- serviceName,
- windows.SE_SERVICE,
- windows.DACL_SECURITY_INFORMATION,
- nil, // owner
- nil, // group
- dacl, // dacl
- nil, // sacl
- ); err != nil {
- return trace.Wrap(err, "setting service DACL")
- }
- return nil
+ }))
}
-// assertTshInProgramFiles asserts that tsh is a regular file installed under
-// the program files directory (usually C:\Program Files\).
-func assertTshInProgramFiles(tshPath string) error {
- if err := assertRegularFile(tshPath); err != nil {
- return trace.Wrap(err)
- }
- programFiles := os.Getenv("PROGRAMFILES")
- if programFiles == "" {
- return trace.Errorf("PROGRAMFILES env var is not set")
- }
- // Windows file paths are case-insensitive.
- cleanedProgramFiles := strings.ToLower(filepath.Clean(programFiles)) + string(filepath.Separator)
- cleanedTshPath := strings.ToLower(filepath.Clean(tshPath))
- if !strings.HasPrefix(cleanedTshPath, cleanedProgramFiles) {
- return trace.BadParameter(
- "tsh.exe is currently installed at %s, it must be installed under %s in order to install the VNet Windows service",
- tshPath, programFiles)
- }
- return nil
+// UninstallService uninstalls the Windows VNet service.
+func UninstallService(ctx context.Context) error {
+ return trace.Wrap(windowsservice.Uninstall(ctx, &windowsservice.UninstallConfig{
+ Name: serviceName,
+ EventSourceName: eventSource,
+ }))
}
// asertWintunInstalled returns an error if wintun.dll is not a regular file
@@ -203,34 +74,3 @@ func assertRegularFile(path string) error {
}
return nil
}
-
-const eventSource = "vnet"
-
-func installEventSource() error {
- exe, err := os.Executable()
- if err != nil {
- return trace.Wrap(err)
- }
- // Assume that the message file is shipped next to tsh.exe.
- msgFilePath := filepath.Join(filepath.Dir(exe), "msgfile.dll")
-
- // This should create a registry entry under
- // SYSTEM\CurrentControlSet\Services\EventLog\Teleport\vnet with an absolute path to msgfile.dll.
- // If the user moves Teleport Connect to some other directory, logs will still be captured, but
- // they might display a message about missing event ID until the user reinstalls the app.
- err = eventlogutils.Install(eventlogutils.LogName, eventSource, msgFilePath, false /* useExpandKey */)
- return trace.Wrap(err)
-}
-
-func logInstallationEvent(eventMessage string) error {
- log, err := eventlog.Open(eventSource)
- if err != nil {
- return trace.Wrap(err, "opening logger")
- }
-
- if err := log.Info(eventlogutils.EventID, fmt.Sprintf("%s version:%s", eventMessage, teleport.Version)); err != nil {
- return trace.Wrap(err, "writing log message")
- }
-
- return trace.Wrap(log.Close(), "closing logger")
-}
diff --git a/lib/vnet/service_windows.go b/lib/vnet/service_windows.go
index d672cdfd841da..9823ef6279ab7 100644
--- a/lib/vnet/service_windows.go
+++ b/lib/vnet/service_windows.go
@@ -17,12 +17,8 @@
package vnet
import (
- "cmp"
"context"
- "errors"
"log/slog"
- "os"
- "strconv"
"syscall"
"time"
@@ -33,7 +29,7 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"github.com/gravitational/teleport"
- logutils "github.com/gravitational/teleport/lib/utils/log"
+ "github.com/gravitational/teleport/lib/windowsservice"
)
const (
@@ -126,83 +122,23 @@ func startService(ctx context.Context, cfg *windowsAdminProcessConfig) (*mgr.Ser
// ServiceMain runs the Windows VNet admin service.
func ServiceMain() error {
- closeFn, err := setupServiceLogger()
+ closeLogger, err := windowsservice.InitSlogEventLogger(eventSource)
if err != nil {
- return trace.Wrap(err, "setting up logger for service")
- }
-
- if err := svc.Run(serviceName, &windowsService{}); err != nil {
- closeFn()
- return trace.Wrap(err, "running Windows service")
+ return trace.Wrap(err)
}
-
- return trace.Wrap(closeFn(), "closing logger")
-}
-
-// windowsService implements [svc.Handler].
-type windowsService struct{}
-
-// Execute implements [svc.Handler.Execute], the GoDoc is copied below.
-//
-// Execute will be called by the package code at the start of the service, and
-// the service will exit once Execute completes. Inside Execute you must read
-// service change requests from [requests] and act accordingly. You must keep
-// service control manager up to date about state of your service by writing
-// into [status] as required. args contains service name followed by argument
-// strings passed to the service.
-// You can provide service exit code in exitCode return parameter, with 0 being
-// "no error". You can also indicate if exit code, if any, is service specific
-// or not by using svcSpecificEC parameter.
-func (s *windowsService) Execute(args []string, requests <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
logger := slog.With(teleport.ComponentKey, teleport.Component("vnet", "windows-service"))
- const cmdsAccepted = svc.AcceptStop // Interrogate is always accepted and there is no const for it.
- status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- errCh := make(chan error)
- go func() { errCh <- s.run(ctx, args) }()
- var terminateTimedOut <-chan time.Time
-loop:
- for {
- select {
- case request := <-requests:
- switch request.Cmd {
- case svc.Interrogate:
- state := svc.Running
- if ctx.Err() != nil {
- state = svc.StopPending
- }
- status <- svc.Status{State: state, Accepts: cmdsAccepted}
- case svc.Stop:
- logger.InfoContext(ctx, "Received stop command, shutting down service")
- // Cancel the context passed to s.run to terminate the
- // networking stack.
- cancel()
- terminateTimedOut = cmp.Or(terminateTimedOut, time.After(terminateTimeout))
- status <- svc.Status{State: svc.StopPending}
- }
- case <-terminateTimedOut:
- logger.ErrorContext(ctx, "Networking stack failed to terminate within timeout, exiting process",
- slog.Duration("timeout", terminateTimeout))
- exitCode = 1
- break loop
- case err := <-errCh:
- if err == nil || errors.Is(err, context.Canceled) {
- logger.InfoContext(ctx, "Service terminated")
- } else {
- logger.ErrorContext(ctx, "Service terminated", "error", err)
- exitCode = 1
- }
- break loop
- }
- }
- status <- svc.Status{State: svc.Stopped, Win32ExitCode: exitCode}
- return false, exitCode
+ err = windowsservice.Run(&windowsservice.RunConfig{
+ Name: serviceName,
+ Handler: &handler{},
+ Logger: logger,
+ })
+ return trace.NewAggregate(err, closeLogger())
}
-func (s *windowsService) run(ctx context.Context, args []string) error {
+type handler struct{}
+
+func (w *handler) Execute(ctx context.Context, args []string) error {
var cfg windowsAdminProcessConfig
app := kingpin.New(serviceName, "Teleport VNet Windows Service")
serviceCmd := app.Command("vnet-service", "Start the VNet service.")
@@ -221,23 +157,3 @@ func (s *windowsService) run(ctx context.Context, args []string) error {
}
return nil
}
-
-func setupServiceLogger() (func() error, error) {
- level := slog.LevelInfo
- if envVar := os.Getenv(teleport.VerboseLogsEnvVar); envVar != "" {
- isDebug, err := strconv.ParseBool(envVar)
- if err != nil {
- return nil, trace.Wrap(err, "parsing %s", teleport.VerboseLogsEnvVar)
- }
- if isDebug {
- level = slog.LevelDebug
- }
- }
-
- handler, close, err := logutils.NewSlogEventLogHandler("vnet", level)
- if err != nil {
- return nil, trace.Wrap(err, "initializing log handler")
- }
- slog.SetDefault(slog.New(handler))
- return close, nil
-}
diff --git a/lib/windowsservice/install_windows.go b/lib/windowsservice/install_windows.go
new file mode 100644
index 0000000000000..546791d1b229d
--- /dev/null
+++ b/lib/windowsservice/install_windows.go
@@ -0,0 +1,267 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package windowsservice
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc/eventlog"
+ "golang.org/x/sys/windows/svc/mgr"
+
+ "github.com/gravitational/teleport"
+ eventlogutils "github.com/gravitational/teleport/lib/utils/log/eventlog"
+)
+
+// InstallConfig defines parameters for installing a Windows service
+// that is implemented by tsh.exe.
+type InstallConfig struct {
+ // Name is the service name.
+ Name string
+ // Description is the service description.
+ Description string
+ // Command is the tsh subcommand that the service manager invokes on start.
+ Command string
+ // EventSourceName is the name of an event source that will log service events.
+ EventSourceName string
+ // AccessPermissions defines which service control actions are granted to
+ // authenticated users (e.g., start/stop/query).
+ AccessPermissions windows.ACCESS_MASK
+}
+
+// Install installs a Windows service implemented by tsh.exe.
+//
+// Windows services are installed by the service manager, which takes a path to
+// the service executable. So that regular users are not able to overwrite the
+// executable at that path, we use a path under %PROGRAMFILES%, which is not
+// writable by regular users by default.
+func Install(ctx context.Context, cfg *InstallConfig) (err error) {
+ if cfg.Name == "" {
+ return trace.BadParameter("service name is required")
+ }
+ if cfg.Command == "" {
+ return trace.BadParameter("command is required")
+ }
+ if cfg.EventSourceName == "" {
+ return trace.BadParameter("event source name is required")
+ }
+ if cfg.AccessPermissions == 0 {
+ return trace.BadParameter("access permissions is required")
+ }
+
+ tshPath, err := os.Executable()
+ if err != nil {
+ return trace.Wrap(err, "getting current exe path")
+ }
+ if err := assertTshInProgramFiles(tshPath); err != nil {
+ return trace.Wrap(err, "checking if tsh.exe is installed under %%PROGRAMFILES%%")
+ }
+
+ svcMgr, err := mgr.Connect()
+ if err != nil {
+ return trace.Wrap(err, "connecting to Windows service manager")
+ }
+ svc, err := svcMgr.OpenService(cfg.Name)
+ if err != nil {
+ if !errors.Is(err, windows.ERROR_SERVICE_DOES_NOT_EXIST) {
+ return trace.Wrap(err, "unexpected error checking if Windows service %s exists", cfg.Name)
+ }
+ // The service has not been created yet and must be installed.
+ svc, err = svcMgr.CreateService(
+ cfg.Name,
+ tshPath,
+ mgr.Config{
+ StartType: mgr.StartManual,
+ Description: cfg.Description,
+ },
+ cfg.Command,
+ )
+ if err != nil {
+ return trace.Wrap(err, "creating VNet Windows service")
+ }
+ }
+ if err := svc.Close(); err != nil {
+ return trace.Wrap(err, "closing VNet Windows service")
+ }
+ if err := grantServiceRights(cfg.Name, cfg.AccessPermissions); err != nil {
+ return trace.Wrap(err, "granting authenticated users permission to control the VNet Windows service")
+ }
+ if err := installEventSource(cfg.EventSourceName); err != nil {
+ return trace.Wrap(err, "creating event source for logging")
+ }
+ if err := logInstallationEvent(cfg.EventSourceName, fmt.Sprintf("%s service installed", cfg.Name)); err != nil {
+ return trace.Wrap(err, "logging installation event")
+ }
+ return nil
+}
+
+// UninstallConfig defines parameters for removing a Windows service.
+type UninstallConfig struct {
+ // Name is the service name.
+ Name string
+ // EventSourceName is the event source to remove from the Windows Event Log.
+ EventSourceName string
+}
+
+// Uninstall uninstalls the Windows service.
+func Uninstall(ctx context.Context, cfg *UninstallConfig) (err error) {
+ if cfg.Name == "" {
+ return trace.BadParameter("service name is required")
+ }
+ if cfg.EventSourceName == "" {
+ return trace.BadParameter("event source name is required")
+ }
+ svcMgr, err := mgr.Connect()
+ if err != nil {
+ return trace.Wrap(err, "connecting to Windows service manager")
+ }
+ svc, err := svcMgr.OpenService(cfg.Name)
+ if err != nil {
+ return trace.Wrap(err, "opening Windows service %s", cfg.Name)
+ }
+ if err := svc.Delete(); err != nil {
+ return trace.Wrap(err, "deleting Windows service %s", cfg.Name)
+ }
+ if err := svc.Close(); err != nil {
+ return trace.Wrap(err, "closing VNet Windows service")
+ }
+
+ if err := logInstallationEvent(cfg.EventSourceName, fmt.Sprintf("%s service uninstalled", cfg.Name)); err != nil {
+ return trace.Wrap(err, "logging installation event")
+ }
+ if err := eventlogutils.Remove(eventlogutils.LogName, cfg.EventSourceName); err != nil {
+ return trace.Wrap(err, "removing event source for logging")
+ }
+
+ return nil
+}
+
+func grantServiceRights(name string, accessPermissions windows.ACCESS_MASK) error {
+ // Get the current security info for the service, requesting only the DACL
+ // (discretionary access control list).
+ si, err := windows.GetNamedSecurityInfo(name, windows.SE_SERVICE, windows.DACL_SECURITY_INFORMATION)
+ if err != nil {
+ return trace.Wrap(err, "getting current service security information")
+ }
+ // Get the DACL from the security info.
+ dacl, _ /*defaulted*/, err := si.DACL()
+ if err != nil {
+ return trace.Wrap(err, "getting current service DACL")
+ }
+ // This is the universal well-known SID for "Authenticated Users".
+ authenticatedUsersSID, err := windows.StringToSid("S-1-5-11")
+ if err != nil {
+ return trace.Wrap(err, "parsing authenticated users SID")
+ }
+ // Build an explicit access entry allowing authenticated users to start,
+ // stop, and query the service.
+ ea := []windows.EXPLICIT_ACCESS{{
+ AccessPermissions: accessPermissions,
+ AccessMode: windows.GRANT_ACCESS,
+ Trustee: windows.TRUSTEE{
+ TrusteeForm: windows.TRUSTEE_IS_SID,
+ TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
+ TrusteeValue: windows.TrusteeValueFromSID(authenticatedUsersSID),
+ },
+ }}
+ // Merge the new explicit access entry with the existing DACL.
+ dacl, err = windows.ACLFromEntries(ea, dacl)
+ if err != nil {
+ return trace.Wrap(err, "merging service DACL entries")
+ }
+ // Set the DACL on the service security info.
+ if err := windows.SetNamedSecurityInfo(
+ name,
+ windows.SE_SERVICE,
+ windows.DACL_SECURITY_INFORMATION,
+ nil, // owner
+ nil, // group
+ dacl, // dacl
+ nil, // sacl
+ ); err != nil {
+ return trace.Wrap(err, "setting service DACL")
+ }
+ return nil
+}
+
+// assertTshInProgramFiles asserts that tsh is a regular file installed under
+// the program files directory (usually C:\Program Files\).
+func assertTshInProgramFiles(tshPath string) error {
+ if err := assertRegularFile(tshPath); err != nil {
+ return trace.Wrap(err)
+ }
+ programFiles := os.Getenv("PROGRAMFILES")
+ if programFiles == "" {
+ return trace.Errorf("PROGRAMFILES env var is not set")
+ }
+ // Windows file paths are case-insensitive.
+ cleanedProgramFiles := strings.ToLower(filepath.Clean(programFiles)) + string(filepath.Separator)
+ cleanedTshPath := strings.ToLower(filepath.Clean(tshPath))
+ if !strings.HasPrefix(cleanedTshPath, cleanedProgramFiles) {
+ return trace.BadParameter(
+ "tsh.exe is currently installed at %s, it must be installed under %s in order to install the VNet Windows service",
+ tshPath, programFiles)
+ }
+ return nil
+}
+
+func assertRegularFile(path string) error {
+ switch info, err := os.Lstat(path); {
+ case os.IsNotExist(err):
+ return trace.Wrap(err, "%s not found", path)
+ case err != nil:
+ return trace.Wrap(err, "unexpected error checking %s", path)
+ case !info.Mode().IsRegular():
+ return trace.BadParameter("%s is not a regular file", path)
+ }
+ return nil
+}
+
+func installEventSource(name string) error {
+ exe, err := os.Executable()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ // Assume that the message file is shipped next to tsh.exe.
+ msgFilePath := filepath.Join(filepath.Dir(exe), "msgfile.dll")
+
+ // This should create a registry entry under
+ // SYSTEM\CurrentControlSet\Services\EventLog\Teleport\ with an absolute path to msgfile.dll.
+ // If the user moves Teleport Connect to some other directory, logs will still be captured, but
+ // they might display a message about missing event ID until the user reinstalls the app.
+ err = eventlogutils.Install(eventlogutils.LogName, name, msgFilePath, false /* useExpandKey */)
+ return trace.Wrap(err)
+}
+
+func logInstallationEvent(name string, eventMessage string) error {
+ log, err := eventlog.Open(name)
+ if err != nil {
+ return trace.Wrap(err, "opening logger")
+ }
+
+ if err := log.Info(eventlogutils.EventID, fmt.Sprintf("%s version:%s", eventMessage, teleport.Version)); err != nil {
+ return trace.Wrap(err, "writing log message")
+ }
+
+ return trace.Wrap(log.Close(), "closing logger")
+}
diff --git a/lib/windowsservice/run_windows.go b/lib/windowsservice/run_windows.go
new file mode 100644
index 0000000000000..896febba00dbe
--- /dev/null
+++ b/lib/windowsservice/run_windows.go
@@ -0,0 +1,153 @@
+// Teleport
+// Copyright (C) 2026 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package windowsservice
+
+import (
+ "cmp"
+ "context"
+ "errors"
+ "log/slog"
+ "os"
+ "strconv"
+ "time"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sys/windows/svc"
+
+ "github.com/gravitational/teleport"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
+)
+
+const defaultTerminateTimeout = 30 * time.Second
+
+// ServiceHandler abstracts the core service workload behind a Windows service.
+type ServiceHandler interface {
+ // Execute will be called by the package code at the start of
+ // the service, and the service will exit once Execute completes.
+ Execute(ctx context.Context, args []string) error
+}
+
+// RunConfig defines the inputs for running a Windows service.
+type RunConfig struct {
+ // Name is the Windows service name registered with the SCM.
+ Name string
+ // Handler runs the service workload.
+ Handler ServiceHandler
+ // Logger is logger for the service.
+ Logger *slog.Logger
+ // TerminateTimeout bounds how long the service waits for shutdown.
+ // If zero, a default timeout is used.
+ TerminateTimeout time.Duration
+}
+
+// runner wires a handler into the Windows service lifecycle.
+type runner struct {
+ handler ServiceHandler
+ logger *slog.Logger
+ terminateTimeout time.Duration
+}
+
+// InitSlogEventLogger sets up a new slog handler that writes to the Windows Event Log as source.
+func InitSlogEventLogger(source string) (func() error, error) {
+ level := slog.LevelInfo
+ if envVar := os.Getenv(teleport.VerboseLogsEnvVar); envVar != "" {
+ isDebug, err := strconv.ParseBool(envVar)
+ if err != nil {
+ return nil, trace.Wrap(err, "parsing %s", teleport.VerboseLogsEnvVar)
+ }
+ if isDebug {
+ level = slog.LevelDebug
+ }
+ }
+
+ handler, close, err := logutils.NewSlogEventLogHandler(source, level)
+ if err != nil {
+ return nil, trace.Wrap(err, "initializing log handler")
+ }
+ slog.SetDefault(slog.New(handler))
+ return close, nil
+}
+
+// Run wires logging, runs the service, and closes logging resources.
+func Run(cfg *RunConfig) error {
+ if cfg.Name == "" {
+ return trace.BadParameter("service name is required")
+ }
+ if cfg.Handler == nil {
+ return trace.BadParameter("handler is required")
+ }
+
+ terminateTimeout := cfg.TerminateTimeout
+ if terminateTimeout == 0 {
+ terminateTimeout = defaultTerminateTimeout
+ }
+
+ err := svc.Run(cfg.Name, &runner{
+ handler: cfg.Handler,
+ logger: cfg.Logger,
+ terminateTimeout: terminateTimeout,
+ })
+ return trace.Wrap(err, "running Windows service")
+}
+
+// Execute implements [svc.Handler.Execute].
+func (s *runner) Execute(args []string, requests <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
+ const cmdsAccepted = svc.AcceptStop // Interrogate is always accepted and there is no const for it.
+ status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ errCh := make(chan error)
+ go func() { errCh <- s.handler.Execute(ctx, args) }()
+
+ var terminateTimedOut <-chan time.Time
+loop:
+ for {
+ select {
+ case request := <-requests:
+ switch request.Cmd {
+ case svc.Interrogate:
+ state := svc.Running
+ if ctx.Err() != nil {
+ state = svc.StopPending
+ }
+ status <- svc.Status{State: state, Accepts: cmdsAccepted}
+ case svc.Stop:
+ s.logger.InfoContext(ctx, "Received stop command, shutting down service")
+ // Cancel the context passed to s.handler.Execute to terminate the service.
+ cancel()
+ terminateTimedOut = cmp.Or(terminateTimedOut, time.After(s.terminateTimeout))
+ status <- svc.Status{State: svc.StopPending}
+ }
+ case <-terminateTimedOut:
+ s.logger.ErrorContext(ctx, "Service failed to terminate within timeout, exiting process",
+ slog.Duration("timeout", s.terminateTimeout))
+ exitCode = 1
+ break loop
+ case err := <-errCh:
+ if err == nil || errors.Is(err, context.Canceled) {
+ s.logger.InfoContext(ctx, "Service terminated")
+ } else {
+ s.logger.ErrorContext(ctx, "Service terminated", "error", err)
+ exitCode = 1
+ }
+ break loop
+ }
+ }
+ status <- svc.Status{State: svc.Stopped, Win32ExitCode: exitCode}
+ return false, exitCode
+}