Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1372,14 +1372,19 @@ func (tc *TeleportClient) RootClusterName(ctx context.Context) (string, error) {
return name, nil
}

type targetNode struct {
hostname string
addr string
}

// getTargetNodes returns a list of node addresses this SSH command needs to
// operate on.
func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]targetNode, error) {
// TargetNode contains information about a resolved host.
type TargetNode struct {
// Hostname of the target host.
Hostname string
// Address of the target host. For hosts resolved
// via auth this will be in the form of UUID:0.
Addr string
}

// GetTargetNodes returns hosts matching the target host provided by users. Host resolution
// honors an explicit host, i.e. tsh ssh user@hostname, label based hosts, i.e. tsh ssh user@foo=bar,
// as well as respecting any proxy templates that are specified.
func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]TargetNode, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/getTargetNodes",
Expand All @@ -1388,16 +1393,17 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
defer span.End()

if options.HostAddress != "" {
return []targetNode{
return []TargetNode{
{
hostname: options.HostAddress,
addr: options.HostAddress,
Hostname: options.HostAddress,
Addr: options.HostAddress,
},
}, nil
}

// Query for nodes if labels, fuzzy search, or predicate expressions were provided.
if len(tc.Labels) > 0 || len(tc.SearchKeywords) > 0 || tc.PredicateExpression != "" {
log.Debugf("Attempting to resolve matching hosts from labels=%v|search=%v|predicate=%v", tc.Labels, tc.SearchKeywords, tc.PredicateExpression)
nodes, err := client.GetAllUnifiedResources(ctx, clt, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindNode},
SortBy: types.SortBy{Field: types.ResourceMetadataName},
Expand All @@ -1410,23 +1416,25 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
return nil, trace.Wrap(err)
}

retval := make([]targetNode, 0, len(nodes))
retval := make([]TargetNode, 0, len(nodes))
for _, resource := range nodes {
server, ok := resource.ResourceWithLabels.(types.Server)
if !ok {
continue
}

// always dial nodes by UUID
retval = append(retval, targetNode{
hostname: server.GetHostname(),
addr: fmt.Sprintf("%s:0", resource.GetName()),
retval = append(retval, TargetNode{
Hostname: server.GetHostname(),
Addr: fmt.Sprintf("%s:0", resource.GetName()),
})
}

return retval, nil
}

log.Debugf("Using provided host %s", tc.Host)

// detect the common error when users use host:port address format
_, port, err := net.SplitHostPort(tc.Host)
// client has used host:port notation
Expand All @@ -1435,10 +1443,10 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
}

addr := net.JoinHostPort(tc.Host, strconv.Itoa(tc.HostPort))
return []targetNode{
return []TargetNode{
{
hostname: tc.Host,
addr: addr,
Hostname: tc.Host,
Addr: addr,
},
}, nil
}
Expand Down Expand Up @@ -1704,7 +1712,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
defer clt.Close()

// which nodes are we executing this commands on?
nodeAddrs, err := tc.getTargetNodes(ctx, clt.AuthClient, options)
nodeAddrs, err := tc.GetTargetNodes(ctx, clt.AuthClient, options)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1715,7 +1723,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
if len(nodeAddrs) > 1 {
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
}
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].addr, command, options.LocalCommandExecutor)
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
}

// ConnectToNode attempts to establish a connection to the node resolved to by the provided
Expand All @@ -1724,7 +1732,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
// fail the error from the connection attempt with the already provisioned certificates will
// be returned. The client from whichever attempt succeeds first will be returned.
func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (_ *NodeClient, err error) {
node := nodeName(targetNode{addr: nodeDetails.Addr})
node := nodeName(TargetNode{Addr: nodeDetails.Addr})
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/ConnectToNode",
Expand Down Expand Up @@ -1878,7 +1886,7 @@ func (m MFARequiredUnknownErr) Is(err error) bool {
// if it is required, then the mfa ceremony is attempted. The target host is dialed once the ceremony
// completes and new certificates are retrieved.
func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) {
node := nodeName(targetNode{addr: nodeDetails.Addr})
node := nodeName(TargetNode{Addr: nodeDetails.Addr})
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/connectToNodeWithMFA",
Expand Down Expand Up @@ -1987,11 +1995,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil))
}

