diff --git a/lib/client/api.go b/lib/client/api.go index 9af2a4e8a8c20..7269dc90ab323 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1372,14 +1372,19 @@ func (tc *TeleportClient) RootClusterName(ctx context.Context) (string, error) { return name, nil } -type targetNode struct { - hostname string - addr string -} - -// getTargetNodes returns a list of node addresses this SSH command needs to -// operate on. -func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]targetNode, error) { +// TargetNode contains information about a resolved host. +type TargetNode struct { + // Hostname of the target host. + Hostname string + // Address of the target host. For hosts resolved + // via auth this will be in the form of UUID:0. + Addr string +} + +// GetTargetNodes returns hosts matching the target host provided by users. Host resolution +// honors an explicit host, i.e. tsh ssh user@hostname, label based hosts, i.e. tsh ssh user@foo=bar, +// as well as respecting any proxy templates that are specified. +func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]TargetNode, error) { ctx, span := tc.Tracer.Start( ctx, "teleportClient/getTargetNodes", @@ -1388,16 +1393,17 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni defer span.End() if options.HostAddress != "" { - return []targetNode{ + return []TargetNode{ { - hostname: options.HostAddress, - addr: options.HostAddress, + Hostname: options.HostAddress, + Addr: options.HostAddress, }, }, nil } // Query for nodes if labels, fuzzy search, or predicate expressions were provided. if len(tc.Labels) > 0 || len(tc.SearchKeywords) > 0 || tc.PredicateExpression != "" { + log.Debugf("Attempting to resolve matching hosts from labels=%v|search=%v|predicate=%v", tc.Labels, tc.SearchKeywords, tc.PredicateExpression) nodes, err := client.GetAllUnifiedResources(ctx, clt, &proto.ListUnifiedResourcesRequest{ Kinds: []string{types.KindNode}, SortBy: types.SortBy{Field: types.ResourceMetadataName}, @@ -1410,7 +1416,7 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni return nil, trace.Wrap(err) } - retval := make([]targetNode, 0, len(nodes)) + retval := make([]TargetNode, 0, len(nodes)) for _, resource := range nodes { server, ok := resource.ResourceWithLabels.(types.Server) if !ok { @@ -1418,15 +1424,17 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni } // always dial nodes by UUID - retval = append(retval, targetNode{ - hostname: server.GetHostname(), - addr: fmt.Sprintf("%s:0", resource.GetName()), + retval = append(retval, TargetNode{ + Hostname: server.GetHostname(), + Addr: fmt.Sprintf("%s:0", resource.GetName()), }) } return retval, nil } + log.Debugf("Using provided host %s", tc.Host) + // detect the common error when users use host:port address format _, port, err := net.SplitHostPort(tc.Host) // client has used host:port notation @@ -1435,10 +1443,10 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni } addr := net.JoinHostPort(tc.Host, strconv.Itoa(tc.HostPort)) - return []targetNode{ + return []TargetNode{ { - hostname: tc.Host, - addr: addr, + Hostname: tc.Host, + Addr: addr, }, }, nil } @@ -1704,7 +1712,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun defer clt.Close() // which nodes are we executing this commands on? - nodeAddrs, err := tc.getTargetNodes(ctx, clt.AuthClient, options) + nodeAddrs, err := tc.GetTargetNodes(ctx, clt.AuthClient, options) if err != nil { return trace.Wrap(err) } @@ -1715,7 +1723,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun if len(nodeAddrs) > 1 { return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command) } - return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].addr, command, options.LocalCommandExecutor) + return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor) } // ConnectToNode attempts to establish a connection to the node resolved to by the provided @@ -1724,7 +1732,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun // fail the error from the connection attempt with the already provisioned certificates will // be returned. The client from whichever attempt succeeds first will be returned. func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (_ *NodeClient, err error) { - node := nodeName(targetNode{addr: nodeDetails.Addr}) + node := nodeName(TargetNode{Addr: nodeDetails.Addr}) ctx, span := tc.Tracer.Start( ctx, "teleportClient/ConnectToNode", @@ -1878,7 +1886,7 @@ func (m MFARequiredUnknownErr) Is(err error) bool { // if it is required, then the mfa ceremony is attempted. The target host is dialed once the ceremony // completes and new certificates are retrieved. func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) { - node := nodeName(targetNode{addr: nodeDetails.Addr}) + node := nodeName(TargetNode{Addr: nodeDetails.Addr}) ctx, span := tc.Tracer.Start( ctx, "teleportClient/connectToNodeWithMFA", @@ -1987,11 +1995,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil)) } -func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error { +func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { cluster := clt.ClusterName() nodeAddrs := make([]string, 0, len(nodes)) for _, node := range nodes { - nodeAddrs = append(nodeAddrs, node.addr) + nodeAddrs = append(nodeAddrs, node.Addr) } ctx, span := tc.Tracer.Start( ctx, @@ -2687,7 +2695,7 @@ type execResult struct { } // runCommandOnNodes executes a given bash command on a bunch of remote nodes. -func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error { +func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { cluster := clt.ClusterName() ctx, span := tc.Tracer.Start( ctx, @@ -2705,7 +2713,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli mfaRequiredCheck, err := clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{ Target: &proto.IsMFARequiredRequest_Node{ Node: &proto.NodeLogin{ - Node: nodeName(targetNode{addr: nodes[0].addr}), + Node: nodeName(TargetNode{Addr: nodes[0].Addr}), Login: tc.Config.HostLogin, }, }, @@ -2731,7 +2739,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli gctx, "teleportClient/executingCommand", oteltrace.WithSpanKind(oteltrace.SpanKindClient), - oteltrace.WithAttributes(attribute.String("node", node.addr)), + oteltrace.WithAttributes(attribute.String("node", node.Addr)), ) defer span.End() @@ -2739,11 +2747,11 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli ctx, clt, NodeDetails{ - Addr: node.addr, + Addr: node.Addr, Namespace: tc.Namespace, Cluster: cluster, MFACheck: mfaRequiredCheck, - hostname: node.hostname, + hostname: node.Hostname, }, tc.Config.HostLogin, ) diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 3dcf18b2aa15b..670c3398d8c42 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -1289,37 +1289,37 @@ func TestGetTargetNodes(t *testing.T) { host string port int clt fakeResourceClient - expected []targetNode + expected []TargetNode }{ { name: "options override", options: SSHOptions{ HostAddress: "test:1234", }, - expected: []targetNode{{hostname: "test:1234", addr: "test:1234"}}, + expected: []TargetNode{{Hostname: "test:1234", Addr: "test:1234"}}, }, { name: "explicit target", host: "test", port: 1234, - expected: []targetNode{{hostname: "test", addr: "test:1234"}}, + expected: []TargetNode{{Hostname: "test", Addr: "test:1234"}}, }, { name: "labels", labels: map[string]string{"foo": "bar"}, - expected: []targetNode{{hostname: "labels", addr: "abcd:0"}}, + expected: []TargetNode{{Hostname: "labels", Addr: "abcd:0"}}, clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}}}, }, { name: "search", search: []string{"foo", "bar"}, - expected: []targetNode{{hostname: "search", addr: "abcd:0"}}, + expected: []TargetNode{{Hostname: "search", Addr: "abcd:0"}}, clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}}}, }, { name: "predicate", predicate: `resource.spec.hostname == "test"`, - expected: []targetNode{{hostname: "predicate", addr: "abcd:0"}}, + expected: []TargetNode{{Hostname: "predicate", Addr: "abcd:0"}}, clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}}}, }, } @@ -1337,7 +1337,7 @@ func TestGetTargetNodes(t *testing.T) { }, } - match, err := clt.getTargetNodes(context.Background(), test.clt, test.options) + match, err := clt.GetTargetNodes(context.Background(), test.clt, test.options) require.NoError(t, err) require.EqualValues(t, test.expected, match) }) diff --git a/lib/client/client.go b/lib/client/client.go index b3defdd593ba4..825b0ebfa9115 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -243,13 +243,13 @@ func (a sharedAuthClient) Close() error { } // nodeName removes the port number from the hostname, if present -func nodeName(node targetNode) string { - if node.hostname != "" { - return node.hostname +func nodeName(node TargetNode) string { + if node.Hostname != "" { + return node.Hostname } - n, _, err := net.SplitHostPort(node.addr) + n, _, err := net.SplitHostPort(node.Addr) if err != nil { - return node.addr + return node.Addr } return n } @@ -273,7 +273,7 @@ type NodeDetails struct { // String returns a user-friendly name func (n NodeDetails) String() string { - parts := []string{nodeName(targetNode{addr: n.Addr})} + parts := []string{nodeName(TargetNode{Addr: n.Addr})} if n.Cluster != "" { parts = append(parts, "on cluster", n.Cluster) } diff --git a/lib/client/client_test.go b/lib/client/client_test.go index 8a7273eb04d95..432f51a47de17 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -40,9 +40,9 @@ import ( ) func TestHelperFunctions(t *testing.T) { - assert.Equal(t, "one", nodeName(targetNode{addr: "one"})) - assert.Equal(t, "one", nodeName(targetNode{addr: "one:22"})) - assert.Equal(t, "example.com", nodeName(targetNode{addr: "one", hostname: "example.com"})) + assert.Equal(t, "one", nodeName(TargetNode{Addr: "one"})) + assert.Equal(t, "one", nodeName(TargetNode{Addr: "one:22"})) + assert.Equal(t, "example.com", nodeName(TargetNode{Addr: "one", Hostname: "example.com"})) } func TestNewSession(t *testing.T) { diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index 71754507a847d..024e646ffb2d9 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -302,7 +302,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe key, err = c.performMFACeremony(ctx, mfaClt, ReissueParams{ - NodeName: nodeName(targetNode{addr: target.Addr}), + NodeName: nodeName(TargetNode{Addr: target.Addr}), RouteToCluster: target.Cluster, MFACheck: target.MFACheck, }, diff --git a/tool/tsh/common/putty_config_windows.go b/tool/tsh/common/putty_config_windows.go index 0c5fb35f2ae0e..4df114c1259a5 100644 --- a/tool/tsh/common/putty_config_windows.go +++ b/tool/tsh/common/putty_config_windows.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/utils/keypaths" + "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/puttyhosts" "github.com/gravitational/teleport/lib/utils/registry" ) @@ -254,10 +255,32 @@ func onPuttyConfig(cf *CLIConf) error { return trace.Wrap(err) } + // connect to proxy to fetch cluster info + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + matches, err := tc.GetTargetNodes(cf.Context, clusterClient.AuthClient, client.SSHOptions{}) + if err != nil { + return trace.Wrap(err) + } + + switch len(matches) { + case 0: + return trace.NotFound("no matching hosts found") + case 1: + log.Debugf("Using host %v", matches[0]) + default: + log.Debugf("found multiple matching hosts %v %v", matches[0], matches[1]) + return trace.BadParameter("multiple matching hosts found") + } + // remove any spaces from the provided hostname. if the hostname contains a colon, it will be a // hostname:port combination so we split it. this is useful as shorthand when adding OpenSSH hosts // with `tsh puttyconfig user@host:22`, rather than using the longer `tsh puttyconfig --port 22 user@host` - hostname := strings.TrimSpace(tc.Config.Host) + hostname := strings.TrimSpace(matches[0].Hostname) port := tc.Config.HostPort if splitHost, splitPort, err := net.SplitHostPort(hostname); err == nil { hostname = splitHost @@ -280,13 +303,6 @@ func onPuttyConfig(cf *CLIConf) error { userHostString = fmt.Sprintf("%v@%v", login, userHostString) } - // connect to proxy to fetch cluster info - clusterClient, err := tc.ConnectToCluster(cf.Context) - if err != nil { - return trace.Wrap(err) - } - defer clusterClient.Close() - // parse out proxy details proxyHost, _, err := net.SplitHostPort(tc.Config.SSHProxyAddr) if err != nil {