Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions lib/autoupdate/agent/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ type LocalInstaller struct {
// ReservedFreeInstallDisk is the amount of disk that must remain free in the install directory.
ReservedFreeInstallDisk uint64
// TransformService transforms the systemd service during copying.
TransformService func(cfg []byte, pathDir string) []byte
TransformService func(cfg []byte, pathDir string, flags autoupdate.InstallFlags) []byte
// ValidateBinary returns true if a file is a linkable binary, or
// false if a file should not be linked.
ValidateBinary func(ctx context.Context, path string) (bool, error)
Expand Down Expand Up @@ -417,8 +417,7 @@ func (li *LocalInstaller) Link(ctx context.Context, rev Revision, pathDir string
revert, err = li.forceLinks(ctx,
filepath.Join(versionDir, "bin"),
filepath.Join(versionDir, serviceDir, serviceName),
pathDir,
force,
pathDir, force, rev.Flags,
)
if err != nil {
return revert, trace.Wrap(err)
Expand All @@ -431,7 +430,9 @@ func (li *LocalInstaller) Link(ctx context.Context, rev Revision, pathDir string
// The revert function restores the previous linking.
// See Installer interface for additional specs.
func (li *LocalInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) {
revert, err = li.forceLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, false)
// The system package service file is always removed without flags, so pass
// no flags here to match the behavior.
revert, err = li.forceLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, false, 0)
return revert, trace.Wrap(err)
}

Expand All @@ -446,15 +447,17 @@ func (li *LocalInstaller) TryLink(ctx context.Context, revision Revision, pathDi
return trace.Wrap(li.tryLinks(ctx,
filepath.Join(versionDir, "bin"),
filepath.Join(versionDir, serviceDir, serviceName),
pathDir,
pathDir, revision.Flags,
))
}

// TryLinkSystem links the system installation to defaultPathDir, but only in the case that
// no installation of Teleport is already linked or partially linked.
// See Installer interface for additional specs.
func (li *LocalInstaller) TryLinkSystem(ctx context.Context) error {
return trace.Wrap(li.tryLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir))
// The system package service file is always removed without flags, so pass
// no flags here to match the behavior.
return trace.Wrap(li.tryLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, 0))
}

// Unlink unlinks a version from pathDir and TargetServiceFile.
Expand All @@ -467,14 +470,16 @@ func (li *LocalInstaller) Unlink(ctx context.Context, rev Revision, pathDir stri
return trace.Wrap(li.removeLinks(ctx,
filepath.Join(versionDir, "bin"),
filepath.Join(versionDir, serviceDir, serviceName),
pathDir,
pathDir, rev.Flags,
))
}

// UnlinkSystem unlinks the system (package) version from defaultPathDir and TargetServiceFile.
// See Installer interface for additional specs.
func (li *LocalInstaller) UnlinkSystem(ctx context.Context) error {
return trace.Wrap(li.removeLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir))
// The system package service file is always linked without flags, so pass
// no flags here to match the behavior.
return trace.Wrap(li.removeLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, 0))
}

