@@ -20,154 +20,92 @@ import (
2020 "context"
2121 "sync"
2222 "time"
23- "unsafe"
24-
25- "container/list"
2623
24+ lru "github.com/hashicorp/golang-lru/v2"
2725 "sigs.k8s.io/controller-runtime/pkg/log"
2826 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2927 logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3028)
3129
32- func newIndexer (maxCacheSize int ) * indexer {
33- t := & indexer {
34- maxCacheSize : maxCacheSize ,
35- table : make (map [BlockHash ]map [ServerID ]* list.Element ),
36- ll : list .New (),
37- }
38- go t .ReportCacheSize (time .Second )
39- return t
30+ // block holds an LRU cache of servers that may have a specific prefix hash.
31+ type block struct {
32+ Pods * lru.Cache [ServerID , struct {}] // Can be extended with metadata (e.g., timestamp).
4033}
4134
4235// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
4336// prefix cached .
4437type indexer struct {
45- mu sync.RWMutex
46- maxCacheSize int
47- table map [ BlockHash ] map [ ServerID ] * list. Element // from any prefix cache to the cache entry to find the server
48- ll * list. List // LinkedList to keep track of the order of entries
38+ mu sync.RWMutex
39+ cache * lru. Cache [ BlockHash , * block ]
40+ maxCacheSize int
41+ maxServersToMatch int
4942}
5043
51- // value is the value stored in the linked list.
52- type value struct {
53- server ServerID
54- hash BlockHash
44+ // newIndexer initializes an indexer with size limits and starts cache size reporting.
45+ func newIndexer (maxCacheSize , maxServersToMatch int ) * indexer {
46+ c , err := lru.New [BlockHash , * block ](maxCacheSize )
47+ if err != nil {
48+ panic (err )
49+ }
50+ ix := & indexer {
51+ cache : c ,
52+ maxCacheSize : maxCacheSize ,
53+ maxServersToMatch : maxServersToMatch ,
54+ }
55+ go ix .ReportCacheSize (time .Second )
56+ return ix
5557}
5658
57- // Get returns the set of servers that have the given prefix hash cached.
58- func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
59- i .mu .RLock ()
60- defer i .mu .RUnlock ()
61- res := map [ServerID ]bool {}
62- for server := range i .table [hash ] {
63- res [server ] = true
59+ // Add adds a list of prefix hashes to the cache, tied to the server.
60+ func (i * indexer ) Add (hashes []BlockHash , pod ServerID ) {
61+ if len (hashes ) == 0 || pod .Name == "" {
62+ return
6463 }
65- return res
66- }
6764
68- // Add adds a list of prefix hashes of a single request to the server the request was sent to.
69- // The intuition is that this server is likely to have the prefix cached, so next time a request
70- // sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71- func (i * indexer ) Add (hashes []BlockHash , server ServerID ) {
7265 i .mu .Lock ()
7366 defer i .mu .Unlock ()
74- for _ , hash := range hashes {
75- i .add (hash , server )
76- }
77- }
78-
79- func (i * indexer ) check (hash BlockHash , server ServerID ) (* list.Element , bool ) {
80- servers , ok := i .table [hash ]
81- if ! ok {
82- return nil , false
83- }
84- e , ok := servers [server ]
85- return e , ok
86- }
8767
88- func (i * indexer ) add (hash BlockHash , server ServerID ) {
89- e , exists := i .check (hash , server )
90- if exists {
91- i .ll .MoveToBack (e )
92- } else {
93- i .create (hash , server )
68+ for _ , hash := range hashes {
69+ b , ok := i .cache .Get (hash )
70+ if ! ok {
71+ // Create block with new LRU
72+ podLRU , _ := lru.New [ServerID , struct {}](i .maxServersToMatch )
73+ b = & block {Pods : podLRU }
74+ i .cache .Add (hash , b )
75+ }
76+
77+ b .Pods .Add (pod , struct {}{})
9478 }
9579}
9680
97- func (i * indexer ) create (hash BlockHash , server ServerID ) {
98- for i .ll .Len () >= i .maxCacheSize {
99- // Evict the least recently used entry if we've exceeded the max cache size
100- i .evict ()
101- }
102-
103- if _ , ok := i .table [hash ]; ! ok {
104- i .table [hash ] = make (map [ServerID ]* list.Element )
105- }
106- v := & value {
107- server : server ,
108- hash : hash ,
109- }
110- e := i .ll .PushBack (v )
111- i.table [hash ][server ] = e
112- }
81+ // Get returns a set of servers that have the given prefix hash cached.
82+ func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
83+ i .mu .RLock ()
84+ defer i .mu .RUnlock ()
11385
114- // evict removes the least recently used entry from the cache
115- func (i * indexer ) evict () {
116- oldestNode := i .ll .Front ()
117- if oldestNode == nil {
118- return
86+ res := map [ServerID ]bool {}
87+ block , ok := i .cache .Get (hash )
88+ if ! ok {
89+ return res
11990 }
120- i .ll .Remove (oldestNode )
121-
122- v := oldestNode .Value .(* value )
123- hash := v .hash
124- server := v .server
125- // Remove from the hash map
126- serverMap := i .table [hash ]
127- delete (serverMap , server )
128-
129- // If this was the last server for this hash, remove the hash entry entirely
130- if len (serverMap ) == 0 {
131- delete (i .table , hash )
91+ for _ , pod := range block .Pods .Keys () {
92+ res [pod ] = true
13293 }
133-
134- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Evicted LRU entry" , "hash" , hash , "server" , server )
94+ return res
13595}
13696
137- // ReportCacheSize starts a goroutine that periodically reports the cache size metric
97+ // ReportCacheSize starts a goroutine that periodically reports the cache size metric.
13898func (i * indexer ) ReportCacheSize (interval time.Duration ) {
13999 ticker := time .NewTicker (interval )
140100 defer ticker .Stop ()
141101 for range ticker .C {
142102 i .mu .RLock ()
143- metrics .RecordPrefixCacheSize (int64 (i .ll .Len ()))
144- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .ll .Len (), "estimated size MB" , i .ll .Len ()* i .estimateEntrySize ()/ 1000000 )
103+ size := i .cache .Len ()
104+ metrics .RecordPrefixCacheSize (int64 (size ))
105+ log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" ,
106+ "# entries" , size ,
107+ "prefix cache utilization [%]" , float64 (size )* 100 / float64 (i .maxCacheSize ),
108+ )
145109 i .mu .RUnlock ()
146110 }
147111}
148-
149- // estimateEntrySize estimates the memory size of a cache entry in bytes.
150- func (i * indexer ) estimateEntrySize () int {
151- size := 0
152-
153- // Estimate the size of a node in the linked list.
154- // First get the size of the node struct via unsafe.Sizeof.
155- // The prev and next pointers are 8 bytes each on a 64-bit system.
156- // The BlockHash is a uint64, which is 8 bytes.
157- // The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158- // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159- // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160- size += int (unsafe .Sizeof (value {}))
161- // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162- size += 2 * 63
163-
164- // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165- size += 8 // Size of the BlockHash (uint64).
166- size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167- size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168- size += 8 // Size of the pointer to the node in the hash map.
169-
170- // Based on the above estimates, the estimated size of an entry is:
171- // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172- return size
173- }
0 commit comments