@@ -20,154 +20,130 @@ import (
2020 "context"
2121 "sync"
2222 "time"
23- "unsafe"
24-
25- "container/list"
2623
24+ lru "github.com/hashicorp/golang-lru/v2"
2725 "sigs.k8s.io/controller-runtime/pkg/log"
2826 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2927 logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3028)
3129
32- func newIndexer (maxCacheSize int ) * indexer {
33- t := & indexer {
34- maxCacheSize : maxCacheSize ,
35- table : make (map [BlockHash ]map [ServerID ]* list.Element ),
36- ll : list .New (),
37- }
38- go t .ReportCacheSize (time .Second )
39- return t
40- }
41-
4230// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
43- // prefix cached .
31+ // prefix cached.
4432type indexer struct {
45- mu sync.RWMutex
46- maxCacheSize int
47- table map [BlockHash ] map [ ServerID ]* list. Element // from any prefix cache to the cache entry to find the server
48- ll * list. List // LinkedList to keep track of the order of entries
33+ mu sync.RWMutex
34+ hashToPods map [ BlockHash ] podSet // the lookup data structure to find pods that have the BlockHash cached
35+ podToLRU map [ServerID ]* lru. Cache [ BlockHash , struct {}] // key is pod namespacedName, value is an LRU cache
36+ maxLRUSize int
4937}
5038
51- // value is the value stored in the linked list.
52- type value struct {
53- server ServerID
54- hash BlockHash
55- }
56-
57- // Get returns the set of servers that have the given prefix hash cached.
58- func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
59- i .mu .RLock ()
60- defer i .mu .RUnlock ()
61- res := map [ServerID ]bool {}
62- for server := range i .table [hash ] {
63- res [server ] = true
39+ // newIndexer initializes an indexer with size limits and starts cache size reporting.
40+ func newIndexer (maxLRUSize int ) * indexer {
41+ ix := & indexer {
42+ hashToPods : make (map [BlockHash ]podSet ),
43+ podToLRU : make (map [ServerID ]* lru.Cache [BlockHash , struct {}]),
44+ maxLRUSize : maxLRUSize ,
6445 }
65- return res
46+
47+ go ix .ReportLRUSize (time .Second )
48+ return ix
6649}
6750
68- // Add adds a list of prefix hashes of a single request to the server the request was sent to.
69- // The intuition is that this server is likely to have the prefix cached, so next time a request
70- // sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71- func (i * indexer ) Add (hashes []BlockHash , server ServerID ) {
51+ // Add adds a list of prefix hashes to the cache, tied to the server.
52+ func (i * indexer ) Add (hashes []BlockHash , pod ServerID ) {
7253 i .mu .Lock ()
73- defer i .mu .Unlock ()
74- for _ , hash := range hashes {
75- i .add (hash , server )
54+ // Check if the LRU pod exist
55+ lruForPod , exists := i .podToLRU [pod ]
56+ if ! exists {
57+ newLRU , _ := lru .NewWithEvict [BlockHash , struct {}](i .maxLRUSize , i .makeEvictionFn (pod ))
58+ i .podToLRU [pod ] = newLRU
59+ lruForPod = newLRU
7660 }
77- }
7861
79- func (i * indexer ) check (hash BlockHash , server ServerID ) (* list.Element , bool ) {
80- servers , ok := i .table [hash ]
81- if ! ok {
82- return nil , false
62+ i .mu .Unlock ()
63+
64+ // Add to LRU (may evict)
65+ for _ , hash := range hashes {
66+ lruForPod .Add (hash , struct {}{})
8367 }
84- e , ok := servers [server ]
85- return e , ok
86- }
8768
88- func (i * indexer ) add (hash BlockHash , server ServerID ) {
89- e , exists := i .check (hash , server )
90- if exists {
91- i .ll .MoveToBack (e )
92- } else {
93- i .create (hash , server )
69+ // Update hashToPods once under lock
70+ i .mu .Lock ()
71+ for _ , hash := range hashes {
72+ pods := i .hashToPods [hash ]
73+ if pods == nil {
74+ pods = make (podSet )
75+ }
76+ pods [pod ] = struct {}{}
77+ i .hashToPods [hash ] = pods
9478 }
79+
80+ i .mu .Unlock ()
9581}
9682
97- func (i * indexer ) create (hash BlockHash , server ServerID ) {
98- for i .ll .Len () >= i .maxCacheSize {
99- // Evict the least recently used entry if we've exceeded the max cache size
100- i .evict ()
101- }
83+ // Get returns a set of servers that have the given prefix hash cached.
84+ func (i * indexer ) Get (hash BlockHash ) podSet {
85+ i .mu .RLock ()
86+ defer i .mu .RUnlock ()
10287
103- if _ , ok := i .table [hash ]; ! ok {
104- i .table [hash ] = make (map [ServerID ]* list.Element )
105- }
106- v := & value {
107- server : server ,
108- hash : hash ,
88+ res := podSet {}
89+ pods , ok := i .hashToPods [hash ]
90+ if ! ok {
91+ return res
10992 }
110- e := i . ll . PushBack ( v )
111- i. table [ hash ][ server ] = e
93+
94+ return pods
11295}
11396
114- // evict removes the least recently used entry from the cache
115- func (i * indexer ) evict () {
116- oldestNode := i .ll .Front ()
117- if oldestNode == nil {
118- return
97+ // makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
98+ func (i * indexer ) makeEvictionFn (pod ServerID ) func (BlockHash , struct {}) {
99+ return func (hash BlockHash , _ struct {}) {
100+ i .mu .Lock ()
101+ defer i .mu .Unlock ()
102+ // Remove the pod from the hash→pods map
103+ if podSet , ok := i .hashToPods [hash ]; ok {
104+ delete (podSet , pod )
105+ if len (podSet ) == 0 {
106+ delete (i .hashToPods , hash )
107+ }
108+ }
119109 }
120- i .ll .Remove (oldestNode )
121-
122- v := oldestNode .Value .(* value )
123- hash := v .hash
124- server := v .server
125- // Remove from the hash map
126- serverMap := i .table [hash ]
127- delete (serverMap , server )
128-
129- // If this was the last server for this hash, remove the hash entry entirely
130- if len (serverMap ) == 0 {
131- delete (i .table , hash )
132- }
133-
134- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Evicted LRU entry" , "hash" , hash , "server" , server )
135110}
136111
137- // ReportCacheSize starts a goroutine that periodically reports the cache size metric
138- func (i * indexer ) ReportCacheSize (interval time.Duration ) {
112+ // ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
113+ func (i * indexer ) ReportLRUSize (interval time.Duration ) {
139114 ticker := time .NewTicker (interval )
140115 defer ticker .Stop ()
141116 for range ticker .C {
142117 i .mu .RLock ()
143- metrics .RecordPrefixCacheSize (int64 (i .ll .Len ()))
144- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .ll .Len (), "estimated size MB" , i .ll .Len ()* i .estimateEntrySize ()/ 1000000 )
118+ totalEntries := 0
119+ maxPodEntries := 0
120+ maxPodName := ServerID {}
121+
122+ for pod , lruCache := range i .podToLRU {
123+ size := lruCache .Len ()
124+ totalEntries += size
125+ if size > maxPodEntries {
126+ maxPodEntries = size
127+ maxPodName = pod
128+ }
129+ }
130+
131+ numPods := len (i .podToLRU )
132+ avg := 0.0
133+ if numPods > 0 {
134+ avg = float64 (totalEntries ) / float64 (numPods )
135+ }
136+
137+ metrics .RecordPrefixCacheSize (int64 (totalEntries ))
138+ log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Prefix cache state" ,
139+ "total entries" , totalEntries ,
140+ "# pods" , numPods ,
141+ "avg entries per pod" , avg ,
142+ "pod with max cache" , maxPodName ,
143+ "max pod size" , maxPodEntries ,
144+ "global max LRU cache capacity per pod" , i .maxLRUSize ,
145+ )
146+
145147 i .mu .RUnlock ()
146148 }
147149}
148-
149- // estimateEntrySize estimates the memory size of a cache entry in bytes.
150- func (i * indexer ) estimateEntrySize () int {
151- size := 0
152-
153- // Estimate the size of a node in the linked list.
154- // First get the size of the node struct via unsafe.Sizeof.
155- // The prev and next pointers are 8 bytes each on a 64-bit system.
156- // The BlockHash is a uint64, which is 8 bytes.
157- // The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158- // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159- // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160- size += int (unsafe .Sizeof (value {}))
161- // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162- size += 2 * 63
163-
164- // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165- size += 8 // Size of the BlockHash (uint64).
166- size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167- size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168- size += 8 // Size of the pointer to the node in the hash map.
169-
170- // Based on the above estimates, the estimated size of an entry is:
171- // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172- return size
173- }
0 commit comments