Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func loadPrefixCacheConfig() prefix.Config {
return prefix.Config{
HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger),
MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger),
LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger),
LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger),
}
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/go-logr/logr v1.4.3
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/onsi/ginkgo/v2 v2.23.4
github.com/onsi/gomega v1.37.0
github.com/prometheus/client_golang v1.22.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
Expand Down
211 changes: 95 additions & 116 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,156 +18,135 @@ package prefix

import (
"context"
"fmt"
"sync"
"time"
"unsafe"

"container/list"

lru "github.com/hashicorp/golang-lru/v2"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

func newIndexer(maxCacheSize int) *indexer {
t := &indexer{
maxCacheSize: maxCacheSize,
table: make(map[BlockHash]map[ServerID]*list.Element),
ll: list.New(),
}
go t.ReportCacheSize(time.Second)
return t
}

// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
// prefix cached .
// prefix cached.
type indexer struct {
mu sync.RWMutex
maxCacheSize int
table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server
ll *list.List // LinkedList to keep track of the order of entries
mu sync.RWMutex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: another optimization is to use different mutex for the hashToPods and podToLRU, but I don't think it's very important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I used it only for the hashToPods operation, except in the ReportLRUSize, which we can remove if it hurts the performance

hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
maxLRUSize int
}

// value is the value stored in the linked list.
type value struct {
server ServerID
hash BlockHash
}

// Get returns the set of servers that have the given prefix hash cached.
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
i.mu.RLock()
defer i.mu.RUnlock()
res := map[ServerID]bool{}
for server := range i.table[hash] {
res[server] = true
// newIndexer initializes an indexer with size limits and starts cache size reporting.
func newIndexer(maxLRUSize int) *indexer {
ix := &indexer{
hashToPods: make(map[BlockHash]podSet),
podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]),
maxLRUSize: maxLRUSize,
}
return res
go ix.ReportLRUSize(time.Second)
return ix
}

