Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/cache/clustering/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ go_library(
"//pkg/assert",
"//pkg/batch",
"//pkg/cache",
"//pkg/cache/clustering/metrics",
"//pkg/cluster",
"//pkg/logger",
],
Expand Down
4 changes: 4 additions & 0 deletions pkg/cache/clustering/broadcaster_gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

cachev1 "github.com/unkeyed/unkey/gen/proto/cache/v1"
clusterv1 "github.com/unkeyed/unkey/gen/proto/cluster/v1"
"github.com/unkeyed/unkey/pkg/cache/clustering/metrics"
"github.com/unkeyed/unkey/pkg/cluster"
"github.com/unkeyed/unkey/pkg/logger"
)
Expand Down Expand Up @@ -57,7 +58,10 @@ func (b *GossipBroadcaster) Broadcast(_ context.Context, events ...*cachev1.Cach
if err := b.cluster.Broadcast(&clusterv1.ClusterMessage_CacheInvalidation{
CacheInvalidation: event,
}); err != nil {
metrics.CacheClusteringBroadcastErrorsTotal.Inc()
logger.Error("Failed to broadcast cache invalidation", "error", err)
} else {
metrics.CacheClusteringInvalidationsSentTotal.WithLabelValues(event.CacheName, metrics.ActionLabel(event)).Inc()
}
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/cache/clustering/cluster_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/unkeyed/unkey/pkg/assert"
"github.com/unkeyed/unkey/pkg/batch"
"github.com/unkeyed/unkey/pkg/cache"
"github.com/unkeyed/unkey/pkg/cache/clustering/metrics"
"github.com/unkeyed/unkey/pkg/logger"
)

