-
Notifications
You must be signed in to change notification settings - Fork 195
refactor: Replace prefix cache structure with golang-lru #928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5845ffd
20609d0
0192528
e7255c8
ca45167
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,156 +18,135 @@ package prefix | |
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "sync" | ||
| "time" | ||
| "unsafe" | ||
|
|
||
| "container/list" | ||
|
|
||
| lru "github.com/hashicorp/golang-lru/v2" | ||
| "sigs.k8s.io/controller-runtime/pkg/log" | ||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" | ||
| logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" | ||
| ) | ||
|
|
||
| func newIndexer(maxCacheSize int) *indexer { | ||
| t := &indexer{ | ||
| maxCacheSize: maxCacheSize, | ||
| table: make(map[BlockHash]map[ServerID]*list.Element), | ||
| ll: list.New(), | ||
| } | ||
| go t.ReportCacheSize(time.Second) | ||
| return t | ||
| } | ||
|
|
||
| // An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that | ||
| // prefix cached . | ||
| // prefix cached. | ||
| type indexer struct { | ||
| mu sync.RWMutex | ||
| maxCacheSize int | ||
| table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server | ||
| ll *list.List // LinkedList to keep track of the order of entries | ||
| mu sync.RWMutex | ||
| hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached | ||
| podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache | ||
| maxLRUSize int | ||
| } | ||
|
|
||
| // value is the value stored in the linked list. | ||
| type value struct { | ||
| server ServerID | ||
| hash BlockHash | ||
| } | ||
|
|
||
| // Get returns the set of servers that have the given prefix hash cached. | ||
| func (i *indexer) Get(hash BlockHash) map[ServerID]bool { | ||
| i.mu.RLock() | ||
| defer i.mu.RUnlock() | ||
| res := map[ServerID]bool{} | ||
| for server := range i.table[hash] { | ||
| res[server] = true | ||
| // newIndexer initializes an indexer with size limits and starts cache size reporting. | ||
| func newIndexer(maxLRUSize int) *indexer { | ||
| ix := &indexer{ | ||
| hashToPods: make(map[BlockHash]podSet), | ||
| podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]), | ||
| maxLRUSize: maxLRUSize, | ||
| } | ||
| return res | ||
| go ix.ReportLRUSize(time.Second) | ||
| return ix | ||
| } | ||
|
|
||
| // Add adds a list of prefix hashes of a single request to the server the request was sent to. | ||
| // The intuition is that this server is likely to have the prefix cached, so next time a request | ||
| // sharing the longest prefix should be sent to the same server to take advantage of the cache hit. | ||
| func (i *indexer) Add(hashes []BlockHash, server ServerID) { | ||
| // Add adds a list of prefix hashes to the cache, tied to the server. | ||
| func (i *indexer) Add(hashes []BlockHash, pod ServerID) error { | ||
| i.mu.Lock() | ||
| defer i.mu.Unlock() | ||
| for _, hash := range hashes { | ||
| i.add(hash, server) | ||
| // Check if the LRU pod exist | ||
| lruForPod, exists := i.podToLRU[pod] | ||
| if !exists { | ||
| newLRU, err := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) | ||
|
||
| if err != nil { | ||
| i.mu.Unlock() | ||
| return fmt.Errorf("failed to create LRU for pod %s: %w", pod, err) | ||
| } | ||
| i.podToLRU[pod] = newLRU | ||
| lruForPod = newLRU | ||
| } | ||
| } | ||
| i.mu.Unlock() | ||
|
|
||
| func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) { | ||
| servers, ok := i.table[hash] | ||
| if !ok { | ||
| return nil, false | ||
| // Add to LRU (may evict) | ||
| for _, hash := range hashes { | ||
| lruForPod.Add(hash, struct{}{}) | ||
| } | ||
| e, ok := servers[server] | ||
| return e, ok | ||
| } | ||
|
|
||
| func (i *indexer) add(hash BlockHash, server ServerID) { | ||
| e, exists := i.check(hash, server) | ||
| if exists { | ||
| i.ll.MoveToBack(e) | ||
| } else { | ||
| i.create(hash, server) | ||
| // Update hashToPods once under lock | ||
| i.mu.Lock() | ||
| for _, hash := range hashes { | ||
| pods := i.hashToPods[hash] | ||
| if pods == nil { | ||
| pods = make(podSet) | ||
| } | ||
| pods[pod] = struct{}{} | ||
| i.hashToPods[hash] = pods | ||
| } | ||
| i.mu.Unlock() | ||
| return nil | ||
| } | ||
|
|
||
| func (i *indexer) create(hash BlockHash, server ServerID) { | ||
| for i.ll.Len() >= i.maxCacheSize { | ||
| // Evict the least recently used entry if we've exceeded the max cache size | ||
| i.evict() | ||
| } | ||
| // Get returns a set of servers that have the given prefix hash cached. | ||
| func (i *indexer) Get(hash BlockHash) podSet { | ||
| i.mu.RLock() | ||
| defer i.mu.RUnlock() | ||
|
|
||
| if _, ok := i.table[hash]; !ok { | ||
| i.table[hash] = make(map[ServerID]*list.Element) | ||
| } | ||
| v := &value{ | ||
| server: server, | ||
| hash: hash, | ||
| res := podSet{} | ||
| pods, ok := i.hashToPods[hash] | ||
| if !ok { | ||
| return res | ||
| } | ||
| e := i.ll.PushBack(v) | ||
| i.table[hash][server] = e | ||
|
|
||
| return pods | ||
| } | ||
|
|
||
| // evict removes the least recently used entry from the cache | ||
| func (i *indexer) evict() { | ||
| oldestNode := i.ll.Front() | ||
| if oldestNode == nil { | ||
| return | ||
| } | ||
| i.ll.Remove(oldestNode) | ||
|
|
||
| v := oldestNode.Value.(*value) | ||
| hash := v.hash | ||
| server := v.server | ||
| // Remove from the hash map | ||
| serverMap := i.table[hash] | ||
| delete(serverMap, server) | ||
|
|
||
| // If this was the last server for this hash, remove the hash entry entirely | ||
| if len(serverMap) == 0 { | ||
| delete(i.table, hash) | ||
| // makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction. | ||
| func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) { | ||
| return func(hash BlockHash, _ struct{}) { | ||
| i.mu.Lock() | ||
| defer i.mu.Unlock() | ||
| // Remove the pod from the hash→pods map | ||
| if podSet, ok := i.hashToPods[hash]; ok { | ||
| delete(podSet, pod) | ||
| if len(podSet) == 0 { | ||
| delete(i.hashToPods, hash) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) | ||
| } | ||
|
|
||
| // ReportCacheSize starts a goroutine that periodically reports the cache size metric | ||
| func (i *indexer) ReportCacheSize(interval time.Duration) { | ||
| // ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric. | ||
| func (i *indexer) ReportLRUSize(interval time.Duration) { | ||
| ticker := time.NewTicker(interval) | ||
| defer ticker.Stop() | ||
| for range ticker.C { | ||
| i.mu.RLock() | ||
| metrics.RecordPrefixCacheSize(int64(i.ll.Len())) | ||
| log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000) | ||
| totalEntries := 0 | ||
| maxPodEntries := 0 | ||
| maxPodName := ServerID{} | ||
|
|
||
| for pod, lruCache := range i.podToLRU { | ||
| size := lruCache.Len() | ||
| totalEntries += size | ||
| if size > maxPodEntries { | ||
| maxPodEntries = size | ||
| maxPodName = pod | ||
| } | ||
| } | ||
|
|
||
| numPods := len(i.podToLRU) | ||
| avg := 0.0 | ||
| if numPods > 0 { | ||
| avg = float64(totalEntries) / float64(numPods) | ||
| } | ||
|
|
||
| metrics.RecordPrefixCacheSize(int64(totalEntries)) | ||
| log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state", | ||
| "total entries", totalEntries, | ||
| "# pods", numPods, | ||
| "avg entries per pod", avg, | ||
| "pod with max cache", maxPodName, | ||
| "max pod size", maxPodEntries, | ||
| "global max LRU cache capacity per pod", i.maxLRUSize, | ||
| ) | ||
|
|
||
| i.mu.RUnlock() | ||
| } | ||
| } | ||
|
|
||
| // estimateEntrySize estimates the memory size of a cache entry in bytes. | ||
| func (i *indexer) estimateEntrySize() int { | ||
| size := 0 | ||
|
|
||
| // Estimate the size of a node in the linked list. | ||
| // First get the size of the node struct via unsafe.Sizeof. | ||
| // The prev and next pointers are 8 bytes each on a 64-bit system. | ||
| // The BlockHash is a uint64, which is 8 bytes. | ||
| // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). | ||
| // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). | ||
| // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. | ||
| size += int(unsafe.Sizeof(value{})) | ||
| // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). | ||
| size += 2 * 63 | ||
|
|
||
| // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. | ||
| size += 8 // Size of the BlockHash (uint64). | ||
| size += 2 * 16 // Size of the ServerID string headers (NamespacedName). | ||
| size += 2 * 63 // Size of the Name and Namespace strings in ServerID. | ||
| size += 8 // Size of the pointer to the node in the hash map. | ||
|
|
||
| // Based on the above estimates, the estimated size of an entry is: | ||
| // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. | ||
| return size | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: another optimization is to use different mutex for the hashToPods and podToLRU, but I don't think it's very important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I used it only for the hashToPods operation, except in the ReportLRUSize, which we can remove if it hurts the performance