func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error {
func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
cluster := clt.ClusterName()
nodeAddrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeAddrs = append(nodeAddrs, node.addr)
nodeAddrs = append(nodeAddrs, node.Addr)
}
ctx, span := tc.Tracer.Start(
ctx,
Expand Down Expand Up @@ -2687,7 +2695,7 @@ type execResult struct {
}

// runCommandOnNodes executes a given bash command on a bunch of remote nodes.
func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error {
func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
cluster := clt.ClusterName()
ctx, span := tc.Tracer.Start(
ctx,
Expand All @@ -2705,7 +2713,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli
mfaRequiredCheck, err := clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_Node{
Node: &proto.NodeLogin{
Node: nodeName(targetNode{addr: nodes[0].addr}),
Node: nodeName(TargetNode{Addr: nodes[0].Addr}),
Login: tc.Config.HostLogin,
},
},
Expand All @@ -2731,19 +2739,19 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli
gctx,
"teleportClient/executingCommand",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(attribute.String("node", node.addr)),
oteltrace.WithAttributes(attribute.String("node", node.Addr)),
)
defer span.End()

nodeClient, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{
Addr: node.addr,
Addr: node.Addr,
Namespace: tc.Namespace,
Cluster: cluster,
MFACheck: mfaRequiredCheck,
hostname: node.hostname,
hostname: node.Hostname,
},
tc.Config.HostLogin,
)
Expand Down
14 changes: 7 additions & 7 deletions lib/client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1289,37 +1289,37 @@ func TestGetTargetNodes(t *testing.T) {
host string
port int
clt fakeResourceClient
expected []targetNode
expected []TargetNode
}{
{
name: "options override",
options: SSHOptions{
HostAddress: "test:1234",
},
expected: []targetNode{{hostname: "test:1234", addr: "test:1234"}},
expected: []TargetNode{{Hostname: "test:1234", Addr: "test:1234"}},
},
{
name: "explicit target",
host: "test",
port: 1234,
expected: []targetNode{{hostname: "test", addr: "test:1234"}},
expected: []TargetNode{{Hostname: "test", Addr: "test:1234"}},
},
{
name: "labels",
labels: map[string]string{"foo": "bar"},
expected: []targetNode{{hostname: "labels", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "labels", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}}},
},
{
name: "search",
search: []string{"foo", "bar"},
expected: []targetNode{{hostname: "search", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "search", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}}},
},
{
name: "predicate",
predicate: `resource.spec.hostname == "test"`,
expected: []targetNode{{hostname: "predicate", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "predicate", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}}},
},
}
Expand All @@ -1337,7 +1337,7 @@ func TestGetTargetNodes(t *testing.T) {
},
}

match, err := clt.getTargetNodes(context.Background(), test.clt, test.options)
match, err := clt.GetTargetNodes(context.Background(), test.clt, test.options)
require.NoError(t, err)
require.EqualValues(t, test.expected, match)
})
Expand Down
12 changes: 6 additions & 6 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ func (a sharedAuthClient) Close() error {
}

// nodeName removes the port number from the hostname, if present
func nodeName(node targetNode) string {
if node.hostname != "" {
return node.hostname
func nodeName(node TargetNode) string {
if node.Hostname != "" {
return node.Hostname
}
n, _, err := net.SplitHostPort(node.addr)
n, _, err := net.SplitHostPort(node.Addr)
if err != nil {
return node.addr
return node.Addr
}
return n
}
Expand All @@ -273,7 +273,7 @@ type NodeDetails struct {

// String returns a user-friendly name
func (n NodeDetails) String() string {
parts := []string{nodeName(targetNode{addr: n.Addr})}
parts := []string{nodeName(TargetNode{Addr: n.Addr})}
if n.Cluster != "" {
parts = append(parts, "on cluster", n.Cluster)
}
Expand Down
6 changes: 3 additions & 3 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ import (
)

func TestHelperFunctions(t *testing.T) {
assert.Equal(t, "one", nodeName(targetNode{addr: "one"}))
assert.Equal(t, "one", nodeName(targetNode{addr: "one:22"}))
assert.Equal(t, "example.com", nodeName(targetNode{addr: "one", hostname: "example.com"}))
assert.Equal(t, "one", nodeName(TargetNode{Addr: "one"}))
assert.Equal(t, "one", nodeName(TargetNode{Addr: "one:22"}))
assert.Equal(t, "example.com", nodeName(TargetNode{Addr: "one", Hostname: "example.com"}))
}

func TestNewSession(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion lib/client/cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe
key, err = c.performMFACeremony(ctx,
mfaClt,
ReissueParams{
NodeName: nodeName(targetNode{addr: target.Addr}),
NodeName: nodeName(TargetNode{Addr: target.Addr}),
RouteToCluster: target.Cluster,
MFACheck: target.MFACheck,
},
Expand Down
32 changes: 24 additions & 8 deletions tool/tsh/common/putty_config_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/utils/keypaths"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/puttyhosts"
"github.com/gravitational/teleport/lib/utils/registry"
)
Expand Down Expand Up @@ -254,10 +255,32 @@ func onPuttyConfig(cf *CLIConf) error {
return trace.Wrap(err)
}

// connect to proxy to fetch cluster info
clusterClient, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
}
defer clusterClient.Close()

matches, err := tc.GetTargetNodes(cf.Context, clusterClient.AuthClient, client.SSHOptions{})
if err != nil {
return trace.Wrap(err)
}

switch len(matches) {
case 0:
return trace.NotFound("no matching hosts found")
case 1:
log.Debugf("Using host %v", matches[0])
default:
log.Debugf("found multiple matching hosts %v %v", matches[0], matches[1])
return trace.BadParameter("multiple matching hosts found")
}

// remove any spaces from the provided hostname. if the hostname contains a colon, it will be a
// hostname:port combination so we split it. this is useful as shorthand when adding OpenSSH hosts
// with `tsh puttyconfig user@host:22`, rather than using the longer `tsh puttyconfig --port 22 user@host`
hostname := strings.TrimSpace(tc.Config.Host)
hostname := strings.TrimSpace(matches[0].Hostname)
port := tc.Config.HostPort
if splitHost, splitPort, err := net.SplitHostPort(hostname); err == nil {
hostname = splitHost
Expand All @@ -280,13 +303,6 @@ func onPuttyConfig(cf *CLIConf) error {
userHostString = fmt.Sprintf("%v@%v", login, userHostString)
}

// connect to proxy to fetch cluster info
clusterClient, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
}
defer clusterClient.Close()

// parse out proxy details
proxyHost, _, err := net.SplitHostPort(tc.Config.SSHProxyAddr)
if err != nil {
Expand Down