Skip to content

Commit

Permalink
Make throttling nmagent fetches for nodesubnet more dynamic (#3023)
Browse files Browse the repository at this point in the history
* feat(CNS): Early work on better throttling in NMAgent fetch for nodesubnet

* feat(CNS): Update NMAgent fetches to be async with binary exponential backoff

* chore: check for empty nmagent response

* test: update test for empty response

* style: make linter happy

* chore: fix some comments

* fix: Fix bug in refresh

* refactor: Address comments

* refactor: ignore primary ip

* refactor: move refresh out of ipfetcher

* test: add ip fetcher tests

* fix: remove broken import

* fix: fix import

* fix: fix linting

* fix: fix some failing tests

* chore: Remove unused function

* test: test updates

* fix: address comments

* chore: add missed file

* chore: add comment about static interval

* feat: address Evan's comment to require Equal method on cached results

* chore: add missed file

* feat: more efficient equality

* refactor: address Evan's comment

* refactor: address Tim's comment

* fix: undo accidental commit

* fix: make linter happy

* fix: make linter happy
  • Loading branch information
santhoshmprabhu authored Oct 14, 2024
1 parent 3ed0bcd commit b5046a0
Show file tree
Hide file tree
Showing 12 changed files with 581 additions and 80 deletions.
9 changes: 0 additions & 9 deletions cns/nodesubnet/helper_for_ip_fetcher_test.go

This file was deleted.

87 changes: 66 additions & 21 deletions cns/nodesubnet/ip_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,93 @@ import (
"time"

"github.com/Azure/azure-container-networking/nmagent"
"github.com/Azure/azure-container-networking/refresh"
"github.com/pkg/errors"
)

const (
// Default minimum time between secondary IP fetches
DefaultMinRefreshInterval = 4 * time.Second
// Default maximum time between secondary IP fetches
DefaultMaxRefreshInterval = 1024 * time.Second
)

var ErrRefreshSkipped = errors.New("refresh skipped due to throttling")

// InterfaceRetriever is an interface is implemented by the NMAgent Client, and also a mock client for testing.
type InterfaceRetriever interface {
GetInterfaceIPInfo(ctx context.Context) (nmagent.Interfaces, error)
}

type IPFetcher struct {
// Node subnet state
secondaryIPQueryInterval time.Duration // Minimum time between secondary IP fetches
secondaryIPLastRefreshTime time.Time // Time of last secondary IP fetch
// IPConsumer is an interface implemented by whoever consumes the secondary IPs fetched in nodesubnet
type IPConsumer interface {
UpdateIPsForNodeSubnet([]netip.Addr) error
}

ipFectcherClient InterfaceRetriever
// IPFetcher fetches secondary IPs from NMAgent at regular intervals. The
// interval will vary within the range of minRefreshInterval and
// maxRefreshInterval. When no diff is observed after a fetch, the interval
// doubles (subject to the maximum interval). When a diff is observed, the
// interval resets to the minimum.
type IPFetcher struct {
// Node subnet config
intfFetcherClient InterfaceRetriever
consumer IPConsumer
fetcher *refresh.Fetcher[nmagent.Interfaces]
}

func NewIPFetcher(nmaClient InterfaceRetriever, queryInterval time.Duration) *IPFetcher {
return &IPFetcher{
ipFectcherClient: nmaClient,
secondaryIPQueryInterval: queryInterval,
// NewIPFetcher creates a new IPFetcher. If minInterval is 0, it will default to 4 seconds.
// If maxInterval is 0, it will default to 1024 seconds (or minInterval, if it is higher).
func NewIPFetcher(
client InterfaceRetriever,
consumer IPConsumer,
minInterval time.Duration,
maxInterval time.Duration,
logger refresh.Logger,
) *IPFetcher {
if minInterval == 0 {
minInterval = DefaultMinRefreshInterval
}

if maxInterval == 0 {
maxInterval = DefaultMaxRefreshInterval
}

maxInterval = max(maxInterval, minInterval)

newIPFetcher := &IPFetcher{
intfFetcherClient: client,
consumer: consumer,
fetcher: nil,
}
fetcher := refresh.NewFetcher[nmagent.Interfaces](client.GetInterfaceIPInfo, minInterval, maxInterval, newIPFetcher.ProcessInterfaces, logger)
newIPFetcher.fetcher = fetcher
return newIPFetcher
}

// Start the IPFetcher.
func (c *IPFetcher) Start(ctx context.Context) {
c.fetcher.Start(ctx)
}

func (c *IPFetcher) RefreshSecondaryIPsIfNeeded(ctx context.Context) (ips []netip.Addr, err error) {
// If secondaryIPQueryInterval has elapsed since the last fetch, fetch secondary IPs
if time.Since(c.secondaryIPLastRefreshTime) < c.secondaryIPQueryInterval {
return nil, ErrRefreshSkipped
// Fetch IPs from NMAgent and pass to the consumer
func (c *IPFetcher) ProcessInterfaces(response nmagent.Interfaces) error {
if len(response.Entries) == 0 {
return errors.New("no interfaces found in response from NMAgent")
}

c.secondaryIPLastRefreshTime = time.Now()
response, err := c.ipFectcherClient.GetInterfaceIPInfo(ctx)
_, secondaryIPs := flattenIPListFromResponse(&response)
err := c.consumer.UpdateIPsForNodeSubnet(secondaryIPs)
if err != nil {
return nil, errors.Wrap(err, "getting interface IPs")
return errors.Wrap(err, "updating secondary IPs")
}

res := flattenIPListFromResponse(&response)
return res, nil
return nil
}

// Get the list of secondary IPs from fetched Interfaces
func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) {
func flattenIPListFromResponse(resp *nmagent.Interfaces) (primary netip.Addr, secondaryIPs []netip.Addr) {
var primaryIP netip.Addr
// For each interface...
for _, intf := range resp.Entries {
if !intf.IsPrimary {
Expand All @@ -63,15 +107,16 @@ func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) {
for _, a := range s.IPAddress {
// Primary addresses are reserved for the host.
if a.IsPrimary {
primaryIP = netip.Addr(a.Address)
continue
}

res = append(res, netip.Addr(a.Address))
secondaryIPs = append(secondaryIPs, netip.Addr(a.Address))
addressCount++
}
log.Printf("Got %d addresses from subnet %s", addressCount, s.Prefix)
}
}

return res
return primaryIP, secondaryIPs
}
131 changes: 81 additions & 50 deletions cns/nodesubnet/ip_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,102 @@ package nodesubnet_test

import (
"context"
"errors"
"net/netip"
"testing"
"time"

"github.com/Azure/azure-container-networking/cns/logger"
"github.com/Azure/azure-container-networking/cns/nodesubnet"
"github.com/Azure/azure-container-networking/nmagent"
)

// Mock client that simply tracks if refresh has been called
type TestClient struct {
fetchCalled bool
// Mock client that simply consumes fetched IPs
type TestConsumer struct {
consumeCount int
secondaryIPCount int
}

// FetchConsumeCount atomically fetches the consume count
func (c *TestConsumer) FetchConsumeCount() int {
return c.consumeCount
}

// FetchSecondaryIPCount atomically fetches the last IP count
func (c *TestConsumer) FetchSecondaryIPCount() int {
return c.consumeCount
}

// UpdateConsumeCount atomically updates the consume count
func (c *TestConsumer) updateCounts(ipCount int) {
c.consumeCount++
c.secondaryIPCount = ipCount
}

// Mock IP update
func (c *TestConsumer) UpdateIPsForNodeSubnet(ips []netip.Addr) error {
c.updateCounts(len(ips))
return nil
}

var _ nodesubnet.IPConsumer = &TestConsumer{}

// Mock client that simply satisfies the interface
type TestClient struct{}

// Mock refresh
func (c *TestClient) GetInterfaceIPInfo(_ context.Context) (nmagent.Interfaces, error) {
c.fetchCalled = true
return nmagent.Interfaces{}, nil
}

func TestRefreshSecondaryIPsIfNeeded(t *testing.T) {
getTests := []struct {
name string
shouldCall bool
interval time.Duration
}{
{
"fetch called",
true,
-1 * time.Second, // Negative timeout to force refresh
},
{
"no refresh needed",
false,
10 * time.Hour, // High timeout to avoid refresh
func TestEmptyResponse(t *testing.T) {
consumerPtr := &TestConsumer{}
fetcher := nodesubnet.NewIPFetcher(&TestClient{}, consumerPtr, 0, 0, logger.Log)
err := fetcher.ProcessInterfaces(nmagent.Interfaces{})
checkErr(t, err, true)

// No consumes, since the responses are empty
if consumerPtr.FetchConsumeCount() > 0 {
t.Error("Consume called unexpectedly, shouldn't be called since responses are empty")
}
}

func TestFlatten(t *testing.T) {
interfaces := nmagent.Interfaces{
Entries: []nmagent.Interface{
{
MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6},
IsPrimary: true,
InterfaceSubnets: []nmagent.InterfaceSubnet{
{
Prefix: "10.240.0.0/16",
IPAddress: []nmagent.NodeIP{
{
Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})),
IsPrimary: true,
},
{
Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 6})),
IsPrimary: false,
},
},
},
},
},
},
}
consumerPtr := &TestConsumer{}
fetcher := nodesubnet.NewIPFetcher(&TestClient{}, consumerPtr, 0, 0, logger.Log)
err := fetcher.ProcessInterfaces(interfaces)
checkErr(t, err, false)

clientPtr := &TestClient{}
fetcher := nodesubnet.NewIPFetcher(clientPtr, 0)

for _, test := range getTests {
test := test
t.Run(test.name, func(t *testing.T) { // Do not parallelize, as we are using a shared client
fetcher.SetSecondaryIPQueryInterval(test.interval)
ctx, cancel := testContext(t)
defer cancel()
clientPtr.fetchCalled = false
_, err := fetcher.RefreshSecondaryIPsIfNeeded(ctx)

if test.shouldCall {
if err != nil && errors.Is(err, nodesubnet.ErrRefreshSkipped) {
t.Error("refresh expected, but didn't happen")
}

checkErr(t, err, false)
} else if err == nil || !errors.Is(err, nodesubnet.ErrRefreshSkipped) {
t.Error("refresh not expected, but happened")
}
})
// 1 consume to be called
if consumerPtr.FetchConsumeCount() != 1 {
t.Error("Consume expected to be called, but not called")
}
}

// testContext creates a context from the provided testing.T that will be
// canceled if the test suite is terminated.
func testContext(t *testing.T) (context.Context, context.CancelFunc) {
if deadline, ok := t.Deadline(); ok {
return context.WithDeadline(context.Background(), deadline)
// 1 consume to be called
if consumerPtr.FetchSecondaryIPCount() != 1 {
t.Error("Wrong number of secondary IPs ", consumerPtr.FetchSecondaryIPCount())
}
return context.WithCancel(context.Background())
}

// checkErr is an assertion of the presence or absence of an error
Expand All @@ -84,3 +111,7 @@ func checkErr(t *testing.T, err error, shouldErr bool) {
t.Fatal("expected error but received none")
}
}

func init() {
logger.InitLogger("testlogs", 0, 0, "./")
}
51 changes: 51 additions & 0 deletions nmagent/equality.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package nmagent

// Equal compares two Interfaces objects for equality.
func (i Interfaces) Equal(other Interfaces) bool {
if len(i.Entries) != len(other.Entries) {
return false
}
for idx, entry := range i.Entries {
if !entry.Equal(other.Entries[idx]) {
return false
}
}
return true
}

// Equal compares two Interface objects for equality.
func (i Interface) Equal(other Interface) bool {
if len(i.InterfaceSubnets) != len(other.InterfaceSubnets) {
return false
}
for idx, subnet := range i.InterfaceSubnets {
if !subnet.Equal(other.InterfaceSubnets[idx]) {
return false
}
}
if i.IsPrimary != other.IsPrimary || !i.MacAddress.Equal(other.MacAddress) {
return false
}
return true
}

// Equal compares two InterfaceSubnet objects for equality.
func (s InterfaceSubnet) Equal(other InterfaceSubnet) bool {
if len(s.IPAddress) != len(other.IPAddress) {
return false
}
if s.Prefix != other.Prefix {
return false
}
for idx, ip := range s.IPAddress {
if !ip.Equal(other.IPAddress[idx]) {
return false
}
}
return true
}

// Equal compares two NodeIP objects for equality.
func (ip NodeIP) Equal(other NodeIP) bool {
return ip.IsPrimary == other.IsPrimary && ip.Address.Equal(other.Address)
}
12 changes: 12 additions & 0 deletions nmagent/macaddress.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ const (

type MACAddress net.HardwareAddr

func (h MACAddress) Equal(other MACAddress) bool {
if len(h) != len(other) {
return false
}
for i := range h {
if h[i] != other[i] {
return false
}
}
return true
}

func (h *MACAddress) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
var macStr string
if err := d.DecodeElement(&macStr, &start); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions refresh/equaler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package refresh

type equaler[T any] interface {
Equal(T) bool
}
Loading

0 comments on commit b5046a0

Please sign in to comment.