// Add adds a list of prefix hashes of a single request to the server the request was sent to.
// The intuition is that this server is likely to have the prefix cached, so next time a request
// sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
func (i *indexer) Add(hashes []BlockHash, server ServerID) {
// Add adds a list of prefix hashes to the cache, tied to the server.
func (i *indexer) Add(hashes []BlockHash, pod ServerID) error {
i.mu.Lock()
defer i.mu.Unlock()
for _, hash := range hashes {
i.add(hash, server)
// Check if the LRU pod exist
lruForPod, exists := i.podToLRU[pod]
if !exists {
newLRU, err := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the lru code and the reason this could fail is only because if the LRU size is <= 0... Which IMO we can simply handle when initiating the indexer we set maxLRUSize = max(maxLRUSize, 1). We can then safely add a comment and ignore this error.

The way you handle the error is OK, but it adds some complexity to read, one might think: what if there is an error, do I end up with inaccurate score?

if err != nil {
i.mu.Unlock()
return fmt.Errorf("failed to create LRU for pod %s: %w", pod, err)
}
i.podToLRU[pod] = newLRU
lruForPod = newLRU
}
}
i.mu.Unlock()

func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) {
servers, ok := i.table[hash]
if !ok {
return nil, false
// Add to LRU (may evict)
for _, hash := range hashes {
lruForPod.Add(hash, struct{}{})
}
e, ok := servers[server]
return e, ok
}

func (i *indexer) add(hash BlockHash, server ServerID) {
e, exists := i.check(hash, server)
if exists {
i.ll.MoveToBack(e)
} else {
i.create(hash, server)
// Update hashToPods once under lock
i.mu.Lock()
for _, hash := range hashes {
pods := i.hashToPods[hash]
if pods == nil {
pods = make(podSet)
}
pods[pod] = struct{}{}
i.hashToPods[hash] = pods
}
i.mu.Unlock()
return nil
}

func (i *indexer) create(hash BlockHash, server ServerID) {
for i.ll.Len() >= i.maxCacheSize {
// Evict the least recently used entry if we've exceeded the max cache size
i.evict()
}
// Get returns a set of servers that have the given prefix hash cached.
func (i *indexer) Get(hash BlockHash) podSet {
i.mu.RLock()
defer i.mu.RUnlock()

if _, ok := i.table[hash]; !ok {
i.table[hash] = make(map[ServerID]*list.Element)
}
v := &value{
server: server,
hash: hash,
res := podSet{}
pods, ok := i.hashToPods[hash]
if !ok {
return res
}
e := i.ll.PushBack(v)
i.table[hash][server] = e

return pods
}

// evict removes the least recently used entry from the cache
func (i *indexer) evict() {
oldestNode := i.ll.Front()
if oldestNode == nil {
return
}
i.ll.Remove(oldestNode)

v := oldestNode.Value.(*value)
hash := v.hash
server := v.server
// Remove from the hash map
serverMap := i.table[hash]
delete(serverMap, server)

// If this was the last server for this hash, remove the hash entry entirely
if len(serverMap) == 0 {
delete(i.table, hash)
// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) {
return func(hash BlockHash, _ struct{}) {
i.mu.Lock()
defer i.mu.Unlock()
// Remove the pod from the hash→pods map
if podSet, ok := i.hashToPods[hash]; ok {
delete(podSet, pod)
if len(podSet) == 0 {
delete(i.hashToPods, hash)
}
}
}

log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server)
}

// ReportCacheSize starts a goroutine that periodically reports the cache size metric
func (i *indexer) ReportCacheSize(interval time.Duration) {
// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
func (i *indexer) ReportLRUSize(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
i.mu.RLock()
metrics.RecordPrefixCacheSize(int64(i.ll.Len()))
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000)
totalEntries := 0
maxPodEntries := 0
maxPodName := ServerID{}

for pod, lruCache := range i.podToLRU {
size := lruCache.Len()
totalEntries += size
if size > maxPodEntries {
maxPodEntries = size
maxPodName = pod
}
}

numPods := len(i.podToLRU)
avg := 0.0
if numPods > 0 {
avg = float64(totalEntries) / float64(numPods)
}

metrics.RecordPrefixCacheSize(int64(totalEntries))
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state",
"total entries", totalEntries,
"# pods", numPods,
"avg entries per pod", avg,
"pod with max cache", maxPodName,
"max pod size", maxPodEntries,
"global max LRU cache capacity per pod", i.maxLRUSize,
)

i.mu.RUnlock()
}
}

// estimateEntrySize estimates the memory size of a cache entry in bytes.
func (i *indexer) estimateEntrySize() int {
size := 0

// Estimate the size of a node in the linked list.
// First get the size of the node struct via unsafe.Sizeof.
// The prev and next pointers are 8 bytes each on a 64-bit system.
// The BlockHash is a uint64, which is 8 bytes.
// The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
// The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
// So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
size += int(unsafe.Sizeof(value{}))
// Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
size += 2 * 63

// Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
size += 8 // Size of the BlockHash (uint64).
size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
size += 8 // Size of the pointer to the node in the hash map.

// Based on the above estimates, the estimated size of an entry is:
// (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
return size
}
20 changes: 11 additions & 9 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,26 @@ import (
)

func TestIndexer_AddAndGet(t *testing.T) {
cache := newIndexer(2)
i := newIndexer(2)

hash1 := BlockHash(1)
server := ServerID{Namespace: "default", Name: "server1"}

// Add an entry to the cache
cache.Add([]BlockHash{hash1}, server)
err := i.Add([]BlockHash{hash1}, server)
assert.NoError(t, err)

// Retrieve the entry
assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry")
servers := cache.Get(hash1)
assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry")
servers := i.Get(hash1)
assert.Contains(t, servers, server, "Cache should contain the added server")

// Add another entry to the cache, the cache size should be incremented to 2.
cache.Add([]BlockHash{BlockHash(2)}, server)
assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry")
err = i.Add([]BlockHash{BlockHash(2)}, server)
assert.NoError(t, err)
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry")

// Add another entry to the cache, which should evict the first one due to max size.
cache.Add([]BlockHash{BlockHash(3)}, server)
assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry")
err = i.Add([]BlockHash{BlockHash(3)}, server)
assert.NoError(t, err)
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry")
}
Loading