Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 17 additions & 177 deletions lib/vnet/install_service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,170 +18,41 @@ package vnet

import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/gravitational/trace"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"

"github.com/gravitational/teleport"
eventlogutils "github.com/gravitational/teleport/lib/utils/log/eventlog"
"github.com/gravitational/teleport/lib/windowsservice"
)

// InstallService installs the VNet windows service.
//
// Windows services are installed by the service manager, which takes a path to
// the service executable. So that regular users are not able to overwrite the
// executable at that path, we use a path under %PROGRAMFILES%, which is not
// writable by regular users by default.
func InstallService(ctx context.Context) (err error) {
const eventSource = "vnet"

// InstallService installs the VNet Windows service.
func InstallService(ctx context.Context) error {
tshPath, err := os.Executable()
if err != nil {
return trace.Wrap(err, "getting current exe path")
}
if err := assertTshInProgramFiles(tshPath); err != nil {
return trace.Wrap(err, "checking if tsh.exe is installed under %%PROGRAMFILES%%")
}
if err := assertWintunInstalled(tshPath); err != nil {
return trace.Wrap(err, "checking if wintun.dll is installed next to %s", tshPath)
}

svcMgr, err := mgr.Connect()
if err != nil {
return trace.Wrap(err, "connecting to Windows service manager")
}
svc, err := svcMgr.OpenService(serviceName)
if err != nil {
if !errors.Is(err, windows.ERROR_SERVICE_DOES_NOT_EXIST) {
return trace.Wrap(err, "unexpected error checking if Windows service %s exists", serviceName)
}
// The service has not been created yet and must be installed.
svc, err = svcMgr.CreateService(
serviceName,
tshPath,
mgr.Config{
StartType: mgr.StartManual,
},
ServiceCommand,
)
if err != nil {
return trace.Wrap(err, "creating VNet Windows service")
}
}
if err := svc.Close(); err != nil {
return trace.Wrap(err, "closing VNet Windows service")
}
if err := grantServiceRights(); err != nil {
return trace.Wrap(err, "granting authenticated users permission to control the VNet Windows service")
}
if err := installEventSource(); err != nil {
trace.Wrap(err, "creating event source for logging")
}
if err := logInstallationEvent("VNet service installed"); err != nil {
trace.Wrap(err, "logging installation event")
}
return nil
}

// UninstallService uninstalls the VNet windows service.
func UninstallService(ctx context.Context) (err error) {
svcMgr, err := mgr.Connect()
if err != nil {
return trace.Wrap(err, "connecting to Windows service manager")
}
svc, err := svcMgr.OpenService(serviceName)
if err != nil {
return trace.Wrap(err, "opening Windows service %s", serviceName)
}
if err := svc.Delete(); err != nil {
return trace.Wrap(err, "deleting Windows service %s", serviceName)
}
if err := svc.Close(); err != nil {
return trace.Wrap(err, "closing VNet Windows service")
}

if err := logInstallationEvent("VNet service uninstalled"); err != nil {
trace.Wrap(err, "logging installation event")
}
if err := eventlogutils.Remove(eventlogutils.LogName, eventSource); err != nil {
return trace.Wrap(err, "removing event source for logging")
}

return nil
}

func grantServiceRights() error {
// Get the current security info for the service, requesting only the DACL
// (discretionary access control list).
si, err := windows.GetNamedSecurityInfo(serviceName, windows.SE_SERVICE, windows.DACL_SECURITY_INFORMATION)
if err != nil {
return trace.Wrap(err, "getting current service security information")
}
// Get the DACL from the security info.
dacl, _ /*defaulted*/, err := si.DACL()
if err != nil {
return trace.Wrap(err, "getting current service DACL")
}
// This is the universal well-known SID for "Authenticated Users".
authenticatedUsersSID, err := windows.StringToSid("S-1-5-11")
if err != nil {
return trace.Wrap(err, "parsing authenticated users SID")
}
// Build an explicit access entry allowing authenticated users to start,
// stop, and query the service.
ea := []windows.EXPLICIT_ACCESS{{
return trace.Wrap(windowsservice.Install(ctx, &windowsservice.InstallConfig{
Name: serviceName,
Description: serviceDescription,
Command: ServiceCommand,
EventSourceName: eventSource,
AccessPermissions: windows.SERVICE_QUERY_STATUS | windows.SERVICE_START | windows.SERVICE_STOP,
AccessMode: windows.GRANT_ACCESS,
Trustee: windows.TRUSTEE{
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(authenticatedUsersSID),
},
}}
// Merge the new explicit access entry with the existing DACL.
dacl, err = windows.ACLFromEntries(ea, dacl)
if err != nil {
return trace.Wrap(err, "merging service DACL entries")
}
// Set the DACL on the service security info.
if err := windows.SetNamedSecurityInfo(
serviceName,
windows.SE_SERVICE,
windows.DACL_SECURITY_INFORMATION,
nil, // owner
nil, // group
dacl, // dacl
nil, // sacl
); err != nil {
return trace.Wrap(err, "setting service DACL")
}
return nil
}))
}

// assertTshInProgramFiles asserts that tsh is a regular file installed under
// the program files directory (usually C:\Program Files\).
func assertTshInProgramFiles(tshPath string) error {
if err := assertRegularFile(tshPath); err != nil {
return trace.Wrap(err)
}
programFiles := os.Getenv("PROGRAMFILES")
if programFiles == "" {
return trace.Errorf("PROGRAMFILES env var is not set")
}
// Windows file paths are case-insensitive.
cleanedProgramFiles := strings.ToLower(filepath.Clean(programFiles)) + string(filepath.Separator)
cleanedTshPath := strings.ToLower(filepath.Clean(tshPath))
if !strings.HasPrefix(cleanedTshPath, cleanedProgramFiles) {
return trace.BadParameter(
"tsh.exe is currently installed at %s, it must be installed under %s in order to install the VNet Windows service",
tshPath, programFiles)
}
return nil
// UninstallService uninstalls the Windows VNet service.
func UninstallService(ctx context.Context) error {
return trace.Wrap(windowsservice.Uninstall(ctx, &windowsservice.UninstallConfig{
Name: serviceName,
EventSourceName: eventSource,
}))
}

// asertWintunInstalled returns an error if wintun.dll is not a regular file
Expand All @@ -203,34 +74,3 @@ func assertRegularFile(path string) error {
}
return nil
}

const eventSource = "vnet"

func installEventSource() error {
exe, err := os.Executable()
if err != nil {
return trace.Wrap(err)
}
// Assume that the message file is shipped next to tsh.exe.
msgFilePath := filepath.Join(filepath.Dir(exe), "msgfile.dll")

// This should create a registry entry under
// SYSTEM\CurrentControlSet\Services\EventLog\Teleport\vnet with an absolute path to msgfile.dll.
// If the user moves Teleport Connect to some other directory, logs will still be captured, but
// they might display a message about missing event ID until the user reinstalls the app.
err = eventlogutils.Install(eventlogutils.LogName, eventSource, msgFilePath, false /* useExpandKey */)
return trace.Wrap(err)
}

func logInstallationEvent(eventMessage string) error {
log, err := eventlog.Open(eventSource)
if err != nil {
return trace.Wrap(err, "opening logger")
}

if err := log.Info(eventlogutils.EventID, fmt.Sprintf("%s version:%s", eventMessage, teleport.Version)); err != nil {
return trace.Wrap(err, "writing log message")
}

return trace.Wrap(log.Close(), "closing logger")
}
108 changes: 12 additions & 96 deletions lib/vnet/service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@
package vnet

import (
"cmp"
"context"
"errors"
"log/slog"
"os"
"strconv"
"syscall"
"time"

Expand All @@ -33,7 +29,7 @@ import (
"golang.org/x/sys/windows/svc/mgr"

"github.com/gravitational/teleport"
logutils "github.com/gravitational/teleport/lib/utils/log"
"github.com/gravitational/teleport/lib/windowsservice"
)

const (
Expand Down Expand Up @@ -126,83 +122,23 @@ func startService(ctx context.Context, cfg *windowsAdminProcessConfig) (*mgr.Ser

// ServiceMain runs the Windows VNet admin service.
func ServiceMain() error {
closeFn, err := setupServiceLogger()
closeLogger, err := windowsservice.InitSlogEventLogger(eventSource)
if err != nil {
return trace.Wrap(err, "setting up logger for service")
}

if err := svc.Run(serviceName, &windowsService{}); err != nil {
closeFn()
return trace.Wrap(err, "running Windows service")
return trace.Wrap(err)
}

return trace.Wrap(closeFn(), "closing logger")
}

// windowsService implements [svc.Handler].
type windowsService struct{}

// Execute implements [svc.Handler.Execute], the GoDoc is copied below.
//
// Execute will be called by the package code at the start of the service, and
// the service will exit once Execute completes. Inside Execute you must read
// service change requests from [requests] and act accordingly. You must keep
// service control manager up to date about state of your service by writing
// into [status] as required. args contains service name followed by argument
// strings passed to the service.
// You can provide service exit code in exitCode return parameter, with 0 being
// "no error". You can also indicate if exit code, if any, is service specific
// or not by using svcSpecificEC parameter.
func (s *windowsService) Execute(args []string, requests <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
logger := slog.With(teleport.ComponentKey, teleport.Component("vnet", "windows-service"))
const cmdsAccepted = svc.AcceptStop // Interrogate is always accepted and there is no const for it.
status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error)
go func() { errCh <- s.run(ctx, args) }()

var terminateTimedOut <-chan time.Time
loop:
for {
select {
case request := <-requests:
switch request.Cmd {
case svc.Interrogate:
state := svc.Running
if ctx.Err() != nil {
state = svc.StopPending
}
status <- svc.Status{State: state, Accepts: cmdsAccepted}
case svc.Stop:
logger.InfoContext(ctx, "Received stop command, shutting down service")
// Cancel the context passed to s.run to terminate the
// networking stack.
cancel()
terminateTimedOut = cmp.Or(terminateTimedOut, time.After(terminateTimeout))
status <- svc.Status{State: svc.StopPending}
}
case <-terminateTimedOut:
logger.ErrorContext(ctx, "Networking stack failed to terminate within timeout, exiting process",
slog.Duration("timeout", terminateTimeout))
exitCode = 1
break loop
case err := <-errCh:
if err == nil || errors.Is(err, context.Canceled) {
logger.InfoContext(ctx, "Service terminated")
} else {
logger.ErrorContext(ctx, "Service terminated", "error", err)
exitCode = 1
}
break loop
}
}
status <- svc.Status{State: svc.Stopped, Win32ExitCode: exitCode}
return false, exitCode
err = windowsservice.Run(&windowsservice.RunConfig{
Name: serviceName,
Handler: &handler{},
Logger: logger,
})
return trace.NewAggregate(err, closeLogger())
}

func (s *windowsService) run(ctx context.Context, args []string) error {
type handler struct{}

func (w *handler) Execute(ctx context.Context, args []string) error {
var cfg windowsAdminProcessConfig
app := kingpin.New(serviceName, "Teleport VNet Windows Service")
serviceCmd := app.Command("vnet-service", "Start the VNet service.")
Expand All @@ -221,23 +157,3 @@ func (s *windowsService) run(ctx context.Context, args []string) error {
}
return nil
}

func setupServiceLogger() (func() error, error) {
level := slog.LevelInfo
if envVar := os.Getenv(teleport.VerboseLogsEnvVar); envVar != "" {
isDebug, err := strconv.ParseBool(envVar)
if err != nil {
return nil, trace.Wrap(err, "parsing %s", teleport.VerboseLogsEnvVar)
}
if isDebug {
level = slog.LevelDebug
}
}

handler, close, err := logutils.NewSlogEventLogHandler("vnet", level)
if err != nil {
return nil, trace.Wrap(err, "initializing log handler")
}
slog.SetDefault(slog.New(handler))
return close, nil
}
Loading
Loading