From 299f0a384e5dd12b6f99f28d14656920bec1aa8f Mon Sep 17 00:00:00 2001 From: Pavel Evdokimov Date: Mon, 9 Feb 2026 20:24:21 +0000 Subject: [PATCH 1/2] stop forwarding unmatched dns queries to upstream servers on linux, delegate it to systemd-resolved --- lib/vnet/dns/dns.go | 119 ++++++++++++++++--- lib/vnet/dns/dns_test.go | 134 +++++++++++++++++++++- lib/vnet/dns/osnameservers_darwin_test.go | 2 +- lib/vnet/dns/osnameservers_linux.go | 33 ++++++ 4 files changed, 272 insertions(+), 16 deletions(-) create mode 100644 lib/vnet/dns/osnameservers_linux.go diff --git a/lib/vnet/dns/dns.go b/lib/vnet/dns/dns.go index fac646f231256..f4083d0bf0f54 100644 --- a/lib/vnet/dns/dns.go +++ b/lib/vnet/dns/dns.go @@ -22,6 +22,7 @@ import ( "io" "log/slog" "net" + "runtime" "sync" "time" @@ -77,6 +78,16 @@ type UpstreamNameserverSource interface { UpstreamNameservers(context.Context) ([]string, error) } +// DNSMode controls how the server handles unmatched queries. +type DNSMode int + +const ( + // DNSModeRecursive forwards unmatched queries to upstream nameservers. + DNSModeRecursive DNSMode = iota + // DNSModeAuthoritative returns negative responses for unmatched queries. + DNSModeAuthoritative +) + // Server is a DNS server. type Server struct { resolver Resolver @@ -127,13 +138,13 @@ func (s *Server) HandleUDP(ctx context.Context, conn net.Conn) error { } buf = buf[:n] - return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn)) + return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn, s.platformDNSMode())) } // ListendAndServeUDP reads all incoming UDP messages from [conn], handles DNS questions, and writes the // responses back to [conn]. // This is not called by VNet code and basically exists so we can test the resolver outside of VNet. -func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error { +func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn, mode DNSMode) error { buf, returnBuf := s.getMessageBuffer() defer returnBuf() @@ -152,7 +163,7 @@ func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error conn: conn, remoteAddr: remoteAddr, } - if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter); err != nil { + if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter, mode); err != nil { s.slog.DebugContext(ctx, "Error handling DNS message.", "error", err) } } @@ -168,9 +179,18 @@ func (u *udpWriter) Write(b []byte) (int, error) { return n, err } +func (s *Server) platformDNSMode() DNSMode { + switch runtime.GOOS { + case "linux": + return DNSModeAuthoritative + default: + return DNSModeRecursive + } +} + // handleDNSMessage handles the DNS message held in [buf] and writes the answer to [responseWriter]. // This could handle DNS messages arriving over UDP or TCP. -func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer) error { +func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer, mode DNSMode) error { slog := s.slog.With("remote_addr", remoteAddr) slog.DebugContext(ctx, "Handling DNS message.") defer slog.DebugContext(ctx, "Done handling DNS message.") @@ -181,8 +201,21 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf [] return trace.Wrap(err, "parsing DNS message") } if requestHeader.OpCode != 0 { - slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode) - return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message") + switch mode { + case DNSModeRecursive: + slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode) + return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message") + case DNSModeAuthoritative: + // RFC 8906 section 3.1.4 recommends NOTIMP for unsupported opcodes. + slog.DebugContext(ctx, "OpCode is not QUERY (0), responding with NOTIMP.", "opcode", requestHeader.OpCode) + question, qerr := parser.Question() + if qerr != nil { + return trace.Wrap(qerr, "parsing DNS question") + } + return trace.Wrap(writeNotImpl(buf, &requestHeader, &question, responseWriter), "authoritative non-Query DNS message") + default: + return trace.BadParameter("unknown DNS mode %v", mode) + } } question, err := parser.Question() if err != nil { @@ -192,8 +225,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf [] slog = slog.With("fqdn", fqdn, "type", question.Type.String()) slog.DebugContext(ctx, "Received DNS question.", "question", question) if question.Class != dnsmessage.ClassINET { - slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class) - return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query") + switch mode { + case DNSModeRecursive: + slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class) + return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query") + case DNSModeAuthoritative: + // RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries. + slog.DebugContext(ctx, "Query class is not INET, responding with NXDOMAIN.", "class", question.Class) + return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative non-INET DNS query") + default: + return trace.BadParameter("unknown DNS mode %v", mode) + } } var result Result @@ -209,8 +251,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf [] return trace.Wrap(err, "resolving AAAA request for %q", fqdn) } default: - slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type) - return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type) + switch mode { + case DNSModeRecursive: + slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type) + return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type) + case DNSModeAuthoritative: + // RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries. + slog.DebugContext(ctx, "Question type is not A or AAAA, responding with NXDOMAIN.", "type", question.Type) + return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative %s DNS query", question.Type) + default: + return trace.BadParameter("unknown DNS mode %v", mode) + } } var response []byte @@ -228,8 +279,16 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf [] slog.DebugContext(ctx, "Matched DNS AAAA.", "aaaa", tcpip.AddrFrom16(result.AAAA)) response, err = buildAAAAResponse(buf, &requestHeader, &question, result.AAAA) default: - slog.DebugContext(ctx, "Forwarding unmatched query.") - return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query") + switch mode { + case DNSModeRecursive: + slog.DebugContext(ctx, "Forwarding unmatched query.") + return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query") + case DNSModeAuthoritative: + slog.DebugContext(ctx, "Unmatched query, responding with NXDOMAIN.") + return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative unmatched DNS query") + default: + return trace.BadParameter("unknown DNS mode %v", mode) + } } if err != nil { return trace.Wrap(err) @@ -343,6 +402,16 @@ func buildNXDomainResponse(buf []byte, requestHeader *dnsmessage.Header, questio return buf, trace.Wrap(err, "serializing DNS response") } +func buildNotImplResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) { + responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeNotImplemented) + if err != nil { + return buf, trace.Wrap(err) + } + // TODO(nklaassen): TTL in SOA record? + buf, err = responseBuilder.Finish() + return buf, trace.Wrap(err, "serializing DNS response") +} + func buildAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [4]byte) ([]byte, error) { responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess) if err != nil { @@ -383,13 +452,37 @@ func buildAAAAResponse(buf []byte, requestHeader *dnsmessage.Header, question *d return buf, trace.Wrap(err, "serializing DNS response") } +func writeNXDomain(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error { + if question == nil { + return trace.Errorf("missing DNS question for NXDOMAIN response") + } + response, err := buildNXDomainResponse(buf, requestHeader, question) + if err != nil { + return trace.Wrap(err) + } + _, err = responseWriter.Write(response) + return trace.Wrap(err, "writing DNS response") +} + +func writeNotImpl(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error { + if question == nil { + return trace.Errorf("missing DNS question for NOTIMP response") + } + response, err := buildNotImplResponse(buf, requestHeader, question) + if err != nil { + return trace.Wrap(err) + } + _, err = responseWriter.Write(response) + return trace.Wrap(err, "writing DNS response") +} + func prepDNSResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, rcode dnsmessage.RCode) (*dnsmessage.Builder, error) { buf = buf[:0] responseBuilder := dnsmessage.NewBuilder(buf, dnsmessage.Header{ ID: requestHeader.ID, Response: true, Authoritative: true, - RCode: dnsmessage.RCodeSuccess, + RCode: rcode, }) responseBuilder.EnableCompression() if err := responseBuilder.StartQuestions(); err != nil { diff --git a/lib/vnet/dns/dns_test.go b/lib/vnet/dns/dns_test.go index b89418b5df646..ca5fd4a89ded5 100644 --- a/lib/vnet/dns/dns_test.go +++ b/lib/vnet/dns/dns_test.go @@ -67,7 +67,7 @@ func TestServer(t *testing.T) { testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ Name: fmt.Sprintf("upstream nameserver %d", i), Task: func(ctx context.Context) error { - err := upstreamServer.ListenAndServeUDP(ctx, conn) + err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive) if err == nil || utils.IsOKNetworkError(err) { return nil } @@ -117,7 +117,7 @@ func TestServer(t *testing.T) { testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ Name: "nameserver under test", Task: func(ctx context.Context) error { - err := server.ListenAndServeUDP(ctx, conn) + err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive) if err == nil || utils.IsOKNetworkError(err) { return nil } @@ -173,6 +173,136 @@ func TestServer(t *testing.T) { } } +func TestServerAuthoritativeMode(t *testing.T) { + t.Parallel() + ctx := context.Background() + + defaultIP4 := tcpip.AddrFrom4([4]byte{1, 2, 3, 4}) + defaultIP6 := tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) + + staticResolver := &staticResolver{Result{ + A: defaultIP4.As4(), + AAAA: defaultIP6.As16(), + }} + noUpstreams := &stubUpstreamNamservers{} + + // Create two upstream nameservers that are able to resolve A and AAAA records for all names. + var upstreamAddrs []string + for i := range 2 { + upstreamServer, err := NewServer(staticResolver, noUpstreams) + require.NoError(t, err) + conn, err := net.ListenUDP("udp", udpLocalhost) + require.NoError(t, err) + + testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ + Name: fmt.Sprintf("upstream nameserver %d", i), + Task: func(ctx context.Context) error { + err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive) + if err == nil || utils.IsOKNetworkError(err) { + return nil + } + return trace.Wrap(err) + }, + Terminate: conn.Close, + }) + + upstreamAddrs = append(upstreamAddrs, conn.LocalAddr().String()) + } + + // Create the nameserver under test. + goTeleportIPv4 := tcpip.AddrFrom4([4]byte{1, 1, 1, 1}) + goTeleportIPv6 := tcpip.AddrFrom16([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) + teleportShIPv6 := tcpip.AddrFrom16([16]byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}) + resolver := &stubResolver{ + aRecords: map[string]Result{ + "goteleport.com.": Result{ + A: goTeleportIPv4.As4(), + }, + "teleport.sh.": Result{ + NoRecord: true, + }, + "fake.example.com.": Result{ + NXDomain: true, + }, + }, + aaaaRecords: map[string]Result{ + "goteleport.com.": Result{ + AAAA: goTeleportIPv6.As16(), + }, + "teleport.sh.": Result{ + AAAA: teleportShIPv6.As16(), + }, + "fake.example.com.": Result{ + NXDomain: true, + }, + }, + } + upstreams := &stubUpstreamNamservers{nameservers: upstreamAddrs} + server, err := NewServer(resolver, upstreams) + require.NoError(t, err) + + conn, err := net.ListenUDP("udp", udpLocalhost) + require.NoError(t, err) + + testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ + Name: "nameserver under test", + Task: func(ctx context.Context) error { + err := server.ListenAndServeUDP(ctx, conn, DNSModeAuthoritative) + if err == nil || utils.IsOKNetworkError(err) { + return nil + } + return trace.Wrap(err) + }, + Terminate: conn.Close, + }) + + netResolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // Always dial the resolver under test. + return net.Dial(network, conn.LocalAddr().String()) + }, + } + + for _, tc := range []struct { + desc string + host string + expectAddrs []string + expectErr string + }{ + { + desc: "v4 and v6", + host: "goteleport.com.", + expectAddrs: []string{goTeleportIPv4.String(), goTeleportIPv6.String()}, + }, + { + desc: "only v6", + host: "teleport.sh.", + expectAddrs: []string{teleportShIPv6.String()}, + }, + { + desc: "no domain", + host: "fake.example.com.", + expectErr: "no such host", + }, + { + desc: "forward disabled", + host: "example.com.", + expectErr: "no such host", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + addrs, err := netResolver.LookupHost(ctx, tc.host) + if tc.expectErr != "" { + require.ErrorContains(t, err, tc.expectErr) + return + } + require.NoError(t, err) + require.ElementsMatch(t, tc.expectAddrs, addrs) + }) + } +} + type stubResolver struct { aRecords map[string]Result aaaaRecords map[string]Result diff --git a/lib/vnet/dns/osnameservers_darwin_test.go b/lib/vnet/dns/osnameservers_darwin_test.go index c33df777bef67..74ac61ec6a526 100644 --- a/lib/vnet/dns/osnameservers_darwin_test.go +++ b/lib/vnet/dns/osnameservers_darwin_test.go @@ -50,7 +50,7 @@ func TestOSUpstreamNameservers(t *testing.T) { testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ Name: "nameserver", Task: func(ctx context.Context) error { - err := server.ListenAndServeUDP(ctx, conn) + err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive) if err == nil || utils.IsOKNetworkError(err) { return nil } diff --git a/lib/vnet/dns/osnameservers_linux.go b/lib/vnet/dns/osnameservers_linux.go new file mode 100644 index 0000000000000..0826992a6e206 --- /dev/null +++ b/lib/vnet/dns/osnameservers_linux.go @@ -0,0 +1,33 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dns + +import ( + "context" + + "github.com/gravitational/trace" +) + +// platformLoadUpstreamNameservers returns an error on Linux. VNet relies +// on systemd-resolved to handle unresolved queries. We should not attempt +// to forward those queries to other upstream servers. +func platformLoadUpstreamNameservers(context.Context) ([]string, error) { + return nil, trace.NotImplemented("upstream nameserver discovery is not supported on Linux") +} + +// Satisfy linter in linux build where withDNSPort isn't referenced. +var _ = withDNSPort From d7ac94aba73b2043ca37f28e784b260c633eb4f1 Mon Sep 17 00:00:00 2001 From: Pavel Evdokimov Date: Tue, 10 Feb 2026 18:49:24 +0000 Subject: [PATCH 2/2] fix --- lib/vnet/dns/osnameservers_other.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/vnet/dns/osnameservers_other.go b/lib/vnet/dns/osnameservers_other.go index 3cc9d156f0d6a..b8db3e96e31e9 100644 --- a/lib/vnet/dns/osnameservers_other.go +++ b/lib/vnet/dns/osnameservers_other.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -//go:build !darwin && !windows +//go:build !darwin && !windows && !linux package dns