Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce sync.Map wrapper with generics support #29452

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions contrib/scripts/lock-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ for l in sync.Mutex sync.RWMutex; do
exit 1
fi
done

if grep -r --exclude-dir={.git,_build,vendor,externalversions,lock,contrib} -i --include \*.go "sync.Map" .; then
echo "Found sync.Map usages. Please use the generic pkg/lock.Map wrapper instead";
exit 1
fi
qmonnet marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 4 additions & 7 deletions pkg/k8s/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand All @@ -20,6 +19,7 @@ import (
"github.com/cilium/cilium/pkg/hive/cell"
k8smetrics "github.com/cilium/cilium/pkg/k8s/metrics"
k8sversion "github.com/cilium/cilium/pkg/k8s/version"
"github.com/cilium/cilium/pkg/lock"
"github.com/cilium/cilium/pkg/logging"
"github.com/cilium/cilium/pkg/option"
"github.com/cilium/cilium/pkg/testutils"
Expand Down Expand Up @@ -200,13 +200,10 @@ func (s *K8sClientSuite) Test_runHeartbeat(c *C) {
}

func (s *K8sClientSuite) Test_client(c *C) {
requests := sync.Map{}
var requests lock.Map[string, *http.Request]
getRequest := func(k string) *http.Request {
v, ok := requests.Load(k)
if !ok {
return nil
}
return v.(*http.Request)
v, _ := requests.Load(k)
return v
}

srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
9 changes: 4 additions & 5 deletions pkg/kvstore/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"os"
"strconv"
"strings"
"sync"

"github.com/sirupsen/logrus"
"go.etcd.io/etcd/api/v3/mvccpb"
Expand Down Expand Up @@ -377,7 +376,7 @@ type etcdClient struct {

lastHeartbeat time.Time

leaseExpiredObservers sync.Map
leaseExpiredObservers lock.Map[string, func(string)]

// logger is the scoped logger associated with this client
logger logrus.FieldLogger
Expand Down Expand Up @@ -1691,9 +1690,9 @@ func (e *etcdClient) RegisterLeaseExpiredObserver(prefix string, fn func(key str
}

func (e *etcdClient) expiredLeaseObserver(key string) {
e.leaseExpiredObservers.Range(func(prefix, fn any) bool {
if strings.HasPrefix(key, prefix.(string)) {
fn.(func(string))(key)
e.leaseExpiredObservers.Range(func(prefix string, fn func(string)) bool {
if strings.HasPrefix(key, prefix) {
fn(key)
}
return true
})
Expand Down
19 changes: 11 additions & 8 deletions pkg/kvstore/store/syncstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"k8s.io/client-go/util/workqueue"

"github.com/cilium/cilium/pkg/kvstore"
"github.com/cilium/cilium/pkg/lock"
"github.com/cilium/cilium/pkg/logging/logfields"
"github.com/cilium/cilium/pkg/metrics"
"github.com/cilium/cilium/pkg/time"
Expand Down Expand Up @@ -66,10 +67,10 @@ type wqSyncStore struct {

limiter workqueue.RateLimiter
workqueue workqueue.RateLimitingInterface
state sync.Map /* map[string][]byte --- map[NamedKey.GetKeyName()]Key.Marshal() */
state lock.Map[string, []byte] // map[NamedKey.GetKeyName()]Key.Marshal()

synced atomic.Bool // Synced() has been triggered
pendingSync sync.Map // map[string]struct{}: the set of keys still to sync
synced atomic.Bool // Synced() has been triggered
pendingSync lock.Map[string, struct{}] // the set of keys still to sync
syncedKey string
syncedCallbacks []func(context.Context)

Expand Down Expand Up @@ -182,7 +183,7 @@ func (wss *wqSyncStore) UpsertKey(_ context.Context, k Key) error {
}

prevValue, loaded := wss.state.Swap(key, value)
if loaded && bytes.Equal(prevValue.([]byte), value) {
if loaded && bytes.Equal(prevValue, value) {
wss.log.WithField(logfields.Key, k).Debug("ignoring upsert request for already up-to-date key")
} else {
if !wss.synced.Load() {
Expand Down Expand Up @@ -246,7 +247,9 @@ func (wss *wqSyncStore) processNextItem(ctx context.Context) bool {
// Since no error occurred, forget this item so it does not get queued again
// until another change happens.
wss.workqueue.Forget(key)
wss.pendingSync.Delete(key)
if skey, ok := key.(string); ok {
wss.pendingSync.Delete(skey)
}
return true
}

Expand All @@ -255,8 +258,8 @@ func (wss *wqSyncStore) handle(ctx context.Context, key interface{}) error {
return wss.handleSync(ctx, value.skipCallbacks)
}

if value, ok := wss.state.Load(key); ok {
return wss.handleUpsert(ctx, key.(string), value.([]byte))
if value, ok := wss.state.Load(key.(string)); ok {
return wss.handleUpsert(ctx, key.(string), value)
}

return wss.handleDelete(ctx, key.(string))
Expand Down Expand Up @@ -290,7 +293,7 @@ func (wss *wqSyncStore) handleDelete(ctx context.Context, key string) error {
func (wss *wqSyncStore) handleSync(ctx context.Context, skipCallbacks bool) error {
// This could be replaced by wss.toSync.Len() == 0 if it only existed...
syncCompleted := true
wss.pendingSync.Range(func(any, any) bool {
wss.pendingSync.Range(func(string, struct{}) bool {
syncCompleted = false
return false
})
Expand Down
108 changes: 108 additions & 0 deletions pkg/lock/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package lock

import "sync"

// Map is a thin generic wrapper around sync.Map. The sync.Map description from
// the standard library follows (and is also propagated to the corresponding
// methods) for users' convenience:
//
// Map is like a Go map[interface{}]interface{} but is safe for concurrent use
// by multiple goroutines without additional locking or coordination.
// Loads, stores, and deletes run in amortized constant time.
//
// The Map type is specialized. Most code should use a plain Go map instead,
// with separate locking or coordination, for better type safety and to make it
// easier to maintain other invariants along with the map content.
//
// The Map type is optimized for two common use cases: (1) when the entry for a given
// key is only ever written once but read many times, as in caches that only grow,
// or (2) when multiple goroutines read, write, and overwrite entries for disjoint
// sets of keys. In these two cases, use of a Map may significantly reduce lock
// contention compared to a Go map paired with a separate Mutex or RWMutex.
//
// The zero Map is empty and ready for use. A Map must not be copied after first use.
type Map[K comparable, V any] sync.Map

// MapCmpValues is an extension of Map, which additionally wraps the two extra
// methods requiring values to be also of comparable type.
type MapCmpValues[K, V comparable] Map[K, V]

// Load returns the value stored in the map for a key, or the zero value if no
// value is present. The ok result indicates whether value was found in the map.
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
val, ok := (*sync.Map)(m).Load(key)
return m.convert(val, ok)
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
val, loaded := (*sync.Map)(m).LoadOrStore(key, value)
return val.(V), loaded
}

// LoadAndDelete deletes the value for a key, returning the previous value if any
// (zero value otherwise). The loaded result reports whether the key was present.
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
val, loaded := (*sync.Map)(m).LoadAndDelete(key)
return m.convert(val, loaded)
}

// Store sets the value for a key.
func (m *Map[K, V]) Store(key K, value V) {
(*sync.Map)(m).Store(key, value)
}

// Swap swaps the value for a key and returns the previous value if any (zero
// value otherwise). The loaded result reports whether the key was present.
func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) {
val, loaded := (*sync.Map)(m).Swap(key, value)
return m.convert(val, loaded)
}

// Delete deletes the value for a key.
func (m *Map[K, V]) Delete(key K) {
(*sync.Map)(m).Delete(key)
}

// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently (including by f), Range may reflect any
// mapping for that key from any point during the Range call. Range does not
// block other methods on the receiver; even f itself may call any method on m.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
(*sync.Map)(m).Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}

// CompareAndDelete deletes the entry for key if its value is equal to old.
// If there is no current value for key in the map, CompareAndDelete returns false
// (even if the old value is the nil interface value).
func (m *MapCmpValues[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
return (*sync.Map)(m).CompareAndDelete(key, old)
}

// CompareAndSwap swaps the old and new values for key if the value stored in
// the map is equal to old.
func (m *MapCmpValues[K, V]) CompareAndSwap(key K, old, new V) bool {
return (*sync.Map)(m).CompareAndSwap(key, old, new)
}

func (m *Map[K, V]) convert(value any, ok bool) (V, bool) {
if !ok {
return *new(V), false
}

return value.(V), true
}
10 changes: 3 additions & 7 deletions pkg/policy/api/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
"context"
"fmt"
"net/netip"
"sync"

"github.com/cilium/cilium/pkg/ip"
"github.com/cilium/cilium/pkg/lock"
)

const (
AWSProvider = "AWS" // AWS provider key
)

var (
providers = sync.Map{} // map with the list of providers to callback to retrieve info from.
providers lock.Map[string, GroupProviderFunc] // map with the list of providers to callback to retrieve info from.
)

// GroupProviderFunc is a func that need to be register to be able to
Expand Down Expand Up @@ -50,14 +50,10 @@ func (group *ToGroups) GetCidrSet(ctx context.Context) ([]CIDRRule, error) {
var addrs []netip.Addr
// Get per provider CIDRSet
if group.AWS != nil {
callbackInterface, ok := providers.Load(AWSProvider)
callback, ok := providers.Load(AWSProvider)
if !ok {
return nil, fmt.Errorf("Provider %s is not registered", AWSProvider)
}
callback, ok := callbackInterface.(GroupProviderFunc)
if !ok {
return nil, fmt.Errorf("Provider callback for %s is not a valid instance", AWSProvider)
}
awsAddrs, err := callback(ctx, group)
if err != nil {
return nil, fmt.Errorf(
Expand Down
9 changes: 5 additions & 4 deletions pkg/policy/groups/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
package groups

import (
"sync"
"k8s.io/apimachinery/pkg/types"

cilium_v2 "github.com/cilium/cilium/pkg/k8s/apis/cilium.io/v2"
"github.com/cilium/cilium/pkg/lock"
)

var groupsCNPCache = groupsCNPCacheMap{}

type groupsCNPCacheMap struct {
sync.Map
lock.Map[types.UID, *cilium_v2.CiliumNetworkPolicy]
}

func (cnpCache *groupsCNPCacheMap) UpdateCNP(cnp *cilium_v2.CiliumNetworkPolicy) {
Expand All @@ -25,8 +26,8 @@ func (cnpCache *groupsCNPCacheMap) DeleteCNP(cnp *cilium_v2.CiliumNetworkPolicy)

func (cnpCache *groupsCNPCacheMap) GetAllCNP() []*cilium_v2.CiliumNetworkPolicy {
result := []*cilium_v2.CiliumNetworkPolicy{}
cnpCache.Range(func(k, v interface{}) bool {
result = append(result, v.(*cilium_v2.CiliumNetworkPolicy))
cnpCache.Range(func(_ types.UID, cnp *cilium_v2.CiliumNetworkPolicy) bool {
result = append(result, cnp)
return true
})
return result
Expand Down
Loading