diff --git a/client/embed/embed.go b/client/embed/embed.go index 88f7e541c81..16b65781081 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -507,3 +507,14 @@ func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { return nsnet, addr, nil } + +// GetDNSAddrPort returns the address of the NetBird DNS resolver for this client. +// Returns the zero AddrPort and false if the client is not started or the DNS +// server is not yet initialized. +func (c *Client) GetDNSAddrPort() (netip.AddrPort, bool) { + engine, err := c.getEngine() + if err != nil { + return netip.AddrPort{}, false + } + return engine.GetDNSAddrPort() +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 548b1f54f9f..9cb56d51a11 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -53,6 +53,10 @@ func (m *MockServer) DnsIP() netip.Addr { return netip.MustParseAddr("100.10.254.255") } +func (m *MockServer) DnsAddrPort() netip.AddrPort { + return netip.AddrPortFrom(netip.MustParseAddr("100.10.254.255"), 53) +} + func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) { // TODO implement me panic("implement me") diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f7865047b50..9adb1bc980f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -51,6 +51,8 @@ type Server interface { Initialize() error Stop() DnsIP() netip.Addr + // DnsAddrPort returns the full address (IP + port) of the DNS resolver. + DnsAddrPort() netip.AddrPort UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string @@ -380,6 +382,11 @@ func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } +// DnsAddrPort returns the full address (IP + port) of the DNS resolver. +func (s *DefaultServer) DnsAddrPort() netip.AddrPort { + return netip.AddrPortFrom(s.service.RuntimeIP(), uint16(s.service.RuntimePort())) +} + // SetFirewall sets the firewall used for DNS port DNAT rules. // This must be called before Initialize when using the listener-based service, // because the firewall is typically not available at construction time. diff --git a/client/internal/engine.go b/client/internal/engine.go index be2d8bbf353..fda51b3d2e5 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -2070,6 +2070,22 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +// GetDNSAddrPort returns the DNS server address (IP + port) used by this engine. +// Returns the zero AddrPort and false if the DNS server is not yet initialized. +func (e *Engine) GetDNSAddrPort() (netip.AddrPort, bool) { + e.syncMsgMux.Lock() + dnsServer := e.dnsServer + e.syncMsgMux.Unlock() + if dnsServer == nil { + return netip.AddrPort{}, false + } + addr := dnsServer.DnsAddrPort() + if !addr.IsValid() { + return netip.AddrPort{}, false + } + return addr, true +} + func (e *Engine) RenewTun(fd int) error { e.syncMsgMux.Lock() wgInterface := e.wgInterface diff --git a/proxy/internal/roundtrip/dns.go b/proxy/internal/roundtrip/dns.go new file mode 100644 index 00000000000..6a03bb7222e --- /dev/null +++ b/proxy/internal/roundtrip/dns.go @@ -0,0 +1,91 @@ +package roundtrip + +import ( + "context" + "fmt" + "net" + "net/netip" +) + +// dialWithDNSResolution wraps a DialContext function so that target addresses +// containing hostnames (rather than IPs) are resolved through NetBird's own +// DNS infrastructure before the connection is dialed. +// +// getDNSAddr is called on every dial that requires hostname resolution; it +// should return the current NetBird DNS server address (IP + port) and true. +// When the DNS server is not yet available it should return false, in which +// case resolution falls back to the process-level default resolver. +// +// The resolver dials the DNS server using the same underlying dial function +// (i.e. through the WireGuard netstack), because in userspace / netstack mode +// the DNS server is reachable only via the virtual WireGuard interface. +func dialWithDNSResolution( + getDNSAddr func() (netip.AddrPort, bool), + dial func(ctx context.Context, network, addr string) (net.Conn, error), +) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + // Malformed address — let the underlying dialer handle or fail it. + return dial(ctx, network, addr) + } + + // If the host is already an IP literal, skip resolution entirely. + if _, err := netip.ParseAddr(host); err == nil { + return dial(ctx, network, addr) + } + + resolved, err := resolveHost(ctx, host, getDNSAddr, dial) + if err != nil { + return nil, err + } + + return dial(ctx, network, net.JoinHostPort(resolved, port)) + } +} + +// resolveHost resolves a hostname to its first IPv4/IPv6 address using a +// custom net.Resolver backed by the NetBird DNS server (when available) or +// the process-level default resolver as a fallback. +func resolveHost( + ctx context.Context, + host string, + getDNSAddr func() (netip.AddrPort, bool), + dial func(ctx context.Context, network, addr string) (net.Conn, error), +) (string, error) { + resolver := buildResolver(getDNSAddr, dial) + + addrs, err := resolver.LookupHost(ctx, host) + if err != nil { + return "", fmt.Errorf("dns: resolve %q: %w", host, err) + } + if len(addrs) == 0 { + return "", fmt.Errorf("dns: no addresses returned for %q", host) + } + return addrs[0], nil +} + +// buildResolver returns a *net.Resolver configured to query the NetBird DNS +// server via the provided dial function. If the DNS server address is not +// yet available, the default system resolver is returned so that the caller +// can still attempt resolution (useful during client startup). +func buildResolver( + getDNSAddr func() (netip.AddrPort, bool), + dial func(ctx context.Context, network, addr string) (net.Conn, error), +) *net.Resolver { + dnsAddr, ok := getDNSAddr() + if !ok { + return net.DefaultResolver + } + + addrStr := dnsAddr.String() + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + // Always use UDP toward the DNS server. The network and address + // arguments passed by net.Resolver are intentionally ignored; + // we route through the WireGuard netstack instead. + return dial(ctx, "udp", addrStr) + }, + } +} diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index e38e3dc4ef5..6a8b74380bb 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -276,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. transport := &http.Transport{ - DialContext: dialWithTimeout(client.DialContext), + DialContext: dialWithDNSResolution(client.GetDNSAddrPort, dialWithTimeout(client.DialContext)), ForceAttemptHTTP2: true, MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,