Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions lib/vnet/opensshconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"encoding/pem"
"io"
"maps"
"os"
"path/filepath"
"slices"
Expand All @@ -37,7 +38,6 @@ import (

"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keypaths"
"github.com/gravitational/teleport/lib/cryptosuites"
)
Expand Down Expand Up @@ -190,7 +190,8 @@ func (c *sshConfigurator) updateSSHConfiguration(ctx context.Context) error {
if err != nil {
return trace.Wrap(err, "listing profiles")
}
hostMatchers := make([]string, 0, len(profileNames))
// Build a set of unique cluster names for all active clusters.
clusterNames := make(map[string]struct{})
for _, profileName := range profileNames {
rootClient, err := c.cfg.clientApplication.GetCachedClient(ctx, profileName, "" /*leafClusterName*/)
if err != nil {
Expand All @@ -199,7 +200,7 @@ func (c *sshConfigurator) updateSSHConfiguration(ctx context.Context) error {
"profile", profileName, "error", err)
continue
}
hostMatchers = append(hostMatchers, hostMatcher(rootClient.RootClusterName()))
clusterNames[rootClient.RootClusterName()] = struct{}{}
leafClusters, err := c.cfg.leafClusterCache.getLeafClusters(ctx, rootClient)
if err != nil {
log.WarnContext(ctx,
Expand All @@ -208,17 +209,10 @@ func (c *sshConfigurator) updateSSHConfiguration(ctx context.Context) error {
continue
}
for _, leafCluster := range leafClusters {
hostMatchers = append(hostMatchers, hostMatcher(leafCluster))
clusterNames[leafCluster] = struct{}{}
}
}
hostMatchers = utils.Deduplicate(hostMatchers)
slices.Sort(hostMatchers)
hostMatchersString := strings.Join(hostMatchers, " ")
return trace.Wrap(writeSSHConfigFile(c.profilePath, hostMatchersString))
}

func hostMatcher(clusterName string) string {
return "*." + strings.Trim(clusterName, ".")
return trace.Wrap(writeSSHConfigFile(c.profilePath, clusterNames))
}

func deleteSSHConfigFile(profilePath string) error {
Expand All @@ -233,28 +227,57 @@ func deleteSSHConfigFile(profilePath string) error {
return nil
}

func writeSSHConfigFile(profilePath, hostMatchers string) error {
t := template.Must(template.New("ssh_config").Parse(configFileTemplate))
func writeSSHConfigFile(profilePath string, clusterNames map[string]struct{}) error {
var b bytes.Buffer
if err := t.Execute(&b, configFileTemplateInput{
Hosts: hostMatchers,
PrivateKeyPath: strconv.Quote(keypaths.VNetClientSSHKeyPath(profilePath)),
KnownHostsPath: strconv.Quote(keypaths.VNetKnownHostsPath(profilePath)),
}); err != nil {
return trace.Wrap(err, "generating SSH config file")
b.WriteString(generatedFileHeader)
if len(clusterNames) == 0 {
// Avoid writing the Host block if there are no clusters to match.
b.WriteString("# VNet currently detects no logged-in clusters, log in to start using VNet\n")
} else {
hosts := strings.Join(hostMatchers(clusterNames), " ")
if err := configFileTemplate.Execute(&b, configFileTemplateInput{
Hosts: hosts,
PrivateKeyPath: strconv.Quote(keypaths.VNetClientSSHKeyPath(profilePath)),
KnownHostsPath: strconv.Quote(keypaths.VNetKnownHostsPath(profilePath)),
}); err != nil {
return trace.Wrap(err, "generating SSH config file")
}
}
p := keypaths.VNetSSHConfigPath(profilePath)
err := renameio.WriteFile(p, b.Bytes(), filePerms)
return trace.Wrap(trace.ConvertSystemError(err), "writing SSH config file to %s", p)
}

const configFileTemplate = `Host {{ .Hosts }}
// hostMatchers returns a sorted list of host matchers for a given set of
// cluster names.
func hostMatchers(clusterNames map[string]struct{}) []string {
sortedClusterNames := slices.Sorted(maps.Keys(clusterNames))
matchers := make([]string, 0, len(sortedClusterNames))
for _, clusterName := range sortedClusterNames {
matchers = append(matchers, hostMatcher(clusterName))
}
return matchers
}

func hostMatcher(clusterName string) string {
return "*." + strings.Trim(clusterName, ".")
}

const generatedFileHeader = `# ---------------------------------------------------------------------
# THIS FILE IS AUTOMATICALLY GENERATED BY TELEPORT VNET. DO NOT EDIT.
# Your changes will be overwritten the next time the file is generated.
# ---------------------------------------------------------------------

`

var configFileTemplate = template.Must(template.New("vnet_ssh_config").
Parse(`Host {{ .Hosts }}
IdentityFile {{ .PrivateKeyPath }}
GlobalKnownHostsFile {{ .KnownHostsPath }}
UserKnownHostsFile /dev/null
StrictHostKeyChecking yes
IdentitiesOnly yes
`
`))

type configFileTemplateInput struct {
Hosts string
Expand Down
30 changes: 22 additions & 8 deletions lib/vnet/opensshconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ func TestSSHConfigurator(t *testing.T) {
// Intentionally not using the template defined in the production code to
// test that it actually produces output that looks like this.
expectedConfigFile := func(expectedHosts string) string {
return fmt.Sprintf(`Host %s
if expectedHosts == "" {
return generatedFileHeader + "# VNet currently detects no logged-in clusters, log in to start using VNet\n"
}
return generatedFileHeader + fmt.Sprintf(`Host %s
IdentityFile "%s/id_vnet"
GlobalKnownHostsFile "%s/vnet_known_hosts"
UserKnownHostsFile /dev/null
Expand All @@ -95,20 +98,31 @@ func TestSSHConfigurator(t *testing.T) {
// fakeClientApp.
assertConfigFile("*.cluster1 *.cluster2 *.leaf1")

// Add a new root and leaf cluster, wait until the configurator is blocked
// in the loop, advance the clock, wait until the configurator is blocked
// again indicating it should have updated the config and made it back into
// the loop, and then assert that the new clusters are in the config file.
// To reliably advance the clock and allow runConfigurationLoop to update
// the config the test waits until the loop is blocked on the clock, then
// advances the clock, then waits until the loop is blocked again.
advance := func() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a good candidate for synctest in the future?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i would love to, already mentioned that in a comment

// strategy doesn't work. In go 1.25 we can use testing/synctest instead.

fakeClock.BlockUntilContext(ctx, 1)
fakeClock.Advance(sshConfigurationUpdateInterval)
fakeClock.BlockUntilContext(ctx, 1)
}

// Add a new root and leaf cluster, allow the configuration loop to run,
// and then assert that the new clusters are in the config file.
fakeClientApp.cfg.clusters["cluster3"] = testClusterSpec{
leafClusters: map[string]testClusterSpec{
"leaf2": {},
},
}
fakeClock.BlockUntilContext(ctx, 1)
fakeClock.Advance(sshConfigurationUpdateInterval)
fakeClock.BlockUntilContext(ctx, 1)
advance()
assertConfigFile("*.cluster1 *.cluster2 *.cluster3 *.leaf1 *.leaf2")

// Delete all clusters as if the user logged out, allow the configuration
// loop to run, and then assert that the config file is well-formed.
fakeClientApp.cfg.clusters = nil
advance()
assertConfigFile("")

// Kill the configurator, wait for it to return, and assert that the config
// file was deleted.
cancel()
Expand Down
Loading