Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 106 additions & 13 deletions lib/vnet/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"io"
"log/slog"
"net"
"runtime"
"sync"
"time"

Expand Down Expand Up @@ -77,6 +78,16 @@ type UpstreamNameserverSource interface {
UpstreamNameservers(context.Context) ([]string, error)
}

// DNSMode controls how the server handles unmatched queries.
type DNSMode int

const (
// DNSModeRecursive forwards unmatched queries to upstream nameservers.
DNSModeRecursive DNSMode = iota
// DNSModeAuthoritative returns negative responses for unmatched queries.
DNSModeAuthoritative
)

// Server is a DNS server.
type Server struct {
resolver Resolver
Expand Down Expand Up @@ -127,13 +138,13 @@ func (s *Server) HandleUDP(ctx context.Context, conn net.Conn) error {
}
buf = buf[:n]

return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn))
return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn, s.platformDNSMode()))
}

// ListendAndServeUDP reads all incoming UDP messages from [conn], handles DNS questions, and writes the
// responses back to [conn].
// This is not called by VNet code and basically exists so we can test the resolver outside of VNet.
func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error {
func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn, mode DNSMode) error {
buf, returnBuf := s.getMessageBuffer()
defer returnBuf()

Expand All @@ -152,7 +163,7 @@ func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error
conn: conn,
remoteAddr: remoteAddr,
}
if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter); err != nil {
if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter, mode); err != nil {
s.slog.DebugContext(ctx, "Error handling DNS message.", "error", err)
}
}
Expand All @@ -168,9 +179,18 @@ func (u *udpWriter) Write(b []byte) (int, error) {
return n, err
}

func (s *Server) platformDNSMode() DNSMode {
switch runtime.GOOS {
case "linux":
return DNSModeAuthoritative
default:
return DNSModeRecursive
}
}

// handleDNSMessage handles the DNS message held in [buf] and writes the answer to [responseWriter].
// This could handle DNS messages arriving over UDP or TCP.
func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer) error {
func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer, mode DNSMode) error {
slog := s.slog.With("remote_addr", remoteAddr)
slog.DebugContext(ctx, "Handling DNS message.")
defer slog.DebugContext(ctx, "Done handling DNS message.")
Expand All @@ -181,8 +201,21 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
return trace.Wrap(err, "parsing DNS message")
}
if requestHeader.OpCode != 0 {
slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message")
switch mode {
case DNSModeRecursive:
slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message")
case DNSModeAuthoritative:
// RFC 8906 section 3.1.4 recommends NOTIMP for unsupported opcodes.
slog.DebugContext(ctx, "OpCode is not QUERY (0), responding with NOTIMP.", "opcode", requestHeader.OpCode)
question, qerr := parser.Question()
if qerr != nil {
return trace.Wrap(qerr, "parsing DNS question")
}
return trace.Wrap(writeNotImpl(buf, &requestHeader, &question, responseWriter), "authoritative non-Query DNS message")
default:
return trace.BadParameter("unknown DNS mode %v", mode)
}
}
question, err := parser.Question()
if err != nil {
Expand All @@ -192,8 +225,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
slog = slog.With("fqdn", fqdn, "type", question.Type.String())
slog.DebugContext(ctx, "Received DNS question.", "question", question)
if question.Class != dnsmessage.ClassINET {
slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query")
switch mode {
case DNSModeRecursive:
slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query")
case DNSModeAuthoritative:
// RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries.
slog.DebugContext(ctx, "Query class is not INET, responding with NXDOMAIN.", "class", question.Class)
return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative non-INET DNS query")
default:
return trace.BadParameter("unknown DNS mode %v", mode)
}
}

var result Result
Expand All @@ -209,8 +251,17 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
return trace.Wrap(err, "resolving AAAA request for %q", fqdn)
}
default:
slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type)
switch mode {
case DNSModeRecursive:
slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type)
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type)
case DNSModeAuthoritative:
// RFC 8906 section 3.1.2 recommends NXDOMAIN or NOERROR for unsupported type queries.
slog.DebugContext(ctx, "Question type is not A or AAAA, responding with NXDOMAIN.", "type", question.Type)
return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative %s DNS query", question.Type)
default:
return trace.BadParameter("unknown DNS mode %v", mode)
}
}

var response []byte
Expand All @@ -228,8 +279,16 @@ func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []
slog.DebugContext(ctx, "Matched DNS AAAA.", "aaaa", tcpip.AddrFrom16(result.AAAA))
response, err = buildAAAAResponse(buf, &requestHeader, &question, result.AAAA)
default:
slog.DebugContext(ctx, "Forwarding unmatched query.")
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query")
switch mode {
case DNSModeRecursive:
slog.DebugContext(ctx, "Forwarding unmatched query.")
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query")
case DNSModeAuthoritative:
slog.DebugContext(ctx, "Unmatched query, responding with NXDOMAIN.")
return trace.Wrap(writeNXDomain(buf, &requestHeader, &question, responseWriter), "authoritative unmatched DNS query")
default:
return trace.BadParameter("unknown DNS mode %v", mode)
}
}
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -343,6 +402,16 @@ func buildNXDomainResponse(buf []byte, requestHeader *dnsmessage.Header, questio
return buf, trace.Wrap(err, "serializing DNS response")
}

func buildNotImplResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) {
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeNotImplemented)
if err != nil {
return buf, trace.Wrap(err)
}
// TODO(nklaassen): TTL in SOA record?
buf, err = responseBuilder.Finish()
return buf, trace.Wrap(err, "serializing DNS response")
}

func buildAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [4]byte) ([]byte, error) {
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess)
if err != nil {
Expand Down Expand Up @@ -383,13 +452,37 @@ func buildAAAAResponse(buf []byte, requestHeader *dnsmessage.Header, question *d
return buf, trace.Wrap(err, "serializing DNS response")
}

func writeNXDomain(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error {
if question == nil {
return trace.Errorf("missing DNS question for NXDOMAIN response")
}
response, err := buildNXDomainResponse(buf, requestHeader, question)
if err != nil {
return trace.Wrap(err)
}
_, err = responseWriter.Write(response)
return trace.Wrap(err, "writing DNS response")
}

func writeNotImpl(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, responseWriter io.Writer) error {
if question == nil {
return trace.Errorf("missing DNS question for NOTIMP response")
}
response, err := buildNotImplResponse(buf, requestHeader, question)
if err != nil {
return trace.Wrap(err)
}
_, err = responseWriter.Write(response)
return trace.Wrap(err, "writing DNS response")
}

func prepDNSResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, rcode dnsmessage.RCode) (*dnsmessage.Builder, error) {
buf = buf[:0]
responseBuilder := dnsmessage.NewBuilder(buf, dnsmessage.Header{
ID: requestHeader.ID,
Response: true,
Authoritative: true,
RCode: dnsmessage.RCodeSuccess,
RCode: rcode,
})
responseBuilder.EnableCompression()
if err := responseBuilder.StartQuestions(); err != nil {
Expand Down
134 changes: 132 additions & 2 deletions lib/vnet/dns/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestServer(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: fmt.Sprintf("upstream nameserver %d", i),
Task: func(ctx context.Context) error {
err := upstreamServer.ListenAndServeUDP(ctx, conn)
err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestServer(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: "nameserver under test",
Task: func(ctx context.Context) error {
err := server.ListenAndServeUDP(ctx, conn)
err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
Expand Down Expand Up @@ -173,6 +173,136 @@ func TestServer(t *testing.T) {
}
}

func TestServerAuthoritativeMode(t *testing.T) {
t.Parallel()
ctx := context.Background()

defaultIP4 := tcpip.AddrFrom4([4]byte{1, 2, 3, 4})
defaultIP6 := tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})

staticResolver := &staticResolver{Result{
A: defaultIP4.As4(),
AAAA: defaultIP6.As16(),
}}
noUpstreams := &stubUpstreamNamservers{}

// Create two upstream nameservers that are able to resolve A and AAAA records for all names.
var upstreamAddrs []string
for i := range 2 {
upstreamServer, err := NewServer(staticResolver, noUpstreams)
require.NoError(t, err)
conn, err := net.ListenUDP("udp", udpLocalhost)
require.NoError(t, err)

testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: fmt.Sprintf("upstream nameserver %d", i),
Task: func(ctx context.Context) error {
err := upstreamServer.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
return trace.Wrap(err)
},
Terminate: conn.Close,
})

upstreamAddrs = append(upstreamAddrs, conn.LocalAddr().String())
}

// Create the nameserver under test.
goTeleportIPv4 := tcpip.AddrFrom4([4]byte{1, 1, 1, 1})
goTeleportIPv6 := tcpip.AddrFrom16([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
teleportShIPv6 := tcpip.AddrFrom16([16]byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2})
resolver := &stubResolver{
aRecords: map[string]Result{
"goteleport.com.": Result{
A: goTeleportIPv4.As4(),
},
"teleport.sh.": Result{
NoRecord: true,
},
"fake.example.com.": Result{
NXDomain: true,
},
},
aaaaRecords: map[string]Result{
"goteleport.com.": Result{
AAAA: goTeleportIPv6.As16(),
},
"teleport.sh.": Result{
AAAA: teleportShIPv6.As16(),
},
"fake.example.com.": Result{
NXDomain: true,
},
},
}
upstreams := &stubUpstreamNamservers{nameservers: upstreamAddrs}
server, err := NewServer(resolver, upstreams)
require.NoError(t, err)

conn, err := net.ListenUDP("udp", udpLocalhost)
require.NoError(t, err)

testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: "nameserver under test",
Task: func(ctx context.Context) error {
err := server.ListenAndServeUDP(ctx, conn, DNSModeAuthoritative)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
return trace.Wrap(err)
},
Terminate: conn.Close,
})

netResolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
// Always dial the resolver under test.
return net.Dial(network, conn.LocalAddr().String())
},
}

for _, tc := range []struct {
desc string
host string
expectAddrs []string
expectErr string
}{
{
desc: "v4 and v6",
host: "goteleport.com.",
expectAddrs: []string{goTeleportIPv4.String(), goTeleportIPv6.String()},
},
{
desc: "only v6",
host: "teleport.sh.",
expectAddrs: []string{teleportShIPv6.String()},
},
{
desc: "no domain",
host: "fake.example.com.",
expectErr: "no such host",
},
{
desc: "forward disabled",
host: "example.com.",
expectErr: "no such host",
},
} {
t.Run(tc.desc, func(t *testing.T) {
addrs, err := netResolver.LookupHost(ctx, tc.host)
if tc.expectErr != "" {
require.ErrorContains(t, err, tc.expectErr)
return
}
require.NoError(t, err)
require.ElementsMatch(t, tc.expectAddrs, addrs)
})
}
}

type stubResolver struct {
aRecords map[string]Result
aaaaRecords map[string]Result
Expand Down
2 changes: 1 addition & 1 deletion lib/vnet/dns/osnameservers_darwin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestOSUpstreamNameservers(t *testing.T) {
testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
Name: "nameserver",
Task: func(ctx context.Context) error {
err := server.ListenAndServeUDP(ctx, conn)
err := server.ListenAndServeUDP(ctx, conn, DNSModeRecursive)
if err == nil || utils.IsOKNetworkError(err) {
return nil
}
Expand Down
Loading
Loading