diff --git a/lib/vnet/dns/dns.go b/lib/vnet/dns/dns.go
index fac646f231256..f4083d0bf0f54 100644
--- a/lib/vnet/dns/dns.go
+++ b/lib/vnet/dns/dns.go
@@ -22,6 +22,7 @@ import (
"io"
"log/slog"
"net"
+ "runtime"
"sync"
"time"
@@ -77,6 +78,16 @@ type UpstreamNameserverSource interface {
UpstreamNameservers(context.Context) ([]string, error)
}
+// DNSMode controls how the server handles unmatched queries.
+type DNSMode int
+
+const (
+ // DNSModeRecursive forwards unmatched queries to upstream nameservers.
+ DNSModeRecursive DNSMode = iota
+ // DNSModeAuthoritative returns negative responses for unmatched queries.
+ DNSModeAuthoritative
+)
+
// Server is a DNS server.
type Server struct {
resolver Resolver
@@ -127,13 +138,13 @@ func (s *Server) HandleUDP(ctx context.Context, conn net.Conn) error {
}
buf = buf[:n]
- return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn))
+ return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn, s.platformDNSMode()))
}
// ListendAndServeUDP reads all incoming UDP messages from [conn], handles DNS questions, and writes the
// responses back to [conn].
// This is not called by VNet code and basically exists so we can test the resolver outside of VNet.
-func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error {
+func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn, mode DNSMode) error {
buf, returnBuf := s.getMessageBuffer()
defer returnBuf()
@@ -152,7 +163,7 @@ func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error
conn: conn,
remoteAddr: remoteAddr,
}
- if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter); err != nil {
+ if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter, mode); err != nil {
s.slog.DebugContext(ctx, "Error handling DNS message.", "error", err)
}
}
@@ -168,9 +179,18 @@ func (u *udpWriter) Write(b []byte) (int, error) {
return n, err
}
+func (s *Server) platformDNSMode() DNSMode {
+ switch runtime.GOOS {
+ case "linux":
+ return DNSModeAuthoritative
+ default:
+ return DNSModeRecursive
+ }
+}
+
// handleDNSMessage handles the DNS message held in [buf] and writes the answer to [responseWriter].
// This could handle DNS messages arriving over UDP or TCP.
-func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer) error {
+func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer, mode DNSMode) error {
slog := s.slog.With("remote_addr", remoteAddr)
slog.DebugContext(ctx, "Handling DNS message.")
defer slog.DebugContext(ctx, "Done handling DNS message.")
@@ -181,8 +201,21 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
return trace.Wrap(err, "parsing DNS message")
}
if requestHeader.OpCode != 0 {
- slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode)
- return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message")
+ switch mode {
+ case DNSModeRecursive:
+ slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message")
+ case DNSModeAuthoritative:
+ // RFC 8906 section 3.1.4 recommends NOTIMP for unsupported opcodes.
+ slog.DebugContext(ctx, "OpCode is not QUERY (0), responding with NOTIMP.", "opcode", requestHeader.OpCode)
+ question, qerr := parser.Question()
+ if qerr != nil {
+ return trace.Wrap(qerr, "parsing DNS question")
+ }
+ return trace.Wrap(writeNotImpl(buf, &requestHeader, &question, responseWriter), "authoritative non-Query DNS message")
+ default:
+ return trace.BadParameter("unknown DNS mode %v", mode)
+ }
}
question, err := parser.Question()
if err != nil {
@@ -192,8 +225,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
slog = slog.With("fqdn", fqdn, "type", question.Type.String())
slog.DebugContext(ctx, "Received DNS question.", "question", question)
if question.Class != dnsmessage.ClassINET {
- slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class)
- return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query")
+ switch mode {
+ case DNSModeRecursive:
+ slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query")
+ case DNSModeAuthoritative:
+ // RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries.
+ slog.DebugContext(ctx, "Query class is not INET, responding with NXDOMAIN.", "class", question.Class)
+ return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative non-INET DNS query")
+ default:
+ return trace.BadParameter("unknown DNS mode %v", mode)
+ }
}
var result Result
@@ -209,8 +251,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
return trace.Wrap(err, "resolving AAAA request for %q", fqdn)
}
default:
- slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type)
- return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type)
+ switch mode {
+ case DNSModeRecursive:
+ slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type)
+ case DNSModeAuthoritative:
+ // RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries.
+ slog.DebugContext(ctx, "Question type is not A or AAAA, responding with NXDOMAIN.", "type", question.Type)
+ return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative %s DNS query", question.Type)
+ default:
+ return trace.BadParameter("unknown DNS mode %v", mode)
+ }
}
var response []byte
@@ -228,8 +279,16 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
slog.DebugContext(ctx, "Matched DNS AAAA.", "aaaa", tcpip.AddrFrom16(result.AAAA))
response, err = buildAAAAResponse(buf, &requestHeader, &question, result.AAAA)
default:
- slog.DebugContext(ctx, "Forwarding unmatched query.")
- return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query")
+ switch mode {
+ case DNSModeRecursive:
+ slog.DebugContext(ctx, "Forwarding unmatched query.")
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query")
+ case DNSModeAuthoritative:
+ slog.DebugContext(ctx, "Unmatched query, responding with NXDOMAIN.")
+ return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative unmatched DNS query")
+ default:
+ return trace.BadParameter("unknown DNS mode %v", mode)
+ }
}
if err != nil {
return trace.Wrap(err)
@@ -343,6 +402,16 @@ func buildNXDomainResponse(buf []byte, requestHeader *dnsmessage.Header, questio
return buf, trace.Wrap(err, "serializing DNS response")
}
+func buildNotImplResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) {
+ responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeNotImplemented)
+ if err != nil {
+ return buf, trace.Wrap(err)
+ }
+ // TODO(nklaassen): TTL in SOA record?
+ buf, err = responseBuilder.Finish()
+ return buf, trace.Wrap(err, "serializing DNS response")
+}
+
func buildAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [4]byte) ([]byte, error) {
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess)
if err != nil {
@@ -383,13 +452,37 @@ func buildAAAAResponse(buf []byte, requestHeader *dnsmessage.Header, question *d
return buf, trace.Wrap(err, "serializing DNS response")
}
+func writeNXDomain(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error {
+ if question == nil {
+ return trace.Errorf("missing DNS question for NXDOMAIN response")
+ }
+ response, err := buildNXDomainResponse(buf, requestHeader, question)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ _, err = responseWriter.Write(response)
+ return trace.Wrap(err, "writing DNS response")
+}
+
+func writeNotImpl(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error {
+ if question == nil {
+ return trace.Errorf("missing DNS question for NOTIMP response")
+ }
+ response, err := buildNotImplResponse(buf, requestHeader, question)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ _, err = responseWriter.Write(response)
+ return trace.Wrap(err, "writing DNS response")
+}
+
func prepDNSResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, rcode dnsmessage.RCode) (*dnsmessage.Builder, error) {
buf = buf[:0]
responseBuilder := dnsmessage.NewBuilder(buf, dnsmessage.Header{
ID: requestHeader.ID,
Response: true,
Authoritative: true,
- RCode: dnsmessage.RCodeSuccess,
+ RCode: rcode,
})
responseBuilder.EnableCompression()
if err := responseBuilder.StartQuestions(); err != nil {
diff --git a/lib/vnet/dns/dns_test.go b/lib/vnet/dns/dns_test.go
index b89418b5df646..ca5fd4a89ded5 100644
--- a/lib/vnet/dns/dns_test.go
+++ b/lib/vnet/dns/dns_test.go
@@ -67,7 +67,7 @@ func TestServer(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: fmt.Sprintf("upstream nameserver %d", i),
Task: func(ctx context.Context) error {
- err := upstreamServer.ListenAndServeUDP(ctx, conn)
+ err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
@@ -117,7 +117,7 @@ func TestServer(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: "nameserver under test",
Task: func(ctx context.Context) error {
- err := server.ListenAndServeUDP(ctx, conn)
+ err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
@@ -173,6 +173,136 @@ func TestServer(t *testing.T) {
}
}
+func TestServerAuthoritativeMode(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ defaultIP4 := tcpip.AddrFrom4([4]byte{1, 2, 3, 4})
+ defaultIP6 := tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})
+
+ staticResolver := &staticResolver{Result{
+ A: defaultIP4.As4(),
+ AAAA: defaultIP6.As16(),
+ }}
+ noUpstreams := &stubUpstreamNamservers{}
+
+ // Create two upstream nameservers that are able to resolve A and AAAA records for all names.
+ var upstreamAddrs []string
+ for i := range 2 {
+ upstreamServer, err := NewServer(staticResolver, noUpstreams)
+ require.NoError(t, err)
+ conn, err := net.ListenUDP("udp", udpLocalhost)
+ require.NoError(t, err)
+
+ testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
+ Name: fmt.Sprintf("upstream nameserver %d", i),
+ Task: func(ctx context.Context) error {
+ err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
+ if err == nil || utils.IsOKNetworkError(err) {
+ return nil
+ }
+ return trace.Wrap(err)
+ },
+ Terminate: conn.Close,
+ })
+
+ upstreamAddrs = append(upstreamAddrs, conn.LocalAddr().String())
+ }
+
+ // Create the nameserver under test.
+ goTeleportIPv4 := tcpip.AddrFrom4([4]byte{1, 1, 1, 1})
+ goTeleportIPv6 := tcpip.AddrFrom16([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
+ teleportShIPv6 := tcpip.AddrFrom16([16]byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2})
+ resolver := &stubResolver{
+ aRecords: map[string]Result{
+ "goteleport.com.": Result{
+ A: goTeleportIPv4.As4(),
+ },
+ "teleport.sh.": Result{
+ NoRecord: true,
+ },
+ "fake.example.com.": Result{
+ NXDomain: true,
+ },
+ },
+ aaaaRecords: map[string]Result{
+ "goteleport.com.": Result{
+ AAAA: goTeleportIPv6.As16(),
+ },
+ "teleport.sh.": Result{
+ AAAA: teleportShIPv6.As16(),
+ },
+ "fake.example.com.": Result{
+ NXDomain: true,
+ },
+ },
+ }
+ upstreams := &stubUpstreamNamservers{nameservers: upstreamAddrs}
+ server, err := NewServer(resolver, upstreams)
+ require.NoError(t, err)
+
+ conn, err := net.ListenUDP("udp", udpLocalhost)
+ require.NoError(t, err)
+
+ testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
+ Name: "nameserver under test",
+ Task: func(ctx context.Context) error {
+ err := server.ListenAndServeUDP(ctx, conn, DNSModeAuthoritative)
+ if err == nil || utils.IsOKNetworkError(err) {
+ return nil
+ }
+ return trace.Wrap(err)
+ },
+ Terminate: conn.Close,
+ })
+
+ netResolver := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ // Always dial the resolver under test.
+ return net.Dial(network, conn.LocalAddr().String())
+ },
+ }
+
+ for _, tc := range []struct {
+ desc string
+ host string
+ expectAddrs []string
+ expectErr string
+ }{
+ {
+ desc: "v4 and v6",
+ host: "goteleport.com.",
+ expectAddrs: []string{goTeleportIPv4.String(), goTeleportIPv6.String()},
+ },
+ {
+ desc: "only v6",
+ host: "teleport.sh.",
+ expectAddrs: []string{teleportShIPv6.String()},
+ },
+ {
+ desc: "no domain",
+ host: "fake.example.com.",
+ expectErr: "no such host",
+ },
+ {
+ desc: "forward disabled",
+ host: "example.com.",
+ expectErr: "no such host",
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ addrs, err := netResolver.LookupHost(ctx, tc.host)
+ if tc.expectErr != "" {
+ require.ErrorContains(t, err, tc.expectErr)
+ return
+ }
+ require.NoError(t, err)
+ require.ElementsMatch(t, tc.expectAddrs, addrs)
+ })
+ }
+}
+
type stubResolver struct {
aRecords map[string]Result
aaaaRecords map[string]Result
diff --git a/lib/vnet/dns/osnameservers_darwin_test.go b/lib/vnet/dns/osnameservers_darwin_test.go
index c33df777bef67..74ac61ec6a526 100644
--- a/lib/vnet/dns/osnameservers_darwin_test.go
+++ b/lib/vnet/dns/osnameservers_darwin_test.go
@@ -50,7 +50,7 @@ func TestOSUpstreamNameservers(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: "nameserver",
Task: func(ctx context.Context) error {
- err := server.ListenAndServeUDP(ctx, conn)
+ err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
diff --git a/lib/vnet/dns/osnameservers_linux.go b/lib/vnet/dns/osnameservers_linux.go
new file mode 100644
index 0000000000000..0826992a6e206
--- /dev/null
+++ b/lib/vnet/dns/osnameservers_linux.go
@@ -0,0 +1,33 @@
+// Teleport
+// Copyright (C) 2026 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package dns
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+)
+
+// platformLoadUpstreamNameservers returns an error on Linux. VNet relies
+// on systemd-resolved to handle unresolved queries. We should not attempt
+// to forward those queries to other upstream servers.
+func platformLoadUpstreamNameservers(context.Context) ([]string, error) {
+ return nil, trace.NotImplemented("upstream nameserver discovery is not supported on Linux")
+}
+
+// Satisfy linter in linux build where withDNSPort isn't referenced.
+var _ = withDNSPort
diff --git a/lib/vnet/dns/osnameservers_other.go b/lib/vnet/dns/osnameservers_other.go
index 3cc9d156f0d6a..b8db3e96e31e9 100644
--- a/lib/vnet/dns/osnameservers_other.go
+++ b/lib/vnet/dns/osnameservers_other.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-//go:build !darwin && !windows
+//go:build !darwin && !windows && !linux
package dns