// symlink from oldname to newname
Expand All @@ -495,7 +500,7 @@ type smallFile struct {
// If successful, forceLinks may also be reverted after it returns by calling revert.
// The revert function returns true if reverting succeeds.
// If force is true, non-link files will be overwritten.
func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, force bool) (revert func(context.Context) bool, err error) {
func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, force bool, flags autoupdate.InstallFlags) (revert func(context.Context) bool, err error) {
// setup revert function
var (
revertLinks []symlink
Expand Down Expand Up @@ -585,7 +590,7 @@ func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile,

// create systemd service file

orig, err := li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize, dstBinDir)
orig, err := li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize, dstBinDir, flags)
if err != nil && !errors.Is(err, os.ErrExist) {
return revert, trace.Wrap(err, "failed to copy service")
}
Expand All @@ -598,12 +603,12 @@ func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile,
// forceCopyService uses forceCopy to copy a systemd service file from src to dst.
// The contents of both src and dst must be smaller than n.
// See forceCopy for more details.
func (li *LocalInstaller) forceCopyService(dst, src string, n int64, dstBinDir string) (orig *smallFile, err error) {
func (li *LocalInstaller) forceCopyService(dst, src string, n int64, dstBinDir string, flags autoupdate.InstallFlags) (orig *smallFile, err error) {
srcData, err := readFileAtMost(src, n)
if err != nil {
return nil, trace.Wrap(err)
}
return forceCopy(dst, li.TransformService(srcData, dstBinDir), n)
return forceCopy(dst, li.TransformService(srcData, dstBinDir, flags), n)
}

// forceLink attempts to create a symlink, atomically replacing an existing link if already present.
Expand Down Expand Up @@ -675,7 +680,7 @@ func readFileAtMost(name string, n int64) ([]byte, error) {
return data, trace.Wrap(err)
}

func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error {
func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, flags autoupdate.InstallFlags) error {
removeService := false
entries, err := os.ReadDir(srcBinDir)
if err != nil {
Expand Down Expand Up @@ -726,7 +731,7 @@ func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile
if err != nil {
return trace.Wrap(err)
}
if !bytes.Equal(li.TransformService(srcBytes, dstBinDir), dstBytes) {
if !bytes.Equal(li.TransformService(srcBytes, dstBinDir, flags), dstBytes) {
li.Log.WarnContext(ctx, "Removed teleport binary link, but skipping removal of custom teleport.service: the service file does not match the reference file for this version. The file might have been manually edited.")
return nil
}
Expand All @@ -740,7 +745,7 @@ func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile
// Existing links that point to files outside binDir or svcDir, as well as existing non-link files, will error.
// tryLinks will not attempt to create any links if linking could result in an error.
// However, concurrent changes to links may result in an error with partially-complete linking.
func (li *LocalInstaller) tryLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error {
func (li *LocalInstaller) tryLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, flags autoupdate.InstallFlags) error {
// ensure source directory exists
entries, err := os.ReadDir(srcBinDir)
if errors.Is(err, os.ErrNotExist) {
Expand Down Expand Up @@ -798,7 +803,7 @@ func (li *LocalInstaller) tryLinks(ctx context.Context, srcBinDir, srcSvcFile, d
}

// if any binaries are linked from srcBinDir, always link the service from svcDir
_, err = li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize, dstBinDir)
_, err = li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize, dstBinDir, flags)
if err != nil && !errors.Is(err, os.ErrExist) {
return trace.Wrap(err, "failed to copy service")
}
Expand Down
44 changes: 22 additions & 22 deletions lib/autoupdate/agent/installer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func TestLocalInstaller_Link(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
versionsDir := t.TempDir()
versionDir := filepath.Join(versionsDir, version)
versionDir := filepath.Join(versionsDir, version+"_ent")
err := os.MkdirAll(versionDir, 0o755)
require.NoError(t, err)

Expand Down Expand Up @@ -459,14 +459,14 @@ func TestLocalInstaller_Link(t *testing.T) {
InstallDir: versionsDir,
TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
Log: slog.Default(),
TransformService: func(b []byte, pathDir string) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s]", string(b), pathDir))
TransformService: func(b []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s][flags=%s]", string(b), pathDir, flags.Strings()))
},
ValidateBinary: validator.IsExecutable,
Template: autoupdate.DefaultCDNURITemplate,
}
ctx := context.Background()
revert, err := installer.Link(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"), tt.force)
revert, err := installer.Link(ctx, NewRevision(version, autoupdate.FlagEnterprise), filepath.Join(linkDir, "bin"), tt.force)
if tt.errMatch != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMatch)
Expand Down Expand Up @@ -499,7 +499,7 @@ func TestLocalInstaller_Link(t *testing.T) {
for _, svc := range tt.resultServices {
v, err := os.ReadFile(filepath.Join(linkDir, svc))
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("[service=%s][path=%s]", filepath.Base(svc), filepath.Join(linkDir, "bin")), string(v))
require.Equal(t, fmt.Sprintf("[service=%s][path=%s][flags=[Enterprise]]", filepath.Base(svc), filepath.Join(linkDir, "bin")), string(v))
}

// verify manual revert
Expand Down Expand Up @@ -680,7 +680,7 @@ func TestLocalInstaller_TryLink(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
versionsDir := t.TempDir()
versionDir := filepath.Join(versionsDir, version)
versionDir := filepath.Join(versionsDir, version+"_ent")
err := os.MkdirAll(versionDir, 0o755)
require.NoError(t, err)

Expand Down Expand Up @@ -714,13 +714,13 @@ func TestLocalInstaller_TryLink(t *testing.T) {
InstallDir: versionsDir,
TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
Log: slog.Default(),
TransformService: func(b []byte, pathDir string) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s]", string(b), pathDir))
TransformService: func(b []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s][flags=%s]", string(b), pathDir, flags.Strings()))
},
ValidateBinary: validator.IsExecutable,
}
ctx := context.Background()
err = installer.TryLink(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"))
err = installer.TryLink(ctx, NewRevision(version, autoupdate.FlagEnterprise), filepath.Join(linkDir, "bin"))
if tt.errMatch != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMatch)
Expand Down Expand Up @@ -749,7 +749,7 @@ func TestLocalInstaller_TryLink(t *testing.T) {
for _, svc := range tt.resultServices {
v, err := os.ReadFile(filepath.Join(linkDir, svc))
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("[service=%s][path=%s]", filepath.Base(svc), filepath.Join(linkDir, "bin")), string(v))
require.Equal(t, fmt.Sprintf("[service=%s][path=%s][flags=[Enterprise]]", filepath.Base(svc), filepath.Join(linkDir, "bin")), string(v))
}

})
Expand Down Expand Up @@ -851,8 +851,8 @@ func TestLocalInstaller_Remove(t *testing.T) {
InstallDir: versionsDir,
TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
Log: slog.Default(),
TransformService: func(b []byte, pathDir string) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s]", string(b), pathDir))
TransformService: func(b []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s][flags=%s]", string(b), pathDir, flags.Strings()))
},
ValidateBinary: validator.IsExecutable,
}
Expand All @@ -864,7 +864,7 @@ func TestLocalInstaller_Remove(t *testing.T) {
return
}
require.NoError(t, err)
_, err = os.Stat(filepath.Join(versionDir, "bin", tt.removeVersion))
_, err = os.Stat(filepath.Join(versionsDir, tt.removeVersion))
require.ErrorIs(t, err, os.ErrNotExist)
})
}
Expand Down Expand Up @@ -921,8 +921,8 @@ func TestLocalInstaller_IsLinked(t *testing.T) {
InstallDir: versionsDir,
TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
Log: slog.Default(),
TransformService: func(b []byte, pathDir string) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s]", string(b), pathDir))
TransformService: func(b []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s][flags=%s]", string(b), pathDir, flags.Strings()))
},
ValidateBinary: validator.IsExecutable,
}
Expand Down Expand Up @@ -981,7 +981,7 @@ func TestLocalInstaller_Unlink(t *testing.T) {
{oldname: "bin/teleport", newname: "bin/teleport"},
{oldname: "bin/tsh", newname: "bin/tsh"},
},
svcCopy: []byte("[service=orig][path=bin]"),
svcCopy: []byte("[service=orig][path=bin][flags=[]]"),
},
{
name: "different services",
Expand Down Expand Up @@ -1021,7 +1021,7 @@ func TestLocalInstaller_Unlink(t *testing.T) {
links: []symlink{
{oldname: "bin/tsh", newname: "bin/tsh"},
},
svcCopy: []byte("[service=orig][path=bin]"),
svcCopy: []byte("[service=orig][path=bin][flags=[]]"),
remaining: []string{servicePath},
},
{
Expand All @@ -1031,7 +1031,7 @@ func TestLocalInstaller_Unlink(t *testing.T) {
links: []symlink{
{oldname: "bin/teleport", newname: "bin/teleport"},
},
svcCopy: []byte("[service=orig][path=bin]"),
svcCopy: []byte("[service=orig][path=bin][flags=[]]"),
},
{
name: "wrong teleport link",
Expand All @@ -1041,7 +1041,7 @@ func TestLocalInstaller_Unlink(t *testing.T) {
{oldname: "other", newname: "bin/teleport"},
{oldname: "bin/tsh", newname: "bin/tsh"},
},
svcCopy: []byte("[service=orig][path=bin]"),
svcCopy: []byte("[service=orig][path=bin][flags=[]]"),
remaining: []string{servicePath, "bin/teleport"},
},
{
Expand All @@ -1052,7 +1052,7 @@ func TestLocalInstaller_Unlink(t *testing.T) {
{oldname: "bin/teleport", newname: "bin/teleport"},
{oldname: "wrong", newname: "bin/tsh"},
},
svcCopy: []byte("[service=orig][path=bin]"),
svcCopy: []byte("[service=orig][path=bin][flags=[]]"),
remaining: []string{"bin/tsh"},
},
}
Expand Down Expand Up @@ -1107,8 +1107,8 @@ func TestLocalInstaller_Unlink(t *testing.T) {
InstallDir: versionsDir,
TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
Log: slog.Default(),
TransformService: func(b []byte, pathDir string) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s]", string(b), filepath.Base(pathDir)))
TransformService: func(b []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
return []byte(fmt.Sprintf("[service=%s][path=%s][flags=%s]", string(b), filepath.Base(pathDir), flags.Strings()))
},
}
ctx := context.Background()
Expand Down
19 changes: 18 additions & 1 deletion lib/autoupdate/agent/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/gravitational/trace"
"gopkg.in/yaml.v3"

