diff --git a/lib/vnet/install_service_windows.go b/lib/vnet/install_service_windows.go index fac0d068ddef8..5699b0200c3ec 100644 --- a/lib/vnet/install_service_windows.go +++ b/lib/vnet/install_service_windows.go @@ -18,170 +18,41 @@ package vnet import ( "context" - "errors" - "fmt" "os" "path/filepath" - "strings" "github.com/gravitational/trace" "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc/eventlog" - "golang.org/x/sys/windows/svc/mgr" - "github.com/gravitational/teleport" - eventlogutils "github.com/gravitational/teleport/lib/utils/log/eventlog" + "github.com/gravitational/teleport/lib/windowsservice" ) -// InstallService installs the VNet windows service. -// -// Windows services are installed by the service manager, which takes a path to -// the service executable. So that regular users are not able to overwrite the -// executable at that path, we use a path under %PROGRAMFILES%, which is not -// writable by regular users by default. -func InstallService(ctx context.Context) (err error) { +const eventSource = "vnet" + +// InstallService installs the VNet Windows service. +func InstallService(ctx context.Context) error { tshPath, err := os.Executable() if err != nil { return trace.Wrap(err, "getting current exe path") } - if err := assertTshInProgramFiles(tshPath); err != nil { - return trace.Wrap(err, "checking if tsh.exe is installed under %%PROGRAMFILES%%") - } if err := assertWintunInstalled(tshPath); err != nil { return trace.Wrap(err, "checking if wintun.dll is installed next to %s", tshPath) } - - svcMgr, err := mgr.Connect() - if err != nil { - return trace.Wrap(err, "connecting to Windows service manager") - } - svc, err := svcMgr.OpenService(serviceName) - if err != nil { - if !errors.Is(err, windows.ERROR_SERVICE_DOES_NOT_EXIST) { - return trace.Wrap(err, "unexpected error checking if Windows service %s exists", serviceName) - } - // The service has not been created yet and must be installed. - svc, err = svcMgr.CreateService( - serviceName, - tshPath, - mgr.Config{ - StartType: mgr.StartManual, - }, - ServiceCommand, - ) - if err != nil { - return trace.Wrap(err, "creating VNet Windows service") - } - } - if err := svc.Close(); err != nil { - return trace.Wrap(err, "closing VNet Windows service") - } - if err := grantServiceRights(); err != nil { - return trace.Wrap(err, "granting authenticated users permission to control the VNet Windows service") - } - if err := installEventSource(); err != nil { - trace.Wrap(err, "creating event source for logging") - } - if err := logInstallationEvent("VNet service installed"); err != nil { - trace.Wrap(err, "logging installation event") - } - return nil -} - -// UninstallService uninstalls the VNet windows service. -func UninstallService(ctx context.Context) (err error) { - svcMgr, err := mgr.Connect() - if err != nil { - return trace.Wrap(err, "connecting to Windows service manager") - } - svc, err := svcMgr.OpenService(serviceName) - if err != nil { - return trace.Wrap(err, "opening Windows service %s", serviceName) - } - if err := svc.Delete(); err != nil { - return trace.Wrap(err, "deleting Windows service %s", serviceName) - } - if err := svc.Close(); err != nil { - return trace.Wrap(err, "closing VNet Windows service") - } - - if err := logInstallationEvent("VNet service uninstalled"); err != nil { - trace.Wrap(err, "logging installation event") - } - if err := eventlogutils.Remove(eventlogutils.LogName, eventSource); err != nil { - return trace.Wrap(err, "removing event source for logging") - } - - return nil -} - -func grantServiceRights() error { - // Get the current security info for the service, requesting only the DACL - // (discretionary access control list). - si, err := windows.GetNamedSecurityInfo(serviceName, windows.SE_SERVICE, windows.DACL_SECURITY_INFORMATION) - if err != nil { - return trace.Wrap(err, "getting current service security information") - } - // Get the DACL from the security info. - dacl, _ /*defaulted*/, err := si.DACL() - if err != nil { - return trace.Wrap(err, "getting current service DACL") - } - // This is the universal well-known SID for "Authenticated Users". - authenticatedUsersSID, err := windows.StringToSid("S-1-5-11") - if err != nil { - return trace.Wrap(err, "parsing authenticated users SID") - } - // Build an explicit access entry allowing authenticated users to start, - // stop, and query the service. - ea := []windows.EXPLICIT_ACCESS{{ + return trace.Wrap(windowsservice.Install(ctx, &windowsservice.InstallConfig{ + Name: serviceName, + Description: serviceDescription, + Command: ServiceCommand, + EventSourceName: eventSource, AccessPermissions: windows.SERVICE_QUERY_STATUS | windows.SERVICE_START | windows.SERVICE_STOP, - AccessMode: windows.GRANT_ACCESS, - Trustee: windows.TRUSTEE{ - TrusteeForm: windows.TRUSTEE_IS_SID, - TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP, - TrusteeValue: windows.TrusteeValueFromSID(authenticatedUsersSID), - }, - }} - // Merge the new explicit access entry with the existing DACL. - dacl, err = windows.ACLFromEntries(ea, dacl) - if err != nil { - return trace.Wrap(err, "merging service DACL entries") - } - // Set the DACL on the service security info. - if err := windows.SetNamedSecurityInfo( - serviceName, - windows.SE_SERVICE, - windows.DACL_SECURITY_INFORMATION, - nil, // owner - nil, // group - dacl, // dacl - nil, // sacl - ); err != nil { - return trace.Wrap(err, "setting service DACL") - } - return nil + })) } -// assertTshInProgramFiles asserts that tsh is a regular file installed under -// the program files directory (usually C:\Program Files\). -func assertTshInProgramFiles(tshPath string) error { - if err := assertRegularFile(tshPath); err != nil { - return trace.Wrap(err) - } - programFiles := os.Getenv("PROGRAMFILES") - if programFiles == "" { - return trace.Errorf("PROGRAMFILES env var is not set") - } - // Windows file paths are case-insensitive. - cleanedProgramFiles := strings.ToLower(filepath.Clean(programFiles)) + string(filepath.Separator) - cleanedTshPath := strings.ToLower(filepath.Clean(tshPath)) - if !strings.HasPrefix(cleanedTshPath, cleanedProgramFiles) { - return trace.BadParameter( - "tsh.exe is currently installed at %s, it must be installed under %s in order to install the VNet Windows service", - tshPath, programFiles) - } - return nil +// UninstallService uninstalls the Windows VNet service. +func UninstallService(ctx context.Context) error { + return trace.Wrap(windowsservice.Uninstall(ctx, &windowsservice.UninstallConfig{ + Name: serviceName, + EventSourceName: eventSource, + })) } // asertWintunInstalled returns an error if wintun.dll is not a regular file @@ -203,34 +74,3 @@ func assertRegularFile(path string) error { } return nil } - -const eventSource = "vnet" - -func installEventSource() error { - exe, err := os.Executable() - if err != nil { - return trace.Wrap(err) - } - // Assume that the message file is shipped next to tsh.exe. - msgFilePath := filepath.Join(filepath.Dir(exe), "msgfile.dll") - - // This should create a registry entry under - // SYSTEM\CurrentControlSet\Services\EventLog\Teleport\vnet with an absolute path to msgfile.dll. - // If the user moves Teleport Connect to some other directory, logs will still be captured, but - // they might display a message about missing event ID until the user reinstalls the app. - err = eventlogutils.Install(eventlogutils.LogName, eventSource, msgFilePath, false /* useExpandKey */) - return trace.Wrap(err) -} - -func logInstallationEvent(eventMessage string) error { - log, err := eventlog.Open(eventSource) - if err != nil { - return trace.Wrap(err, "opening logger") - } - - if err := log.Info(eventlogutils.EventID, fmt.Sprintf("%s version:%s", eventMessage, teleport.Version)); err != nil { - return trace.Wrap(err, "writing log message") - } - - return trace.Wrap(log.Close(), "closing logger") -} diff --git a/lib/vnet/service_windows.go b/lib/vnet/service_windows.go index d672cdfd841da..9823ef6279ab7 100644 --- a/lib/vnet/service_windows.go +++ b/lib/vnet/service_windows.go @@ -17,12 +17,8 @@ package vnet import ( - "cmp" "context" - "errors" "log/slog" - "os" - "strconv" "syscall" "time" @@ -33,7 +29,7 @@ import ( "golang.org/x/sys/windows/svc/mgr" "github.com/gravitational/teleport" - logutils "github.com/gravitational/teleport/lib/utils/log" + "github.com/gravitational/teleport/lib/windowsservice" ) const ( @@ -126,83 +122,23 @@ func startService(ctx context.Context, cfg *windowsAdminProcessConfig) (*mgr.Ser // ServiceMain runs the Windows VNet admin service. func ServiceMain() error { - closeFn, err := setupServiceLogger() + closeLogger, err := windowsservice.InitSlogEventLogger(eventSource) if err != nil { - return trace.Wrap(err, "setting up logger for service") - } - - if err := svc.Run(serviceName, &windowsService{}); err != nil { - closeFn() - return trace.Wrap(err, "running Windows service") + return trace.Wrap(err) } - - return trace.Wrap(closeFn(), "closing logger") -} - -// windowsService implements [svc.Handler]. -type windowsService struct{} - -// Execute implements [svc.Handler.Execute], the GoDoc is copied below. -// -// Execute will be called by the package code at the start of the service, and -// the service will exit once Execute completes. Inside Execute you must read -// service change requests from [requests] and act accordingly. You must keep -// service control manager up to date about state of your service by writing -// into [status] as required. args contains service name followed by argument -// strings passed to the service. -// You can provide service exit code in exitCode return parameter, with 0 being -// "no error". You can also indicate if exit code, if any, is service specific -// or not by using svcSpecificEC parameter. -func (s *windowsService) Execute(args []string, requests <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { logger := slog.With(teleport.ComponentKey, teleport.Component("vnet", "windows-service")) - const cmdsAccepted = svc.AcceptStop // Interrogate is always accepted and there is no const for it. - status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - errCh := make(chan error) - go func() { errCh <- s.run(ctx, args) }() - var terminateTimedOut <-chan time.Time -loop: - for { - select { - case request := <-requests: - switch request.Cmd { - case svc.Interrogate: - state := svc.Running - if ctx.Err() != nil { - state = svc.StopPending - } - status <- svc.Status{State: state, Accepts: cmdsAccepted} - case svc.Stop: - logger.InfoContext(ctx, "Received stop command, shutting down service") - // Cancel the context passed to s.run to terminate the - // networking stack. - cancel() - terminateTimedOut = cmp.Or(terminateTimedOut, time.After(terminateTimeout)) - status <- svc.Status{State: svc.StopPending} - } - case <-terminateTimedOut: - logger.ErrorContext(ctx, "Networking stack failed to terminate within timeout, exiting process", - slog.Duration("timeout", terminateTimeout)) - exitCode = 1 - break loop - case err := <-errCh: - if err == nil || errors.Is(err, context.Canceled) { - logger.InfoContext(ctx, "Service terminated") - } else { - logger.ErrorContext(ctx, "Service terminated", "error", err) - exitCode = 1 - } - break loop - } - } - status <- svc.Status{State: svc.Stopped, Win32ExitCode: exitCode} - return false, exitCode + err = windowsservice.Run(&windowsservice.RunConfig{ + Name: serviceName, + Handler: &handler{}, + Logger: logger, + }) + return trace.NewAggregate(err, closeLogger()) } -func (s *windowsService) run(ctx context.Context, args []string) error { +type handler struct{} + +func (w *handler) Execute(ctx context.Context, args []string) error { var cfg windowsAdminProcessConfig app := kingpin.New(serviceName, "Teleport VNet Windows Service") serviceCmd := app.Command("vnet-service", "Start the VNet service.") @@ -221,23 +157,3 @@ func (s *windowsService) run(ctx context.Context, args []string) error { } return nil } - -func setupServiceLogger() (func() error, error) { - level := slog.LevelInfo - if envVar := os.Getenv(teleport.VerboseLogsEnvVar); envVar != "" { - isDebug, err := strconv.ParseBool(envVar) - if err != nil { - return nil, trace.Wrap(err, "parsing %s", teleport.VerboseLogsEnvVar) - } - if isDebug { - level = slog.LevelDebug - } - } - - handler, close, err := logutils.NewSlogEventLogHandler("vnet", level) - if err != nil { - return nil, trace.Wrap(err, "initializing log handler") - } - slog.SetDefault(slog.New(handler)) - return close, nil -} diff --git a/lib/windowsservice/install_windows.go b/lib/windowsservice/install_windows.go new file mode 100644 index 0000000000000..546791d1b229d --- /dev/null +++ b/lib/windowsservice/install_windows.go @@ -0,0 +1,267 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package windowsservice + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/gravitational/trace" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" + + "github.com/gravitational/teleport" + eventlogutils "github.com/gravitational/teleport/lib/utils/log/eventlog" +) + +// InstallConfig defines parameters for installing a Windows service +// that is implemented by tsh.exe. +type InstallConfig struct { + // Name is the service name. + Name string + // Description is the service description. + Description string + // Command is the tsh subcommand that the service manager invokes on start. + Command string + // EventSourceName is the name of an event source that will log service events. + EventSourceName string + // AccessPermissions defines which service control actions are granted to + // authenticated users (e.g., start/stop/query). + AccessPermissions windows.ACCESS_MASK +} + +// Install installs a Windows service implemented by tsh.exe. +// +// Windows services are installed by the service manager, which takes a path to +// the service executable. So that regular users are not able to overwrite the +// executable at that path, we use a path under %PROGRAMFILES%, which is not +// writable by regular users by default. +func Install(ctx context.Context, cfg *InstallConfig) (err error) { + if cfg.Name == "" { + return trace.BadParameter("service name is required") + } + if cfg.Command == "" { + return trace.BadParameter("command is required") + } + if cfg.EventSourceName == "" { + return trace.BadParameter("event source name is required") + } + if cfg.AccessPermissions == 0 { + return trace.BadParameter("access permissions is required") + } + + tshPath, err := os.Executable() + if err != nil { + return trace.Wrap(err, "getting current exe path") + } + if err := assertTshInProgramFiles(tshPath); err != nil { + return trace.Wrap(err, "checking if tsh.exe is installed under %%PROGRAMFILES%%") + } + + svcMgr, err := mgr.Connect() + if err != nil { + return trace.Wrap(err, "connecting to Windows service manager") + } + svc, err := svcMgr.OpenService(cfg.Name) + if err != nil { + if !errors.Is(err, windows.ERROR_SERVICE_DOES_NOT_EXIST) { + return trace.Wrap(err, "unexpected error checking if Windows service %s exists", cfg.Name) + } + // The service has not been created yet and must be installed. + svc, err = svcMgr.CreateService( + cfg.Name, + tshPath, + mgr.Config{ + StartType: mgr.StartManual, + Description: cfg.Description, + }, + cfg.Command, + ) + if err != nil { + return trace.Wrap(err, "creating VNet Windows service") + } + } + if err := svc.Close(); err != nil { + return trace.Wrap(err, "closing VNet Windows service") + } + if err := grantServiceRights(cfg.Name, cfg.AccessPermissions); err != nil { + return trace.Wrap(err, "granting authenticated users permission to control the VNet Windows service") + } + if err := installEventSource(cfg.EventSourceName); err != nil { + return trace.Wrap(err, "creating event source for logging") + } + if err := logInstallationEvent(cfg.EventSourceName, fmt.Sprintf("%s service installed", cfg.Name)); err != nil { + return trace.Wrap(err, "logging installation event") + } + return nil +} + +// UninstallConfig defines parameters for removing a Windows service. +type UninstallConfig struct { + // Name is the service name. + Name string + // EventSourceName is the event source to remove from the Windows Event Log. + EventSourceName string +} + +// Uninstall uninstalls the Windows service. +func Uninstall(ctx context.Context, cfg *UninstallConfig) (err error) { + if cfg.Name == "" { + return trace.BadParameter("service name is required") + } + if cfg.EventSourceName == "" { + return trace.BadParameter("event source name is required") + } + svcMgr, err := mgr.Connect() + if err != nil { + return trace.Wrap(err, "connecting to Windows service manager") + } + svc, err := svcMgr.OpenService(cfg.Name) + if err != nil { + return trace.Wrap(err, "opening Windows service %s", cfg.Name) + } + if err := svc.Delete(); err != nil { + return trace.Wrap(err, "deleting Windows service %s", cfg.Name) + } + if err := svc.Close(); err != nil { + return trace.Wrap(err, "closing VNet Windows service") + } + + if err := logInstallationEvent(cfg.EventSourceName, fmt.Sprintf("%s service uninstalled", cfg.Name)); err != nil { + return trace.Wrap(err, "logging installation event") + } + if err := eventlogutils.Remove(eventlogutils.LogName, cfg.EventSourceName); err != nil { + return trace.Wrap(err, "removing event source for logging") + } + + return nil +} + +func grantServiceRights(name string, accessPermissions windows.ACCESS_MASK) error { + // Get the current security info for the service, requesting only the DACL + // (discretionary access control list). + si, err := windows.GetNamedSecurityInfo(name, windows.SE_SERVICE, windows.DACL_SECURITY_INFORMATION) + if err != nil { + return trace.Wrap(err, "getting current service security information") + } + // Get the DACL from the security info. + dacl, _ /*defaulted*/, err := si.DACL() + if err != nil { + return trace.Wrap(err, "getting current service DACL") + } + // This is the universal well-known SID for "Authenticated Users". + authenticatedUsersSID, err := windows.StringToSid("S-1-5-11") + if err != nil { + return trace.Wrap(err, "parsing authenticated users SID") + } + // Build an explicit access entry allowing authenticated users to start, + // stop, and query the service. + ea := []windows.EXPLICIT_ACCESS{{ + AccessPermissions: accessPermissions, + AccessMode: windows.GRANT_ACCESS, + Trustee: windows.TRUSTEE{ + TrusteeForm: windows.TRUSTEE_IS_SID, + TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + TrusteeValue: windows.TrusteeValueFromSID(authenticatedUsersSID), + }, + }} + // Merge the new explicit access entry with the existing DACL. + dacl, err = windows.ACLFromEntries(ea, dacl) + if err != nil { + return trace.Wrap(err, "merging service DACL entries") + } + // Set the DACL on the service security info. + if err := windows.SetNamedSecurityInfo( + name, + windows.SE_SERVICE, + windows.DACL_SECURITY_INFORMATION, + nil, // owner + nil, // group + dacl, // dacl + nil, // sacl + ); err != nil { + return trace.Wrap(err, "setting service DACL") + } + return nil +} + +// assertTshInProgramFiles asserts that tsh is a regular file installed under +// the program files directory (usually C:\Program Files\). +func assertTshInProgramFiles(tshPath string) error { + if err := assertRegularFile(tshPath); err != nil { + return trace.Wrap(err) + } + programFiles := os.Getenv("PROGRAMFILES") + if programFiles == "" { + return trace.Errorf("PROGRAMFILES env var is not set") + } + // Windows file paths are case-insensitive. + cleanedProgramFiles := strings.ToLower(filepath.Clean(programFiles)) + string(filepath.Separator) + cleanedTshPath := strings.ToLower(filepath.Clean(tshPath)) + if !strings.HasPrefix(cleanedTshPath, cleanedProgramFiles) { + return trace.BadParameter( + "tsh.exe is currently installed at %s, it must be installed under %s in order to install the VNet Windows service", + tshPath, programFiles) + } + return nil +} + +func assertRegularFile(path string) error { + switch info, err := os.Lstat(path); { + case os.IsNotExist(err): + return trace.Wrap(err, "%s not found", path) + case err != nil: + return trace.Wrap(err, "unexpected error checking %s", path) + case !info.Mode().IsRegular(): + return trace.BadParameter("%s is not a regular file", path) + } + return nil +} + +func installEventSource(name string) error { + exe, err := os.Executable() + if err != nil { + return trace.Wrap(err) + } + // Assume that the message file is shipped next to tsh.exe. + msgFilePath := filepath.Join(filepath.Dir(exe), "msgfile.dll") + + // This should create a registry entry under + // SYSTEM\CurrentControlSet\Services\EventLog\Teleport\ with an absolute path to msgfile.dll. + // If the user moves Teleport Connect to some other directory, logs will still be captured, but + // they might display a message about missing event ID until the user reinstalls the app. + err = eventlogutils.Install(eventlogutils.LogName, name, msgFilePath, false /* useExpandKey */) + return trace.Wrap(err) +} + +func logInstallationEvent(name string, eventMessage string) error { + log, err := eventlog.Open(name) + if err != nil { + return trace.Wrap(err, "opening logger") + } + + if err := log.Info(eventlogutils.EventID, fmt.Sprintf("%s version:%s", eventMessage, teleport.Version)); err != nil { + return trace.Wrap(err, "writing log message") + } + + return trace.Wrap(log.Close(), "closing logger") +} diff --git a/lib/windowsservice/run_windows.go b/lib/windowsservice/run_windows.go new file mode 100644 index 0000000000000..896febba00dbe --- /dev/null +++ b/lib/windowsservice/run_windows.go @@ -0,0 +1,153 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package windowsservice + +import ( + "cmp" + "context" + "errors" + "log/slog" + "os" + "strconv" + "time" + + "github.com/gravitational/trace" + "golang.org/x/sys/windows/svc" + + "github.com/gravitational/teleport" + logutils "github.com/gravitational/teleport/lib/utils/log" +) + +const defaultTerminateTimeout = 30 * time.Second + +// ServiceHandler abstracts the core service workload behind a Windows service. +type ServiceHandler interface { + // Execute will be called by the package code at the start of + // the service, and the service will exit once Execute completes. + Execute(ctx context.Context, args []string) error +} + +// RunConfig defines the inputs for running a Windows service. +type RunConfig struct { + // Name is the Windows service name registered with the SCM. + Name string + // Handler runs the service workload. + Handler ServiceHandler + // Logger is logger for the service. + Logger *slog.Logger + // TerminateTimeout bounds how long the service waits for shutdown. + // If zero, a default timeout is used. + TerminateTimeout time.Duration +} + +// runner wires a handler into the Windows service lifecycle. +type runner struct { + handler ServiceHandler + logger *slog.Logger + terminateTimeout time.Duration +} + +// InitSlogEventLogger sets up a new slog handler that writes to the Windows Event Log as source. +func InitSlogEventLogger(source string) (func() error, error) { + level := slog.LevelInfo + if envVar := os.Getenv(teleport.VerboseLogsEnvVar); envVar != "" { + isDebug, err := strconv.ParseBool(envVar) + if err != nil { + return nil, trace.Wrap(err, "parsing %s", teleport.VerboseLogsEnvVar) + } + if isDebug { + level = slog.LevelDebug + } + } + + handler, close, err := logutils.NewSlogEventLogHandler(source, level) + if err != nil { + return nil, trace.Wrap(err, "initializing log handler") + } + slog.SetDefault(slog.New(handler)) + return close, nil +} + +// Run wires logging, runs the service, and closes logging resources. +func Run(cfg *RunConfig) error { + if cfg.Name == "" { + return trace.BadParameter("service name is required") + } + if cfg.Handler == nil { + return trace.BadParameter("handler is required") + } + + terminateTimeout := cfg.TerminateTimeout + if terminateTimeout == 0 { + terminateTimeout = defaultTerminateTimeout + } + + err := svc.Run(cfg.Name, &runner{ + handler: cfg.Handler, + logger: cfg.Logger, + terminateTimeout: terminateTimeout, + }) + return trace.Wrap(err, "running Windows service") +} + +// Execute implements [svc.Handler.Execute]. +func (s *runner) Execute(args []string, requests <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { + const cmdsAccepted = svc.AcceptStop // Interrogate is always accepted and there is no const for it. + status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + go func() { errCh <- s.handler.Execute(ctx, args) }() + + var terminateTimedOut <-chan time.Time +loop: + for { + select { + case request := <-requests: + switch request.Cmd { + case svc.Interrogate: + state := svc.Running + if ctx.Err() != nil { + state = svc.StopPending + } + status <- svc.Status{State: state, Accepts: cmdsAccepted} + case svc.Stop: + s.logger.InfoContext(ctx, "Received stop command, shutting down service") + // Cancel the context passed to s.handler.Execute to terminate the service. + cancel() + terminateTimedOut = cmp.Or(terminateTimedOut, time.After(s.terminateTimeout)) + status <- svc.Status{State: svc.StopPending} + } + case <-terminateTimedOut: + s.logger.ErrorContext(ctx, "Service failed to terminate within timeout, exiting process", + slog.Duration("timeout", s.terminateTimeout)) + exitCode = 1 + break loop + case err := <-errCh: + if err == nil || errors.Is(err, context.Canceled) { + s.logger.InfoContext(ctx, "Service terminated") + } else { + s.logger.ErrorContext(ctx, "Service terminated", "error", err) + exitCode = 1 + } + break loop + } + } + status <- svc.Status{State: svc.Stopped, Win32ExitCode: exitCode} + return false, exitCode +}