Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions lib/srv/desktop/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ func (s *WindowsService) applyLabelsFromLDAP(entry *ldap.Entry, labels map[strin
}
}

const dnsQueryTimeout = 5 * time.Second

// lookupDesktop does a DNS lookup for the provided hostname.
// It checks using the default system resolver first, and falls
// back to the configured LDAP server if the system resolver fails.
Expand All @@ -194,56 +196,57 @@ func (s *WindowsService) lookupDesktop(ctx context.Context, hostname string) ([]
return result
}

const queryTimeout = 5 * time.Second

queryResolver := func(resolver *net.Resolver, resolverName string) chan []netip.Addr {
ch := make(chan []netip.Addr, 1)
if resolver != nil {
go func() {
tctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()

addrs, err := resolver.LookupNetIP(tctx, "ip4", hostname)
if err != nil {
s.cfg.Log.Debugf("DNS lookup for %v failed with %s resolver: %v",
hostname, resolverName, err)
}
if len(addrs) > 0 {
ch <- addrs
}
}()
}
go func() {
tctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
defer cancel()

addrs, err := resolver.LookupNetIP(tctx, "ip4", hostname)
if err != nil {
s.cfg.Log.Debugf("DNS lookup for %v failed with %s resolver: %v",
hostname, resolverName, err)
}

ch <- addrs

}()
return ch
}

// kick off both DNS queries in parallel
defaultResult := queryResolver(net.DefaultResolver, "default")
ldapResult := queryResolver(s.dnsResolver, "LDAP")

// wait 5 seconds for the default resolver to return
select {
case addrs := <-defaultResult:
// wait for the default resolver to return (or time out)
addrs := <-defaultResult
if len(addrs) > 0 {
return stringAddrs(addrs), nil
case <-s.cfg.Clock.After(5 * time.Second):
}

// If we didn't get a result from the default resolver,
// the result from the LDAP resolver is either available
// now or we're done. There's no more waiting.
select {
case addrs := <-ldapResult:
// use the result from the LDAP resolver.
// This shouldn't block for very long, since both operations
// started at the same time with the same timeout.
addrs = <-ldapResult
if len(addrs) > 0 {
return stringAddrs(addrs), nil
default:
return nil, trace.Errorf("could not resolve %v in time", hostname)
}

return nil, trace.Errorf("could not resolve %v in time", hostname)
}

// ldapEntryToWindowsDesktop generates the Windows Desktop resource
// from an LDAP search result
func (s *WindowsService) ldapEntryToWindowsDesktop(ctx context.Context, entry *ldap.Entry, getHostLabels func(string) map[string]string) (types.ResourceWithLabels, error) {
hostname := entry.GetAttributeValue(windows.AttrDNSHostName)
if hostname == "" {
return nil, trace.BadParameter("LDAP entry missing hostname, has attributes: %v", entry.Attributes)
attrs := make([]string, len(entry.Attributes))
for _, a := range entry.Attributes {
attrs = append(attrs, fmt.Sprintf("%v=%v", a.Name, a.Values))
}
s.cfg.Log.Debugf("LDAP entry %v is missing hostname, has attributes %v", entry.DN, strings.Join(attrs, ","))
return nil, trace.BadParameter("LDAP entry %v missing hostname", entry.DN)
}
labels := getHostLabels(hostname)
labels[types.DiscoveryLabelWindowsDomain] = s.cfg.Domain
Expand Down
34 changes: 34 additions & 0 deletions lib/srv/desktop/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
package desktop

import (
"context"
"errors"
"io"
"net"
"strconv"
"testing"
"time"

"github.com/go-ldap/ldap/v3"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/windows"
"github.com/gravitational/teleport/lib/utils"
)

// TestDiscoveryLDAPFilter verifies that WindowsService produces a valid
Expand Down Expand Up @@ -135,3 +143,29 @@ func TestLabelsDomainControllers(t *testing.T) {
})
}
}

// TestDNSErrors verifies that errors are handled quickly
// and do not block discovery for too long.
func TestDNSErrors(t *testing.T) {
logger := utils.NewLoggerForTests()
logger.SetLevel(logrus.PanicLevel)
logger.SetOutput(io.Discard)

s := &WindowsService{
cfg: WindowsServiceConfig{
Log: logger,
Clock: clockwork.NewRealClock(),
},
dnsResolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, errors.New("this resolver always fails")
},
},
}

start := time.Now()
_, err := s.lookupDesktop(context.Background(), "$invalid hostname")
require.Less(t, time.Since(start), dnsQueryTimeout-1*time.Second)
require.Error(t, err)
}