Skip to content

Commit

Permalink
correct handle corner cases
Browse files Browse the repository at this point in the history
Signed-off-by: Denis Tingaikin <[email protected]>
  • Loading branch information
denis-tingaikin committed Jul 3, 2022
1 parent ab30ed7 commit baeefca
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
13 changes: 9 additions & 4 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"context"
"fmt"
"net"
"net/url"
"testing"
"time"

Expand All @@ -40,6 +39,12 @@ import (
"github.com/networkservicemesh/sdk/pkg/tools/sandbox"
)

func staticIP(addr net.IP) func() net.IP {
return func() net.IP {
return addr
}
}

func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

Expand Down Expand Up @@ -70,7 +75,7 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
&url.URL{Scheme: "tcp", Host: "127.0.0.1"},
staticIP(net.ParseIP("127.0.0.1")),
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)),
)
Expand Down Expand Up @@ -161,7 +166,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
&url.URL{Scheme: "tcp", Host: "127.0.0.1"},
staticIP(net.ParseIP("127.0.0.1")),
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSListenAndServeFunc(func(ctx context.Context, handler dnsutils.Handler, listenOn string) {
dnsutils.ListenAndServe(ctx, handler, ":50053")
Expand All @@ -182,7 +187,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
defer close(clientPrefixCh)

clientPrefixCh <- &ipam.PrefixResponse{Prefix: "127.0.0.1/32"}
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3.NewClient(ctx, clientPrefixCh), vl3dns.NewClient(&url.URL{Host: "127.0.0.1"})))
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3.NewClient(ctx, clientPrefixCh), vl3dns.NewClient(net.ParseIP("127.0.0.1"))))

req := defaultRequest(nsReg.Name)
req.Connection.Id = uuid.New().String()
Expand Down
10 changes: 5 additions & 5 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package vl3dns

import (
"context"
"net/url"
"net"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"
Expand All @@ -28,14 +28,14 @@ import (
)

type vl3DNSClient struct {
listenOn *url.URL
dnsServerIP net.IP
}

// NewClient - returns a new null client that does nothing but call next.Client(ctx).{Request/Close} and return the result
// This is very useful in testing
func NewClient(listenOn *url.URL) networkservice.NetworkServiceClient {
func NewClient(dnsServerIP net.IP) networkservice.NetworkServiceClient {
return &vl3DNSClient{
listenOn: listenOn,
dnsServerIP: dnsServerIP,
}
}

Expand All @@ -52,7 +52,7 @@ func (n *vl3DNSClient) Request(ctx context.Context, request *networkservice.Netw

request.GetConnection().GetContext().GetDnsContext().Configs = []*networkservice.DNSConfig{
{
DnsServerIps: []string{n.listenOn.Hostname()},
DnsServerIps: []string{n.dnsServerIP.String()},
},
}
return next.Client(ctx).Request(ctx, request, opts...)
Expand Down
31 changes: 23 additions & 8 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
dnsnext "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/next"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion"
"github.com/networkservicemesh/sdk/pkg/tools/ippool"
)

type vl3DNSServer struct {
Expand All @@ -47,7 +48,7 @@ type vl3DNSServer struct {
dnsPort int
dnsServer dnsutils.Handler
listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string)
listenOn *url.URL
getDNSServerIP func() net.IP
}

type clientDNSNameKey struct{}
Expand All @@ -57,11 +58,11 @@ type clientDNSNameKey struct{}
// By default is using fanout dns handler to connect to other vl3 nses.
// chanCtx is using for signal to stop dns server.
// opts confugre vl3dns networkservice instance with specific behavior.
func NewServer(chanCtx context.Context, listenOn *url.URL, opts ...Option) networkservice.NetworkServiceServer {
func NewServer(chanCtx context.Context, getDNSServerIP func() net.IP, opts ...Option) networkservice.NetworkServiceServer {
var result = &vl3DNSServer{
dnsPort: 53,
listenAndServeDNS: dnsutils.ListenAndServe,
listenOn: listenOn,
getDNSServerIP: getDNSServerIP,
}

for _, opt := range opts {
Expand Down Expand Up @@ -89,15 +90,20 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw

var dnsContext = request.GetConnection().GetContext().GetDnsContext()

for _, config := range dnsContext.GetConfigs() {
for _, serverIP := range config.DnsServerIps {
var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", serverIP, n.dnsPort)}
n.fanoutAddresses.Store(u, struct{}{})
if srcRoutes := request.GetConnection().GetContext().GetIpContext().GetSrcIPRoutes(); len(srcRoutes) > 0 {
var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix
for _, config := range dnsContext.GetConfigs() {
for _, serverIP := range config.DnsServerIps {
if withinPrefix(serverIP, lastPrefix) {
var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", serverIP, n.dnsPort)}
n.fanoutAddresses.Store(u, struct{}{})
}
}
}
}

dnsContext.Configs = append(dnsContext.Configs, &networkservice.DNSConfig{
DnsServerIps: []string{n.listenOn.Hostname()},
DnsServerIps: []string{n.getDNSServerIP().String()},
})

var recordNames, err = n.buildSrcDNSRecords(request.GetConnection())
Expand Down Expand Up @@ -178,3 +184,12 @@ func compareStringSlices(a, b []string) bool {
}
return true
}

func withinPrefix(ipAddr, prefix string) bool {
_, ipNet, err := net.ParseCIDR(prefix)
if err != nil {
return false
}
var pool = ippool.NewWithNet(ipNet)
return pool.ContainsString(ipAddr)
}

0 comments on commit baeefca

Please sign in to comment.