"github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/defaults"
libutils "github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -421,10 +422,14 @@ func writeSystemTemplate(path, t string, values any) error {
}

// ReplaceTeleportService replaces the default paths in the Teleport service config with namespaced paths.
func (ns *Namespace) ReplaceTeleportService(cfg []byte, pathDir string) []byte {
func (ns *Namespace) ReplaceTeleportService(cfg []byte, pathDir string, flags autoupdate.InstallFlags) []byte {
if pathDir == "" {
pathDir = ns.defaultPathDir
}
var startFlags []string
if flags&autoupdate.FlagFIPS != 0 {
startFlags = append(startFlags, "--fips")
}
for _, rep := range []struct {
old, new string
}{
Expand All @@ -440,12 +445,24 @@ func (ns *Namespace) ReplaceTeleportService(cfg []byte, pathDir string) []byte {
old: "/run/teleport.pid",
new: ns.pidFile,
},
{
old: "/teleport start ",
new: "/teleport start " + joinTerminal(startFlags, " "),
},
} {
cfg = bytes.ReplaceAll(cfg, []byte(rep.old), []byte(rep.new))
}
return cfg
}

func joinTerminal(s []string, sep string) string {
v := strings.Join(s, sep)
if len(v) > 0 {
return v + sep
}
return v
}

func (ns *Namespace) LogWarnings(ctx context.Context, pathDir string) {
if ns.name == "" {
return
Expand Down
Loading
Loading