diff --git a/go-controller/pkg/libovsdb/ops/db_object_types.go b/go-controller/pkg/libovsdb/ops/db_object_types.go index 45c2777637..87102451c7 100644 --- a/go-controller/pkg/libovsdb/ops/db_object_types.go +++ b/go-controller/pkg/libovsdb/ops/db_object_types.go @@ -242,9 +242,10 @@ var ACLNetworkPolicyPortIndex = newObjectIDsType(acl, NetworkPolicyPortIndexOwne // ingress/egress + NetworkPolicy[In/E]gressRule idx - defines given gressPolicy. // ACLs are created for gp.portPolicies which are grouped by protocol: // - for empty policy (no selectors and no ip blocks) - empty ACL (see allIPsMatch) +// with idx=emptyIdx (-1) // OR -// - all selector-based peers ACL -// - for every IPBlock +1 ACL +// - all selector-based peers ACL with idx=emptyIdx (-1) +// - all ipBlocks combined into a single ACL with idx=ipBlockCombinedIdx (-2) // Therefore unique id for a given gressPolicy is protocol name + IPBlock idx // (protocol will be "None" if no port policy is defined, and empty policy and all // selector-based peers ACLs will have idx=-1) diff --git a/go-controller/pkg/ovn/gress_policy.go b/go-controller/pkg/ovn/gress_policy.go index ad20fadfb3..b1f844123c 100644 --- a/go-controller/pkg/ovn/gress_policy.go +++ b/go-controller/pkg/ovn/gress_policy.go @@ -22,6 +22,11 @@ import ( const ( // emptyIdx is used to create ACL for gressPolicy that doesn't have ipBlocks emptyIdx = -1 + // ipBlockCombinedIdx is used when creating an ACL for a gressPolicy + // that contains ipBlocks. Previously, one ACL was created per ipBlock. + // This is changed to create a single combined ACL for all ipBlocks, + // and this special index value identifies those new ACLs. + ipBlockCombinedIdx = -2 ) type gressPolicy struct { @@ -167,14 +172,14 @@ func (gp *gressPolicy) allIPsMatch() string { } } -func (gp *gressPolicy) getMatchFromIPBlock(lportMatch, l4Match string) []string { +func (gp *gressPolicy) getMatchFromIPBlock(lportMatch, l4Match string) string { var direction string if gp.policyType == knet.PolicyTypeIngress { direction = "src" } else { direction = "dst" } - var matchStrings []string + var ipBlockMatches []string var matchStr, ipVersion string for _, ipBlock := range gp.ipBlocks { if utilnet.IsIPv6CIDRString(ipBlock.CIDR) { @@ -185,17 +190,22 @@ func (gp *gressPolicy) getMatchFromIPBlock(lportMatch, l4Match string) []string if len(ipBlock.Except) == 0 { matchStr = fmt.Sprintf("%s.%s == %s", ipVersion, direction, ipBlock.CIDR) } else { - matchStr = fmt.Sprintf("%s.%s == %s && %s.%s != {%s}", ipVersion, direction, ipBlock.CIDR, + matchStr = fmt.Sprintf("(%s.%s == %s && %s.%s != {%s})", ipVersion, direction, ipBlock.CIDR, ipVersion, direction, strings.Join(ipBlock.Except, ", ")) } - if l4Match == libovsdbutil.UnspecifiedL4Match { - matchStr = fmt.Sprintf("%s && %s", matchStr, lportMatch) - } else { - matchStr = fmt.Sprintf("%s && %s && %s", matchStr, l4Match, lportMatch) - } - matchStrings = append(matchStrings, matchStr) + ipBlockMatches = append(ipBlockMatches, matchStr) } - return matchStrings + var l3Match string + if len(ipBlockMatches) == 1 { + l3Match = ipBlockMatches[0] + } else { + l3Match = fmt.Sprintf("(%s)", strings.Join(ipBlockMatches, " || ")) + } + + if l4Match == libovsdbutil.UnspecifiedL4Match { + return fmt.Sprintf("%s && %s", l3Match, lportMatch) + } + return fmt.Sprintf("%s && %s && %s", l3Match, l4Match, lportMatch) } // addNamespaceAddressSet adds a namespace address set to the gress policy. @@ -285,13 +295,11 @@ func (gp *gressPolicy) buildLocalPodACLs(portGroupName string, aclLogging *libov for protocol, l4Match := range libovsdbutil.GetL4MatchesFromNetworkPolicyPorts(gp.portPolicies) { if len(gp.ipBlocks) > 0 { // Add ACL allow rule for IPBlock CIDR - ipBlockMatches := gp.getMatchFromIPBlock(lportMatch, l4Match) - for ipBlockIdx, ipBlockMatch := range ipBlockMatches { - aclIDs := gp.getNetpolACLDbIDs(ipBlockIdx, protocol) - acl := libovsdbutil.BuildACLWithDefaultTier(aclIDs, types.DefaultAllowPriority, ipBlockMatch, action, - aclLogging, gp.aclPipeline) - createdACLs = append(createdACLs, acl) - } + ipBlockMatch := gp.getMatchFromIPBlock(lportMatch, l4Match) + aclIDs := gp.getNetpolACLDbIDs(ipBlockCombinedIdx, protocol) + acl := libovsdbutil.BuildACLWithDefaultTier(aclIDs, types.DefaultAllowPriority, ipBlockMatch, action, + aclLogging, gp.aclPipeline) + createdACLs = append(createdACLs, acl) } // if there are pod/namespace selector, then allow packets from/to that address_set or // if the NetworkPolicyPeer is empty, then allow from all sources or to all destinations. @@ -334,10 +342,10 @@ func (gp *gressPolicy) getNetpolACLDbIDs(ipBlockIdx int, protocol string) *libov // gress rule index libovsdbops.GressIdxKey: strconv.Itoa(gp.idx), // acls are created for every gp.portPolicies which are grouped by protocol: - // - for empty policy (no selectors and no ip blocks) - empty ACL + // - for empty policy (no selectors and no ip blocks) - empty ACL with idx=emptyIdx (-1) // OR - // - all selector-based peers ACL - // - for every IPBlock +1 ACL + // - all selector-based peers ACL with idx=emptyIdx (-1) + // - all ipBlocks combined into a single ACL with idx=ipBlockCombinedIdx (-2) // Therefore unique id for a given gressPolicy is protocol name + IPBlock idx // (protocol will be "None" if no port policy is defined, and empty policy and all // selector-based peers ACLs will have idx=-1) diff --git a/go-controller/pkg/ovn/gress_policy_test.go b/go-controller/pkg/ovn/gress_policy_test.go index 14b2a65a7c..f45be5385a 100644 --- a/go-controller/pkg/ovn/gress_policy_test.go +++ b/go-controller/pkg/ovn/gress_policy_test.go @@ -16,7 +16,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { ipBlocks []*knet.IPBlock lportMatch string l4Match string - expected []string + expected string }{ { desc: "IPv4 only no except", @@ -27,7 +27,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip4.src == 0.0.0.0/0 && input && fake"}, + expected: "ip4.src == 0.0.0.0/0 && input && fake", }, { desc: "multiple IPv4 only no except", @@ -41,8 +41,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip4.src == 0.0.0.0/0 && input && fake", - "ip4.src == 10.1.0.0/16 && input && fake"}, + expected: "(ip4.src == 0.0.0.0/0 || ip4.src == 10.1.0.0/16) && input && fake", }, { desc: "IPv6 only no except", @@ -53,7 +52,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip6.src == fd00:10:244:3::49/32 && input && fake"}, + expected: "ip6.src == fd00:10:244:3::49/32 && input && fake", }, { desc: "mixed IPv4 and IPv6 no except", @@ -67,8 +66,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip6.src == ::/0 && input && fake", - "ip4.src == 0.0.0.0/0 && input && fake"}, + expected: "(ip6.src == ::/0 || ip4.src == 0.0.0.0/0) && input && fake", }, { desc: "IPv4 only with except", @@ -80,7 +78,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16} && input && fake"}, + expected: "(ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16}) && input && fake", }, { desc: "multiple IPv4 with except", @@ -95,8 +93,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16} && input && fake", - "ip4.src == 10.1.0.0/16 && input && fake"}, + expected: "((ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16}) || ip4.src == 10.1.0.0/16) && input && fake", }, { desc: "IPv4 with IPv4 except", @@ -108,7 +105,7 @@ func TestGetMatchFromIPBlock(t *testing.T) { }, lportMatch: "fake", l4Match: "input", - expected: []string{"ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16} && input && fake"}, + expected: "(ip4.src == 0.0.0.0/0 && ip4.src != {10.1.0.0/16}) && input && fake", }, } diff --git a/go-controller/pkg/ovn/policy_test.go b/go-controller/pkg/ovn/policy_test.go index ba0c3f9592..08d0ff434b 100644 --- a/go-controller/pkg/ovn/policy_test.go +++ b/go-controller/pkg/ovn/policy_test.go @@ -6,6 +6,7 @@ import ( "net" "runtime" "sort" + "strings" "time" "github.com/onsi/ginkgo/v2" @@ -294,13 +295,27 @@ func getGressACLs(gressIdx int, peers []knet.NetworkPolicyPeer, policyType knet. acl.UUID = dbIDs.String() + "-UUID" acls = append(acls, acl) } - for i, ipBlock := range ipBlocks { - match := fmt.Sprintf("ip4.%s == %s && %s == @%s", ipDir, ipBlock, portDir, pgName) + if len(ipBlocks) > 0 { + var ipBlockMatches []string + for _, ipBlock := range ipBlocks { + ipVersion := "ip4" + if utilnet.IsIPv6CIDRString(ipBlock) { + ipVersion = "ip6" + } + ipBlockMatches = append(ipBlockMatches, fmt.Sprintf("%s.%s == %s", ipVersion, ipDir, ipBlock)) + } + var match string + if len(ipBlockMatches) == 1 { + match = ipBlockMatches[0] + } else { + match = fmt.Sprintf("(%s)", strings.Join(ipBlockMatches, " || ")) + } + match = fmt.Sprintf("%s && %s == @%s", match, portDir, pgName) action := nbdb.ACLActionAllowRelated if params.statelessNetPol { action = nbdb.ACLActionAllowStateless } - dbIDs := gp.getNetpolACLDbIDs(i, libovsdbutil.UnspecifiedL4Protocol) + dbIDs := gp.getNetpolACLDbIDs(ipBlockCombinedIdx, libovsdbutil.UnspecifiedL4Protocol) acl := libovsdbops.BuildACL( libovsdbutil.GetACLName(dbIDs), direction, @@ -363,6 +378,17 @@ func getPolicyData(params *netpolDataParams) []libovsdbtest.TestData { acls = append(acls, getGressACLs(i, egress.To, knet.PolicyTypeEgress, params)...) } + pg := getPolicyPortGroup(params, acls) + + data := []libovsdbtest.TestData{} + for _, acl := range acls { + data = append(data, acl) + } + data = append(data, pg) + return data +} + +func getPolicyPortGroup(params *netpolDataParams, acls []*nbdb.ACL) *nbdb.PortGroup { lsps := []*nbdb.LogicalSwitchPort{} for _, uuid := range params.localPortUUIDs { lsps = append(lsps, &nbdb.LogicalSwitchPort{UUID: uuid}) @@ -377,12 +403,7 @@ func getPolicyData(params *netpolDataParams) []libovsdbtest.TestData { ) pg.UUID = pg.Name + "-UUID" - data := []libovsdbtest.TestData{} - for _, acl := range acls { - data = append(data, acl) - } - data = append(data, pg) - return data + return pg } func newNetpolDataParams(networkPolicy *knet.NetworkPolicy) *netpolDataParams { @@ -958,6 +979,149 @@ var _ = ginkgo.Describe("OVN NetworkPolicy Operations", func() { } gomega.Expect(app.Run([]string{app.Name})).To(gomega.Succeed()) }) + + ginkgo.It("reconciles existing networkPolicies with has legacy ipBlock ACLs", func() { + app.Action = func(*cli.Context) error { + namespace1 := *newNamespace(namespaceName1) + namespace1AddressSetv4, _ := buildNamespaceAddressSets(namespace1.Name, nil) + peer := knet.NetworkPolicyPeer{ + IPBlock: &knet.IPBlock{ + CIDR: "1.1.1.1", + }, + } + // equivalent rules in one peer + networkPolicy1 := newNetworkPolicy(netPolicyName1, namespace1.Name, metav1.LabelSelector{}, + []knet.NetworkPolicyIngressRule{{ + From: []knet.NetworkPolicyPeer{peer, peer}, + }}, nil) + // equivalent rules in different peers + networkPolicy2 := newNetworkPolicy(netPolicyName2, namespace1.Name, metav1.LabelSelector{}, + []knet.NetworkPolicyIngressRule{ + { + From: []knet.NetworkPolicyPeer{peer}, + }, + { + From: []knet.NetworkPolicyPeer{peer}, + }, + }, nil) + initialData := initialDB.NBData + initialData = append(initialData, namespace1AddressSetv4) + defaultDenyExpectedData := getDefaultDenyDataMultiplePolicies([]*knet.NetworkPolicy{networkPolicy1, networkPolicy2}) + initialData = append(initialData, defaultDenyExpectedData...) + + // NetworkPolicy 1 contains a single gress policy that previously + // created one legacy ACL per ipBlock. Simulate two legacy ACLs + // corresponding to ipBlock indexes 0 and 1 of the gress policy. + // ACL1 => libovsdbops.GressIdxKey: 0, libovsdbops.IpBlockIndexKey: 0 + // ACL2 => libovsdbops.GressIdxKey: 0, libovsdbops.IpBlockIndexKey: 1 + netInfo := &util.DefaultNetInfo{} + fakeController := getFakeBaseController(netInfo) + controllerName := getNetworkControllerName(netInfo.GetNetworkName()) + pgName1 := fakeController.getNetworkPolicyPGName(namespace1.Name, networkPolicy1.Name) + gp1 := gressPolicy{ + policyNamespace: networkPolicy1.Namespace, + policyName: networkPolicy1.Name, + policyType: knet.PolicyTypeIngress, + idx: 0, + controllerName: controllerName, + } + var legacyACLPolicy1 []*nbdb.ACL + for idx := 0; idx < 2; idx++ { + legacyACLIDs := gp1.getNetpolACLDbIDs(idx, libovsdbutil.UnspecifiedL4Protocol) + legacyACL := libovsdbops.BuildACL( + libovsdbutil.GetACLName(legacyACLIDs), + nbdb.ACLDirectionToLport, + types.DefaultAllowPriority, + fmt.Sprintf("ip4.src == 1.1.1.1 && outport == @%s", pgName1), + nbdb.ACLActionAllowRelated, + types.OvnACLLoggingMeter, + "", + false, + legacyACLIDs.GetExternalIDs(), + nil, + types.DefaultACLTier, + ) + legacyACL.UUID = legacyACLIDs.String() + "-UUID" + initialData = append(initialData, legacyACL) + legacyACLPolicy1 = append(legacyACLPolicy1, legacyACL) + } + pgNetworkPolicy1 := getPolicyPortGroup(newNetpolDataParams(networkPolicy1), legacyACLPolicy1) + initialData = append(initialData, pgNetworkPolicy1) + + // NetworkPolicy 2 contains two gress policies, each with one legacy + // ACL per ipBlock. Simulate two legacy ACL corresponding to gress + // policy indexes 0 and 1, respectively. + // ACL1 => libovsdbops.GressIdxKey: 0, libovsdbops.IpBlockIndexKey: 0 + // ACL2 => libovsdbops.GressIdxKey: 1, libovsdbops.IpBlockIndexKey: 0 + pgName2 := fakeController.getNetworkPolicyPGName(namespace1.Name, networkPolicy2.Name) + firstgp2 := gressPolicy{ + policyNamespace: networkPolicy2.Namespace, + policyName: networkPolicy2.Name, + policyType: knet.PolicyTypeIngress, + idx: 0, + controllerName: controllerName, + } + secondgp2 := gressPolicy{ + policyNamespace: networkPolicy2.Namespace, + policyName: networkPolicy2.Name, + policyType: knet.PolicyTypeIngress, + idx: 1, + controllerName: controllerName, + } + legacyACLID := firstgp2.getNetpolACLDbIDs(0, libovsdbutil.UnspecifiedL4Protocol) + legacyACL := libovsdbops.BuildACL( + libovsdbutil.GetACLName(legacyACLID), + nbdb.ACLDirectionToLport, + types.DefaultAllowPriority, + fmt.Sprintf("ip4.src == 1.1.1.1 && outport == @%s", pgName2), + nbdb.ACLActionAllowRelated, + types.OvnACLLoggingMeter, + "", + false, + legacyACLID.GetExternalIDs(), + nil, + types.DefaultACLTier, + ) + legacyACL.UUID = legacyACLID.String() + "-UUID" + initialData = append(initialData, legacyACL) + + legacyACLID2 := secondgp2.getNetpolACLDbIDs(0, libovsdbutil.UnspecifiedL4Protocol) + legacyACL2 := libovsdbops.BuildACL( + libovsdbutil.GetACLName(legacyACLID2), + nbdb.ACLDirectionToLport, + types.DefaultAllowPriority, + fmt.Sprintf("ip4.src == 1.1.1.1 && outport == @%s", pgName2), + nbdb.ACLActionAllowRelated, + types.OvnACLLoggingMeter, + "", + false, + legacyACLID2.GetExternalIDs(), + nil, + types.DefaultACLTier, + ) + legacyACL2.UUID = legacyACLID2.String() + "-UUID" + initialData = append(initialData, legacyACL2) + pgNetworkPolicy2 := getPolicyPortGroup(newNetpolDataParams(networkPolicy2), []*nbdb.ACL{legacyACL, legacyACL2}) + initialData = append(initialData, pgNetworkPolicy2) + + startOvn(libovsdbtest.TestSetup{NBData: initialData}, []corev1.Namespace{namespace1}, + []knet.NetworkPolicy{*networkPolicy1, *networkPolicy2}, + nil, nil) + + // check the initial data is updated and all legacy ACLs should be cleaned up + gressPolicy1ExpectedData := getPolicyData(newNetpolDataParams(networkPolicy1)) + gressPolicy2ExpectedData := getPolicyData(newNetpolDataParams(networkPolicy2)) + finalData := initialDB.NBData + finalData = append(finalData, namespace1AddressSetv4) + finalData = append(finalData, gressPolicy1ExpectedData...) + finalData = append(finalData, gressPolicy2ExpectedData...) + finalData = append(finalData, defaultDenyExpectedData...) + gomega.Eventually(fakeOvn.nbClient).Should(libovsdbtest.HaveData(finalData)) + + return nil + } + gomega.Expect(app.Run([]string{app.Name})).To(gomega.Succeed()) + }) }) ginkgo.Context("during execution", func() { diff --git a/go-controller/pkg/util/dns.go b/go-controller/pkg/util/dns.go index 9466ad16f5..86d8a9e054 100644 --- a/go-controller/pkg/util/dns.go +++ b/go-controller/pkg/util/dns.go @@ -16,8 +16,12 @@ import ( ) const ( - // defaultTTL is used if an invalid or zero TTL is provided. - defaultTTL = 30 * time.Minute + // defaultMinTTL is the minimum TTL value that will be used for a domain name if an invalid or zero TTL is found + defaultMinTTL = 5 * time.Second + // defaultMaxTTL is the maximum TTL value that will be used for a domain name if an invalid or zero TTL is found + defaultMaxTTL = 2 * time.Minute + // maxRetryBeforeBackoff is the maximum number of times to retry a DNS lookup before exponential backoff starts + maxRetryBeforeBackoff = 10 ) type dnsValue struct { @@ -27,6 +31,8 @@ type dnsValue struct { ttl time.Duration // Holds (last dns lookup time + ttl), tells when to refresh IPs next time nextQueryTime time.Time + // Number of times the DNS lookup has been retried before backoff starts + retryCount int } type DNS struct { @@ -105,11 +111,22 @@ func (d *DNS) updateOne(dns string) (bool, error) { return false, fmt.Errorf("DNS value not found in dnsMap for domain: %q", dns) } - ips, ttl, err := d.getIPsAndMinTTL(dns) - if err != nil { - res.nextQueryTime = time.Now().Add(defaultTTL) - d.dnsMap[dns] = res - return false, err + ips, ttl, retry, err := d.getIPsAndMinTTL(dns) + if retry { + // If the DNS lookup has been retried maxRetryCount times, use exponential backoff + // by doubling the previous TTL. The TTL is capped at defaultMaxTTL. + if res.retryCount >= maxRetryBeforeBackoff { + ttl = min(res.ttl*2, defaultMaxTTL) + } else { + // Increment the retry count + res.retryCount++ + } + // If no valid IPs were found, use the previous IPs as fallback. + if len(ips) == 0 { + ips = res.ips + } + } else { + res.retryCount = 0 } changed := false @@ -120,10 +137,10 @@ func (d *DNS) updateOne(dns string) (bool, error) { res.ttl = ttl res.nextQueryTime = time.Now().Add(res.ttl) d.dnsMap[dns] = res - return changed, nil + return changed, err } -func (d *DNS) getIPsAndMinTTL(domain string) ([]net.IP, time.Duration, error) { +func (d *DNS) getIPsAndMinTTL(domain string) ([]net.IP, time.Duration, bool, error) { ips := []net.IP{} ttlSet := false var ttlSeconds uint32 @@ -197,19 +214,27 @@ func (d *DNS) getIPsAndMinTTL(domain string) ([]net.IP, time.Duration, error) { } if !ttlSet || (len(ips) == 0) { - return nil, defaultTTL, fmt.Errorf("IPv4 or IPv6 addr not found for domain: %q, nameservers: %v", domain, d.nameservers) + return nil, defaultMinTTL, true, fmt.Errorf("IPv4 or IPv6 addr not found for domain: %q, nameservers: %v", domain, d.nameservers) } + ips = removeDuplicateIPs(ips) + ttl, err := time.ParseDuration(fmt.Sprintf("%ds", minTTL)) if err != nil { - utilruntime.HandleError(fmt.Errorf("invalid TTL value for domain: %q, err: %v, defaulting ttl=%s", domain, err, defaultTTL.String())) - ttl = defaultTTL + utilruntime.HandleError(fmt.Errorf("invalid TTL value for domain: %q, err: %v", domain, err)) + return ips, defaultMinTTL, true, nil } if ttl == 0 { - ttl = defaultTTL + // If the TTL is 0, return the default minimum TTL. The retry is set to false as this + // is not an error scenario. TTL being 0 is a valid scenario for some DNS servers + // and it means that the IP addresses should be refreshed everytime whenever the DNS + // name is being used. From the point of view of OVN-Kubernetes, the IP addresses are + // refreshed every defaultMinTTL. + klog.V(5).Infof("TTL value is 0 for domain: %q, defaulting ttl=%s", domain, defaultMinTTL.String()) + return ips, defaultMinTTL, false, nil } - return removeDuplicateIPs(ips), ttl, nil + return ips, ttl, false, nil } func (d *DNS) GetNextQueryTime() (time.Time, string, bool) { diff --git a/go-controller/pkg/util/dns_test.go b/go-controller/pkg/util/dns_test.go index a9d248042b..9f40c176ba 100644 --- a/go-controller/pkg/util/dns_test.go +++ b/go-controller/pkg/util/dns_test.go @@ -70,13 +70,16 @@ func TestGetIPsAndMinTTL(t *testing.T) { tests := []struct { desc string errExp bool + retry bool ipv4Mode bool ipv6Mode bool dnsOpsMockHelper []ovntest.TestifyMockHelper + expectedTTL time.Duration }{ { desc: "call to Exchange fails IPv4 only", errExp: true, + retry: true, ipv4Mode: true, ipv6Mode: false, dnsOpsMockHelper: []ovntest.TestifyMockHelper{ @@ -89,10 +92,12 @@ func TestGetIPsAndMinTTL(t *testing.T) { CallTimes: 1, }, }, + expectedTTL: defaultMinTTL, }, { desc: "Exchange returns correctly but Rcode != RcodeSuccess IPv4 only", errExp: true, + retry: true, ipv4Mode: true, ipv6Mode: false, dnsOpsMockHelper: []ovntest.TestifyMockHelper{ @@ -105,6 +110,46 @@ func TestGetIPsAndMinTTL(t *testing.T) { CallTimes: 1, }, }, + expectedTTL: defaultMinTTL, + }, + { + desc: "Exchange returns correctly but with TTL 0 IPv4 only", + errExp: false, + retry: false, + ipv4Mode: true, + ipv6Mode: false, + dnsOpsMockHelper: []ovntest.TestifyMockHelper{ + {OnCallMethodName: "SetQuestion", OnCallMethodArgType: []string{"*dns.Msg", "string", "uint16"}, RetArgList: []interface{}{&dns.Msg{}}, CallTimes: 1}, + {OnCallMethodName: "Fqdn", OnCallMethodArgType: []string{"string"}, RetArgList: []interface{}{"www.test.com"}, CallTimes: 1}, + {OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{&dns.A{A: net.ParseIP("1.2.3.4")}}}, 0 * time.Second, nil}, CallTimes: 1}, + }, + expectedTTL: defaultMinTTL, + }, + { + desc: "Exchange returns correctly but no Answer IPv4 only", + errExp: true, + retry: true, + ipv4Mode: true, + ipv6Mode: false, + dnsOpsMockHelper: []ovntest.TestifyMockHelper{ + {OnCallMethodName: "SetQuestion", OnCallMethodArgType: []string{"*dns.Msg", "string", "uint16"}, RetArgList: []interface{}{&dns.Msg{}}, CallTimes: 1}, + {OnCallMethodName: "Fqdn", OnCallMethodArgType: []string{"string"}, RetArgList: []interface{}{"www.test.com"}, CallTimes: 1}, + {OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{}}, 0 * time.Second, nil}, CallTimes: 1}, + }, + expectedTTL: defaultMinTTL, + }, + { + desc: "Exchange returns correctly but with non-zero TTL IPv4 only", + errExp: false, + retry: false, + ipv4Mode: true, + ipv6Mode: false, + dnsOpsMockHelper: []ovntest.TestifyMockHelper{ + {OnCallMethodName: "SetQuestion", OnCallMethodArgType: []string{"*dns.Msg", "string", "uint16"}, RetArgList: []interface{}{&dns.Msg{}}, CallTimes: 1}, + {OnCallMethodName: "Fqdn", OnCallMethodArgType: []string{"string"}, RetArgList: []interface{}{"www.test.com"}, CallTimes: 1}, + {OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 100}, A: net.ParseIP("1.2.3.4")}}}, 0 * time.Second, nil}, CallTimes: 1}, + }, + expectedTTL: 100 * time.Second, }, } @@ -128,19 +173,22 @@ func TestGetIPsAndMinTTL(t *testing.T) { } config.IPv4Mode = tc.ipv4Mode config.IPv6Mode = tc.ipv6Mode - res, _, err := testDNS.getIPsAndMinTTL("www.test.com") - t.Log(res, err) + res, ttl, retry, err := testDNS.getIPsAndMinTTL("www.test.com") + t.Log(res, ttl, retry, err) if tc.errExp { require.Error(t, err) } else { require.NoError(t, err) } + assert.Equal(t, tc.retry, retry, "the exponentialBackoff variable should match the return from dns.getIPsAndMinTTL()") + assert.Equal(t, tc.expectedTTL, ttl, "the ttl variable should match the return from dns.getIPsAndMinTTL()") mockDNSOps.AssertExpectations(t) }) } } func TestUpdate(t *testing.T) { + config.IPv4Mode = true mockDNSOps := new(util_mocks.DNSOps) SetDNSLibOpsMockInst(mockDNSOps) @@ -252,6 +300,7 @@ func TestUpdate(t *testing.T) { } func TestAdd(t *testing.T) { + config.IPv4Mode = true dnsName := "www.testing.com" mockDNSOps := new(util_mocks.DNSOps) SetDNSLibOpsMockInst(mockDNSOps) @@ -319,3 +368,211 @@ func TestAdd(t *testing.T) { } } + +func TestIPsEqual(t *testing.T) { + tests := []struct { + desc string + oldips []net.IP + newips []net.IP + expEqual bool + }{ + { + desc: "oldips and newips are the same", + oldips: []net.IP{net.ParseIP("1.2.3.4")}, + newips: []net.IP{net.ParseIP("1.2.3.4")}, + expEqual: true, + }, + { + desc: "oldips and newips are different", + oldips: []net.IP{net.ParseIP("1.2.3.4")}, + newips: []net.IP{net.ParseIP("1.2.3.5")}, + expEqual: false, + }, + { + desc: "oldips and newips are different length", + oldips: []net.IP{net.ParseIP("1.2.3.4")}, + newips: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.5")}, + expEqual: false, + }, + { + desc: "oldips is nil and newips is not nil", + oldips: nil, + newips: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.5")}, + expEqual: false, + }, + { + desc: "oldips is empty and newips is not empty", + oldips: []net.IP{}, + newips: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.5")}, + expEqual: false, + }, + { + desc: "oldips is not nil and newips is nil", + oldips: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.5")}, + newips: nil, + expEqual: false, + }, + { + desc: "oldips is not empty and newips is empty", + oldips: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.5")}, + newips: []net.IP{}, + expEqual: false, + }, + { + desc: "oldips and newips are both nil", + oldips: nil, + newips: nil, + expEqual: true, + }, + { + desc: "oldips and newips are both empty", + oldips: []net.IP{}, + newips: []net.IP{}, + expEqual: true, + }, + { + desc: "oldips is nil and newips is empty", + oldips: nil, + newips: []net.IP{}, + expEqual: true, + }, + { + desc: "oldips is empty and newips is nil", + oldips: []net.IP{}, + newips: nil, + expEqual: true, + }, + } + for i, tc := range tests { + t.Run(fmt.Sprintf("%d:%s", i, tc.desc), func(t *testing.T) { + res := ipsEqual(tc.oldips, tc.newips) + assert.Equal(t, tc.expEqual, res) + }) + } +} + +func TestUpdateOne(t *testing.T) { + config.IPv4Mode = true + dnsName := "www.testing.com" + newIP := net.ParseIP("1.2.3.4") + fqdnOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "Fqdn", OnCallMethodArgType: []string{"string"}, RetArgList: []interface{}{dnsName}, CallTimes: 1, + } + setQuestionOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "SetQuestion", OnCallMethodArgType: []string{"*dns.Msg", "string", "uint16"}, RetArgList: []interface{}{&dns.Msg{}}, CallTimes: 1, + } + exchangeSuccessNoAnswerOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{}}, 0 * time.Second, nil}, CallTimes: 1, + } + exchangeSuccessZeroTTLOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{&dns.A{A: newIP}}}, 0 * time.Second, nil}, CallTimes: 1, + } + exchangeSuccessNonZeroTTLOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Answer: []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 100}, A: newIP}}}, 0 * time.Second, nil}, CallTimes: 1, + } + exchangeFailureOpsMockHelper := ovntest.TestifyMockHelper{ + OnCallMethodName: "Exchange", OnCallMethodArgType: []string{"*dns.Client", "*dns.Msg", "string"}, RetArgList: []interface{}{&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}, 0 * time.Second, nil}, CallTimes: 1, + } + tests := []struct { + desc string + numCalls int + exchangeOpsMockHelper ovntest.TestifyMockHelper + expTTL time.Duration + }{ + { + desc: "when Exchange function returns with Rcode != RcodeSuccess, defaultMinTTL is used", + numCalls: 1, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when Exchange function returns successfully but without Answer, defaultMinTTL is used", + numCalls: 1, + exchangeOpsMockHelper: exchangeSuccessNoAnswerOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when TTL returned is 0 by Exchange function, defaultMinTTL is used", + numCalls: 1, + exchangeOpsMockHelper: exchangeSuccessZeroTTLOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when TTL returned is 0 by Exchange function 2 times, defaultMinTTL is used", + numCalls: 2, + exchangeOpsMockHelper: exchangeSuccessZeroTTLOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when TTL returned is 0 by Exchange function 11 times, defaultMinTTL is used", + numCalls: 11, + exchangeOpsMockHelper: exchangeSuccessZeroTTLOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when Exchange function returns with Rcode != RcodeSuccess twice, defaultMinTTL is used", + numCalls: 2, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when Exchange function returns with Rcode != RcodeSuccess 10 times, defaultMinTTL is used", + numCalls: 10, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: defaultMinTTL, + }, + { + desc: "when Exchange function returns with Rcode != RcodeSuccess 11 times, defaultMinTTL is doubled", + numCalls: 11, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: 2 * defaultMinTTL, + }, + { + desc: "when Exchange function returns with Rcode != RcodeSuccess 14 times, 16 (2^4) times defaultMinTTL is used", + numCalls: 14, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: 16 * defaultMinTTL, + }, + { + desc: "when Exchange function returns with Rcode != RcodeSuccess 15 times, defaultMaxTTL is used", + numCalls: 15, + exchangeOpsMockHelper: exchangeFailureOpsMockHelper, + expTTL: defaultMaxTTL, + }, + { + desc: "when TTL returned is non-zero by Exchange function, it is used", + numCalls: 1, + exchangeOpsMockHelper: exchangeSuccessNonZeroTTLOpsMockHelper, + expTTL: 100 * time.Second, + }, + } + for i, tc := range tests { + t.Run(fmt.Sprintf("%d:%s", i, tc.desc), func(t *testing.T) { + mockDNSOps := new(util_mocks.DNSOps) + SetDNSLibOpsMockInst(mockDNSOps) + dnsOpsMockHelper := []ovntest.TestifyMockHelper{fqdnOpsMockHelper, setQuestionOpsMockHelper, tc.exchangeOpsMockHelper} + for index := 0; index < tc.numCalls; index++ { + for _, item := range dnsOpsMockHelper { + call := mockDNSOps.On(item.OnCallMethodName) + for _, arg := range item.OnCallMethodArgType { + call.Arguments = append(call.Arguments, mock.AnythingOfType(arg)) + } + for _, ret := range item.RetArgList { + call.ReturnArguments = append(call.ReturnArguments, ret) + } + call.Once() + } + } + dns := DNS{ + dnsMap: make(map[string]dnsValue), + nameservers: []string{"1.1.1.1"}, + } + dns.dnsMap[dnsName] = dnsValue{} + for i := 0; i < tc.numCalls; i++ { + _, _ = dns.updateOne(dnsName) + } + assert.Equal(t, tc.expTTL, dns.dnsMap[dnsName].ttl) + mockDNSOps.AssertExpectations(t) + }) + } +}