Expand Down Expand Up @@ -239,8 +240,11 @@ func (c *ClusterCache[K, V]) Name() string {
// HandleInvalidation processes a cache invalidation event.
// Returns true if the event was handled by this cache.
func (c *ClusterCache[K, V]) HandleInvalidation(ctx context.Context, event *cachev1.CacheInvalidationEvent) bool {
actionLabel := metrics.ActionLabel(event)

// Ignore our own events to avoid loops
if event.GetSourceInstance() == c.nodeID {
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(c.cacheName, actionLabel, "skipped_self").Inc()
return false
}

Expand All @@ -252,11 +256,13 @@ func (c *ClusterCache[K, V]) HandleInvalidation(ctx context.Context, event *cach
switch event.Action.(type) {
case *cachev1.CacheInvalidationEvent_ClearAll:
c.localCache.Clear(ctx)
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(c.cacheName, "clear_all", "handled").Inc()
return true

case *cachev1.CacheInvalidationEvent_CacheKey:
key, err := c.stringToKey(event.GetCacheKey())
if err != nil {
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(c.cacheName, "key", "error").Inc()
logger.Warn(
"Failed to convert cache key",
"cache", c.cacheName,
Expand All @@ -266,9 +272,11 @@ func (c *ClusterCache[K, V]) HandleInvalidation(ctx context.Context, event *cach
return false
}
c.onInvalidation(ctx, key)
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(c.cacheName, "key", "handled").Inc()
return true

default:
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(c.cacheName, "unknown", "error").Inc()
logger.Warn("Unknown cache invalidation action", "cache", c.cacheName)
return false
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/cache/clustering/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

cachev1 "github.com/unkeyed/unkey/gen/proto/cache/v1"
"github.com/unkeyed/unkey/pkg/assert"
"github.com/unkeyed/unkey/pkg/cache/clustering/metrics"
)

// InvalidationHandler is an interface that cluster caches implement
Expand Down Expand Up @@ -59,6 +60,8 @@ func (d *InvalidationDispatcher) handleEvent(ctx context.Context, event *cachev1

// If we don't have a handler for this cache, skip it
if !exists {
actionLabel := metrics.ActionLabel(event)
metrics.CacheClusteringInvalidationsReceivedTotal.WithLabelValues(event.GetCacheName(), actionLabel, "skipped_unknown").Inc()
return nil
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/cache/clustering/metrics/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@rules_go//go:def.bzl", "go_library")

go_library(
name = "metrics",
srcs = ["prometheus.go"],
importpath = "github.com/unkeyed/unkey/pkg/cache/clustering/metrics",
visibility = ["//visibility:public"],
deps = [
"//gen/proto/cache/v1:cache",
"@com_github_prometheus_client_golang//prometheus",
"@com_github_prometheus_client_golang//prometheus/promauto",
],
)
56 changes: 56 additions & 0 deletions pkg/cache/clustering/metrics/prometheus.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package metrics

import (
cachev1 "github.com/unkeyed/unkey/gen/proto/cache/v1"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)

var (
// CacheClusteringInvalidationsSentTotal counts outbound invalidation events
// by cache name and action type.
CacheClusteringInvalidationsSentTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "unkey",
Subsystem: "cache_clustering",
Name: "invalidations_sent_total",
Help: "Total number of outbound cache invalidation events by cache name and action.",
},
[]string{"cache_name", "action"},
)

// CacheClusteringInvalidationsReceivedTotal counts inbound invalidation events
// by cache name, action, and processing status.
CacheClusteringInvalidationsReceivedTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "unkey",
Subsystem: "cache_clustering",
Name: "invalidations_received_total",
Help: "Total number of inbound cache invalidation events by cache name, action, and status.",
},
[]string{"cache_name", "action", "status"},
)

// CacheClusteringBroadcastErrorsTotal counts failed broadcast attempts.
CacheClusteringBroadcastErrorsTotal = promauto.NewCounter(
prometheus.CounterOpts{
Namespace: "unkey",
Subsystem: "cache_clustering",
Name: "broadcast_errors_total",
Help: "Total number of failed cache invalidation broadcast attempts.",
},
)
)

// ActionLabel returns a label string for the action oneof in a CacheInvalidationEvent.
func ActionLabel(event *cachev1.CacheInvalidationEvent) string {
switch event.Action.(type) {
case *cachev1.CacheInvalidationEvent_CacheKey:
return "key"
case *cachev1.CacheInvalidationEvent_ClearAll:
return "clear_all"
default:
return "unknown"
}
}
2 changes: 2 additions & 0 deletions pkg/cluster/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//gen/proto/cluster/v1:cluster",
"//pkg/cluster/metrics",
"//pkg/logger",
"//pkg/repeat",
"@com_github_hashicorp_memberlist//:memberlist",
"@org_golang_google_protobuf//proto",
],
Expand Down
8 changes: 8 additions & 0 deletions pkg/cluster/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/hashicorp/memberlist"
"github.com/unkeyed/unkey/pkg/cluster/metrics"
"github.com/unkeyed/unkey/pkg/logger"
)

Expand Down Expand Up @@ -84,6 +85,9 @@ func (c *gossipCluster) promoteToBridge() {
seeds := c.config.WANSeeds
c.mu.Unlock()

metrics.ClusterBridgeStatus.Set(1)
metrics.ClusterBridgeTransitionsTotal.WithLabelValues("promoted").Inc()

// Join WAN seeds outside the lock with retries
if len(seeds) > 0 {
go c.joinSeeds("WAN", func() *memberlist.Memberlist {
Expand Down Expand Up @@ -113,6 +117,10 @@ func (c *gossipCluster) demoteFromBridge() {
c.isBridge = false
c.mu.Unlock()

metrics.ClusterBridgeStatus.Set(0)
metrics.ClusterBridgeTransitionsTotal.WithLabelValues("demoted").Inc()
metrics.ClusterMembersCount.WithLabelValues("wan").Set(0)

// Leave and shutdown outside the lock since Leave can trigger callbacks
if wan != nil {
if err := wan.Leave(5 * time.Second); err != nil {
Expand Down
53 changes: 43 additions & 10 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (

"github.com/hashicorp/memberlist"
clusterv1 "github.com/unkeyed/unkey/gen/proto/cluster/v1"
"github.com/unkeyed/unkey/pkg/cluster/metrics"
"github.com/unkeyed/unkey/pkg/logger"
"github.com/unkeyed/unkey/pkg/repeat"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -43,6 +45,9 @@ type gossipCluster struct {
// where memberlist holds its internal state lock.
evalCh chan struct{}
done chan struct{}

// stopMetrics stops the periodic member count gauge updater.
stopMetrics func()
}

// New creates a new cluster node, starts the LAN memberlist, joins LAN seeds,
Expand All @@ -51,16 +56,17 @@ func New(cfg Config) (Cluster, error) {
cfg.setDefaults()

c := &gossipCluster{
config: cfg,
mu: sync.RWMutex{},
lan: nil,
lanQueue: nil,
wan: nil,
wanQueue: nil,
isBridge: false,
closing: atomic.Bool{},
evalCh: make(chan struct{}, 1),
done: make(chan struct{}),
config: cfg,
mu: sync.RWMutex{},
lan: nil,
lanQueue: nil,
wan: nil,
wanQueue: nil,
isBridge: false,
closing: atomic.Bool{},
evalCh: make(chan struct{}, 1),
done: make(chan struct{}),
stopMetrics: nil, // set below
}

// Start the async bridge evaluator
Expand Down Expand Up @@ -101,6 +107,22 @@ func New(cfg Config) (Cluster, error) {
}, cfg.LANSeeds, c.triggerEvalBridge)
}

// Periodically update pool member count gauges. This avoids tracking
// counts inside memberlist callbacks where internal locks are held.
c.stopMetrics = repeat.Every(1*time.Minute, func() {
c.mu.RLock()
lan := c.lan
wan := c.wan
c.mu.RUnlock()

if lan != nil {
metrics.ClusterMembersCount.WithLabelValues("lan").Set(float64(lan.NumMembers()))
}
if wan != nil {
metrics.ClusterMembersCount.WithLabelValues("wan").Set(float64(wan.NumMembers()))
}
})

// Trigger initial bridge evaluation
c.triggerEvalBridge()

Expand All @@ -126,13 +148,15 @@ func (c *gossipCluster) joinSeeds(pool string, list func() *memberlist.Memberlis

_, err := ml.Join(seeds)
if err == nil {
metrics.ClusterSeedJoinAttemptsTotal.WithLabelValues(pool, "success").Inc()
logger.Info("Joined "+pool+" seeds", "seeds", seeds, "attempt", attempt)
if onSuccess != nil {
onSuccess()
}
return
}

metrics.ClusterSeedJoinAttemptsTotal.WithLabelValues(pool, "failure").Inc()
logger.Warn("Failed to join "+pool+" seeds, retrying",
"error", err,
"seeds", seeds,
Expand All @@ -149,6 +173,7 @@ func (c *gossipCluster) joinSeeds(pool string, list func() *memberlist.Memberlis
backoff = min(backoff*2, 10*time.Second)
}

metrics.ClusterSeedJoinAttemptsTotal.WithLabelValues(pool, "exhausted").Inc()
logger.Error("Exhausted retries joining "+pool+" seeds",
"seeds", seeds,
"attempts", maxJoinAttempts,
Expand Down Expand Up @@ -197,18 +222,22 @@ func (c *gossipCluster) Broadcast(payload clusterv1.IsClusterMessage_Payload) er
msg.Direction = clusterv1.Direction_DIRECTION_LAN
lanBytes, err := proto.Marshal(msg)
if err != nil {
metrics.ClusterBroadcastErrorsTotal.WithLabelValues("lan").Inc()
return fmt.Errorf("failed to marshal LAN message: %w", err)
}
lanQ.QueueBroadcast(newBroadcast(lanBytes))
metrics.ClusterBroadcastsTotal.WithLabelValues("lan").Inc()
}

if isBr && wanQ != nil {
msg.Direction = clusterv1.Direction_DIRECTION_WAN
wanBytes, err := proto.Marshal(msg)
if err != nil {
metrics.ClusterBroadcastErrorsTotal.WithLabelValues("wan").Inc()
return fmt.Errorf("failed to marshal WAN message: %w", err)
}
wanQ.QueueBroadcast(newBroadcast(wanBytes))
metrics.ClusterBroadcastsTotal.WithLabelValues("wan").Inc()
}

return nil
Expand Down Expand Up @@ -257,6 +286,10 @@ func (c *gossipCluster) Close() error {
}
close(c.done)

if c.stopMetrics != nil {
c.stopMetrics()
}

// Demote from bridge first (leaves WAN).
c.demoteFromBridge()

Expand Down
22 changes: 22 additions & 0 deletions pkg/cluster/delegate_lan.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package cluster

import (
"time"

"github.com/hashicorp/memberlist"
clusterv1 "github.com/unkeyed/unkey/gen/proto/cluster/v1"
"github.com/unkeyed/unkey/pkg/cluster/metrics"
"github.com/unkeyed/unkey/pkg/logger"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -40,10 +43,25 @@ func (d *lanDelegate) NotifyMsg(data []byte) {

var msg clusterv1.ClusterMessage
if err := proto.Unmarshal(data, &msg); err != nil {
metrics.ClusterMessageUnmarshalErrorsTotal.WithLabelValues("lan").Inc()
logger.Warn("Failed to unmarshal LAN cluster message", "error", err)
return
}

direction := "lan"
if msg.Direction == clusterv1.Direction_DIRECTION_WAN {
direction = "wan"
}
payloadType := metrics.PayloadTypeName(msg.GetPayload())
metrics.ClusterMessagesReceivedTotal.WithLabelValues("lan", direction, payloadType).Inc()

if msg.SentAtMs > 0 {
latency := time.Since(time.UnixMilli(msg.SentAtMs)).Seconds()
if latency >= 0 {
metrics.ClusterMessageLatencySeconds.WithLabelValues(direction, msg.SourceRegion).Observe(latency)
}
}

// Deliver to the application callback
if d.cluster.config.OnMessage != nil {
d.cluster.config.OnMessage(&msg)
Expand All @@ -61,10 +79,12 @@ func (d *lanDelegate) NotifyMsg(data []byte) {
relay.Direction = clusterv1.Direction_DIRECTION_WAN
wanBytes, err := proto.Marshal(relay)
if err != nil {
metrics.ClusterRelayErrorsTotal.WithLabelValues("lan_to_wan").Inc()
logger.Warn("Failed to marshal WAN relay message", "error", err)
return
}
wanQ.QueueBroadcast(newBroadcast(wanBytes))
metrics.ClusterRelaysTotal.WithLabelValues("lan_to_wan").Inc()
}
}
}
Expand All @@ -81,10 +101,12 @@ func newLANEventDelegate(c *gossipCluster) *lanEventDelegate {
}

func (d *lanEventDelegate) NotifyJoin(node *memberlist.Node) {
metrics.ClusterMembershipEventsTotal.WithLabelValues("join").Inc()
d.cluster.triggerEvalBridge()
}

func (d *lanEventDelegate) NotifyLeave(node *memberlist.Node) {
metrics.ClusterMembershipEventsTotal.WithLabelValues("leave").Inc()
d.cluster.triggerEvalBridge()
}

Expand Down
Loading