From 5845ffd73533990a03688978833cecc61c23a3d3 Mon Sep 17 00:00:00 2001 From: Kfir Toledo Date: Thu, 5 Jun 2025 15:02:06 +0300 Subject: [PATCH 1/5] refactor: Replace prefix cache structure with golang-lru Signed-off-by: Kfir Toledo Co-authored-by: Maroon Ayoub --- go.mod | 1 + go.sum | 2 + .../framework/plugins/multi/prefix/indexer.go | 168 ++++++------------ .../plugins/multi/prefix/indexer_test.go | 16 +- .../framework/plugins/multi/prefix/plugin.go | 40 ++--- .../plugins/multi/prefix/plugin_test.go | 58 ++++++ 6 files changed, 140 insertions(+), 145 deletions(-) diff --git a/go.mod b/go.mod index 192773cc6..0c02daccc 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/prometheus/client_golang v1.22.0 diff --git a/go.sum b/go.sum index 7733d5555..2d45d351f 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index 4859357d8..7529d26af 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -20,154 +20,92 @@ import ( "context" "sync" "time" - "unsafe" - - "container/list" + lru "github.com/hashicorp/golang-lru/v2" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func newIndexer(maxCacheSize int) *indexer { - t := &indexer{ - maxCacheSize: maxCacheSize, - table: make(map[BlockHash]map[ServerID]*list.Element), - ll: list.New(), - } - go t.ReportCacheSize(time.Second) - return t +// block holds an LRU cache of servers that may have a specific prefix hash. +type block struct { + Pods *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp). } // An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that // prefix cached . type indexer struct { - mu sync.RWMutex - maxCacheSize int - table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server - ll *list.List // LinkedList to keep track of the order of entries + mu sync.RWMutex + cache *lru.Cache[BlockHash, *block] + maxCacheSize int + maxServersToMatch int } -// value is the value stored in the linked list. -type value struct { - server ServerID - hash BlockHash +// newIndexer initializes an indexer with size limits and starts cache size reporting. +func newIndexer(maxCacheSize, maxServersToMatch int) *indexer { + c, err := lru.New[BlockHash, *block](maxCacheSize) + if err != nil { + panic(err) + } + ix := &indexer{ + cache: c, + maxCacheSize: maxCacheSize, + maxServersToMatch: maxServersToMatch, + } + go ix.ReportCacheSize(time.Second) + return ix } -// Get returns the set of servers that have the given prefix hash cached. -func (i *indexer) Get(hash BlockHash) map[ServerID]bool { - i.mu.RLock() - defer i.mu.RUnlock() - res := map[ServerID]bool{} - for server := range i.table[hash] { - res[server] = true +// Add adds a list of prefix hashes to the cache, tied to the server. +func (i *indexer) Add(hashes []BlockHash, pod ServerID) { + if len(hashes) == 0 || pod.Name == "" { + return } - return res -} -// Add adds a list of prefix hashes of a single request to the server the request was sent to. -// The intuition is that this server is likely to have the prefix cached, so next time a request -// sharing the longest prefix should be sent to the same server to take advantage of the cache hit. -func (i *indexer) Add(hashes []BlockHash, server ServerID) { i.mu.Lock() defer i.mu.Unlock() - for _, hash := range hashes { - i.add(hash, server) - } -} - -func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) { - servers, ok := i.table[hash] - if !ok { - return nil, false - } - e, ok := servers[server] - return e, ok -} -func (i *indexer) add(hash BlockHash, server ServerID) { - e, exists := i.check(hash, server) - if exists { - i.ll.MoveToBack(e) - } else { - i.create(hash, server) + for _, hash := range hashes { + b, ok := i.cache.Get(hash) + if !ok { + // Create block with new LRU + podLRU, _ := lru.New[ServerID, struct{}](i.maxServersToMatch) + b = &block{Pods: podLRU} + i.cache.Add(hash, b) + } + + b.Pods.Add(pod, struct{}{}) } } -func (i *indexer) create(hash BlockHash, server ServerID) { - for i.ll.Len() >= i.maxCacheSize { - // Evict the least recently used entry if we've exceeded the max cache size - i.evict() - } - - if _, ok := i.table[hash]; !ok { - i.table[hash] = make(map[ServerID]*list.Element) - } - v := &value{ - server: server, - hash: hash, - } - e := i.ll.PushBack(v) - i.table[hash][server] = e -} +// Get returns a set of servers that have the given prefix hash cached. +func (i *indexer) Get(hash BlockHash) map[ServerID]bool { + i.mu.RLock() + defer i.mu.RUnlock() -// evict removes the least recently used entry from the cache -func (i *indexer) evict() { - oldestNode := i.ll.Front() - if oldestNode == nil { - return + res := map[ServerID]bool{} + block, ok := i.cache.Get(hash) + if !ok { + return res } - i.ll.Remove(oldestNode) - - v := oldestNode.Value.(*value) - hash := v.hash - server := v.server - // Remove from the hash map - serverMap := i.table[hash] - delete(serverMap, server) - - // If this was the last server for this hash, remove the hash entry entirely - if len(serverMap) == 0 { - delete(i.table, hash) + for _, pod := range block.Pods.Keys() { + res[pod] = true } - - log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) + return res } -// ReportCacheSize starts a goroutine that periodically reports the cache size metric +// ReportCacheSize starts a goroutine that periodically reports the cache size metric. func (i *indexer) ReportCacheSize(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { i.mu.RLock() - metrics.RecordPrefixCacheSize(int64(i.ll.Len())) - log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000) + size := i.cache.Len() + metrics.RecordPrefixCacheSize(int64(size)) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", + "# entries", size, + "prefix cache utilization [%]", float64(size)*100/float64(i.maxCacheSize), + ) i.mu.RUnlock() } } - -// estimateEntrySize estimates the memory size of a cache entry in bytes. -func (i *indexer) estimateEntrySize() int { - size := 0 - - // Estimate the size of a node in the linked list. - // First get the size of the node struct via unsafe.Sizeof. - // The prev and next pointers are 8 bytes each on a 64-bit system. - // The BlockHash is a uint64, which is 8 bytes. - // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). - // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). - // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. - size += int(unsafe.Sizeof(value{})) - // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). - size += 2 * 63 - - // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. - size += 8 // Size of the BlockHash (uint64). - size += 2 * 16 // Size of the ServerID string headers (NamespacedName). - size += 2 * 63 // Size of the Name and Namespace strings in ServerID. - size += 8 // Size of the pointer to the node in the hash map. - - // Based on the above estimates, the estimated size of an entry is: - // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. - return size -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index 596625d10..436ead771 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -22,24 +22,24 @@ import ( ) func TestIndexer_AddAndGet(t *testing.T) { - cache := newIndexer(2) + i := newIndexer(2, 2) hash1 := BlockHash(1) server := ServerID{Namespace: "default", Name: "server1"} // Add an entry to the cache - cache.Add([]BlockHash{hash1}, server) + i.Add([]BlockHash{hash1}, server) // Retrieve the entry - assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry") - servers := cache.Get(hash1) + assert.Equal(t, 1, i.cache.Len(), "Cache size should be 1 after adding an entry") + servers := i.Get(hash1) assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. - cache.Add([]BlockHash{BlockHash(2)}, server) - assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry") + i.Add([]BlockHash{BlockHash(2)}, server) + assert.Equal(t, 2, i.cache.Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. - cache.Add([]BlockHash{BlockHash(3)}, server) - assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry") + i.Add([]BlockHash{BlockHash(3)}, server) + assert.Equal(t, 2, i.cache.Len(), "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 01903bce3..e5f978ba9 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -36,7 +36,7 @@ const ( // Why not just return the server with longest prefix match? // It may not be the optimal choice, e.g., it may have a high queue depth. // We optimistically search more than one to give more candidates for the scheduler to choose. - DefaultNumServersToMatch = 2 + DefaultNumServersToMatch = 16 // vLLM default token block size is 16, and a good guess of average characters per token is 4. DefaultHashBlockSize = 64 // The maximum number of blocks to match. Two long requests with the same prefix up to this @@ -44,20 +44,17 @@ const ( // This parameter provides a trade-off between cache size, prefix matching speed and matching // accuracy. Use a small value if most requests are short to reduce cache size and speed up the // matching process. Use a large value if most requests are long to increase the matching accuracy. - DefaultMaxPrefixBlocks = 128 + DefaultMaxPrefixBlocks = 256 // The indexer is an approximation to the actual prefix cache state on the model servers. // A small capacity ensures a high accuracy of cache hit on the model server, but it will // increase the chance of false negatives. A high capacity does the opposite. // To properly size this, consider the sum of the total number of cache entries on all model - // servers. Consider the llama3 8B model on 3 H100 80GB GPUs. The size of the model weight is + // servers. Consider the llama3 8B model on 8 H100 80GB GPUs. The size of the model weight is // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 - // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 3 = 46.8K blocks, or - // roughly 50K. - // How much memory space does it require to hold the 50K block hashes? - // According to the estimates in indexer.estimateEntrySize(), the size of each entry is - // approximately 348 bytes. So in total we have 50K * 348 = 17.4MB. - DefaultLRUIndexerCapacity = 50000 + // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 8 = 124.8K blocks, or + // roughly 130K. + DefaultLRUIndexerCapacity = 130000 ) type Config struct { @@ -67,6 +64,8 @@ type Config struct { // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int + // NumServersToMatch is the maximum number that can match per hash BlockHash. + MaxNumServersToMatch int // Max (approximate) size of the LRU indexer in number of entries. LRUIndexerCapacity int } @@ -123,7 +122,7 @@ var _ framework.PostCycle = &Plugin{} func New(config Config) *Plugin { m := &Plugin{ Config: config, - indexer: newIndexer(config.LRUIndexerCapacity), + indexer: newIndexer(config.LRUIndexerCapacity, config.MaxNumServersToMatch), } return m } @@ -138,14 +137,11 @@ func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleStat loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) - numServers := DefaultNumServersToMatch - if numServers > len(pods) { - numServers = len(pods) - } state := &schedulingContextState{ PrefixHashes: hashes, - PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, numServers), + PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), } + cycleState.Write(types.StateKey(m.Name()), state) loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) // calculate the scores of pods @@ -181,22 +177,22 @@ func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, re } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. -func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash, numServers int) map[ServerID]int { +func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) res := make(map[ServerID]int) // Use a greedy strategy to search from the longest prefix. // NOTE: It's possible to further optimize this with a binary search. - for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- { + for i := 0; i < len(hashes); i++ { hash := hashes[i] cachedServers := m.indexer.Get(hash) - if len(cachedServers) > 0 { + if len(cachedServers) == 0 { + break + } else { loggerTrace.Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i) for server := range cachedServers { // Update servers with their longest prefix match. - // If we already found this server with longer prefix match, don't update it. - if _, ok := res[server]; !ok { - res[server] = i + 1 - } + res[server]++ + } } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index de6c68bbd..250b89b7b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -18,6 +18,10 @@ package prefix import ( "context" + "fmt" + "math" + "math/rand" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -27,10 +31,12 @@ import ( ) func TestPrefixPlugin(t *testing.T) { + config := Config{ HashBlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUIndexerCapacity: DefaultLRUIndexerCapacity, + MaxNumServersToMatch: DefaultNumServersToMatch, } plugin := New(config) @@ -136,3 +142,55 @@ func TestPrefixPlugin(t *testing.T) { plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1}) } + +// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. +func TestPrefixPluginStress(t *testing.T) { + blockSize := 4 + config := Config{ + HashBlockSize: blockSize, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, + MaxNumServersToMatch: DefaultNumServersToMatch, + } + + plugin := New(config) + types.NewCycleState() + for i := 0; i < 1000; i++ { + // Generate increasing-length random prompts + prompt := randomPrompt(4 + i) + pod := &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{ + Name: fmt.Sprintf("random-pod-%d", i), + }, + }, + } + + pods := []types.Pod{pod} + req := &types.LLMRequest{ + TargetModel: "model-stress", + Prompt: prompt, + } + + // First cycle: simulate scheduling and insert prefix info into the cache + cycleState := types.NewCycleState() + plugin.Score(context.Background(), req, cycleState, pods) + plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod}) + + // Second cycle: validate internal state + state, err := plugin.getPrefixState(cycleState) + assert.NoError(t, err) + expectedHashes := int(math.Min(DefaultMaxPrefixBlocks+1, float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. + assert.Equal(t, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") + } +} + +// randomPrompt generates a pseudo-random string of length n using lowercase letters. +func randomPrompt(n int) string { + runes := []rune("abcdefghijklmnopqrstuvwxyz") + var sb strings.Builder + for i := 0; i < n; i++ { + sb.WriteRune(runes[rand.Intn(len(runes))]) + } + return sb.String() +} From 20609d03c955dde6ea272cd5ccd9e5ee33bcd7f5 Mon Sep 17 00:00:00 2001 From: Kfir Toledo Date: Mon, 9 Jun 2025 16:54:27 +0300 Subject: [PATCH 2/5] fix: rename prefix scorer parameters and convert test to benchmark test Signed-off-by: Kfir Toledo --- .../framework/plugins/multi/prefix/indexer.go | 26 +++++++++---------- .../framework/plugins/multi/prefix/plugin.go | 17 ++++++------ .../plugins/multi/prefix/plugin_test.go | 18 +++++++------ .../guides/epp-configuration/prefix-aware.md | 10 ++++--- 4 files changed, 38 insertions(+), 33 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index 7529d26af..d31e97549 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -27,23 +27,23 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// block holds an LRU cache of servers that may have a specific prefix hash. -type block struct { - Pods *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp). +// podSet holds an LRU cache of servers that may have a specific prefix hash. +type podSet struct { + enteries *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp). } // An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that // prefix cached . type indexer struct { mu sync.RWMutex - cache *lru.Cache[BlockHash, *block] + cache *lru.Cache[BlockHash, *podSet] maxCacheSize int maxServersToMatch int } // newIndexer initializes an indexer with size limits and starts cache size reporting. func newIndexer(maxCacheSize, maxServersToMatch int) *indexer { - c, err := lru.New[BlockHash, *block](maxCacheSize) + c, err := lru.New[BlockHash, *podSet](maxCacheSize) if err != nil { panic(err) } @@ -58,7 +58,7 @@ func newIndexer(maxCacheSize, maxServersToMatch int) *indexer { // Add adds a list of prefix hashes to the cache, tied to the server. func (i *indexer) Add(hashes []BlockHash, pod ServerID) { - if len(hashes) == 0 || pod.Name == "" { + if pod.Name == "" { return } @@ -66,15 +66,15 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) { defer i.mu.Unlock() for _, hash := range hashes { - b, ok := i.cache.Get(hash) + p, ok := i.cache.Get(hash) if !ok { - // Create block with new LRU + // Create podSet with new LRU podLRU, _ := lru.New[ServerID, struct{}](i.maxServersToMatch) - b = &block{Pods: podLRU} - i.cache.Add(hash, b) + p = &podSet{enteries: podLRU} + i.cache.Add(hash, p) } - b.Pods.Add(pod, struct{}{}) + p.enteries.Add(pod, struct{}{}) } } @@ -84,11 +84,11 @@ func (i *indexer) Get(hash BlockHash) map[ServerID]bool { defer i.mu.RUnlock() res := map[ServerID]bool{} - block, ok := i.cache.Get(hash) + pods, ok := i.cache.Get(hash) if !ok { return res } - for _, pod := range block.Pods.Keys() { + for _, pod := range pods.enteries.Keys() { res[pod] = true } return res diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index e5f978ba9..cf55c8400 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -32,11 +32,10 @@ import ( const ( DefaultScorerWeight = 1 - // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. - // Why not just return the server with longest prefix match? - // It may not be the optimal choice, e.g., it may have a high queue depth. - // We optimistically search more than one to give more candidates for the scheduler to choose. - DefaultNumServersToMatch = 16 + // DefaultMaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer. + // This limits the number of recent pods associated with a given prefix to reduce memory usage + // and ensure faster lookup. When the limit is reached, the least recently used pod is evicted. + DefaultMaxPodsPerPrefix = 4 // vLLM default token block size is 16, and a good guess of average characters per token is 4. DefaultHashBlockSize = 64 // The maximum number of blocks to match. Two long requests with the same prefix up to this @@ -64,8 +63,8 @@ type Config struct { // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int - // NumServersToMatch is the maximum number that can match per hash BlockHash. - MaxNumServersToMatch int + // MaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer. + MaxPodsPerPrefix int // Max (approximate) size of the LRU indexer in number of entries. LRUIndexerCapacity int } @@ -122,7 +121,7 @@ var _ framework.PostCycle = &Plugin{} func New(config Config) *Plugin { m := &Plugin{ Config: config, - indexer: newIndexer(config.LRUIndexerCapacity, config.MaxNumServersToMatch), + indexer: newIndexer(config.LRUIndexerCapacity, config.MaxPodsPerPrefix), } return m } @@ -136,7 +135,7 @@ func (m *Plugin) Name() string { func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPodsPerPrefix) state := &schedulingContextState{ PrefixHashes: hashes, PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 250b89b7b..d05988872 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -36,7 +36,7 @@ func TestPrefixPlugin(t *testing.T) { HashBlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUIndexerCapacity: DefaultLRUIndexerCapacity, - MaxNumServersToMatch: DefaultNumServersToMatch, + MaxPodsPerPrefix: DefaultMaxPodsPerPrefix, } plugin := New(config) @@ -144,18 +144,20 @@ func TestPrefixPlugin(t *testing.T) { } // TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. -func TestPrefixPluginStress(t *testing.T) { +func BenchmarkPrefixPluginStress(b *testing.B) { blockSize := 4 + maxPrefixBlocks := 50000 config := Config{ HashBlockSize: blockSize, - MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + MaxPrefixBlocksToMatch: maxPrefixBlocks, LRUIndexerCapacity: DefaultLRUIndexerCapacity, - MaxNumServersToMatch: DefaultNumServersToMatch, + MaxPodsPerPrefix: DefaultMaxPodsPerPrefix, } plugin := New(config) types.NewCycleState() - for i := 0; i < 1000; i++ { + promptLen := []int{10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000} + for _, i := range promptLen { // Generate increasing-length random prompts prompt := randomPrompt(4 + i) pod := &types.PodMetrics{ @@ -179,9 +181,9 @@ func TestPrefixPluginStress(t *testing.T) { // Second cycle: validate internal state state, err := plugin.getPrefixState(cycleState) - assert.NoError(t, err) - expectedHashes := int(math.Min(DefaultMaxPrefixBlocks+1, float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. - assert.Equal(t, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") + assert.NoError(b, err) + expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. + assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") } } diff --git a/site-src/guides/epp-configuration/prefix-aware.md b/site-src/guides/epp-configuration/prefix-aware.md index a0ad8b51a..cbd2018f5 100644 --- a/site-src/guides/epp-configuration/prefix-aware.md +++ b/site-src/guides/epp-configuration/prefix-aware.md @@ -4,7 +4,7 @@ The [prefix cache plugin](https://github.com/kubernetes-sigs/gateway-api-inferen takes advantage of the prefix caching (e.g., [vllm APC](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching.html)) feature of model servers, and optimizes request scheduling by placing requests sharing the longest prefixes to the same server as much as possible, while balancing the server load by considering kv-cache -and queue depth. +and queue depth. ## Enable the prefix cache plugin @@ -34,6 +34,10 @@ for performance. * `PREFIX_CACHE_LRU_CAPACITY`: Maximum capacity the prefix LRU indexer in number of block hashes. Below shows a detailed analysis on how to estimate this. +* `PREFIX_MAX_PODS_PER_PREFIX`: Defines the maximum number of pods (servers) tracked per prefix hash in the internal LRU cache. +This setting helps optimize memory usage by retaining only the hottest (most recently active) pods for each prefix. +When the limit is reached, older pods are evicted based on least-recently-used (LRU) order. + The prefix cache plugin estimates the prefix cache indexes in model server HBMs. In the perfect scenario, EPP has the exact same prefix cache entries per model server as their HBM cache entries. If @@ -41,7 +45,7 @@ shows a detailed analysis on how to estimate this. false cache misses. If the EPP cache is larger than the HBM cache, then there are more false cache hits. Therefore **the EPP prefix cache indexer size should be as close as possible to the HBM cache size.** - NOTE: EPP builds prefix cache based on characters, while model server maintains prefix cache entries + NOTE: EPP builds prefix cache based on characters, while model server maintains prefix cache entries in tokens, a conversion between character <-> token is needed. Below are the formulas to estimate the EPP prefix indexer size: @@ -63,7 +67,7 @@ shows a detailed analysis on how to estimate this. max_kv_tokens_per_server = (80GB - 16GB) / 128KB = 500,000 # assume avg_chars_per_token = 4, prefix_indexer_hash_block_size = 64 (default) # each entry is about 358KB, so the memory footrpint is abut 11 MB per server - lru_indexer_capacity_per_server = 500,000*4/64 = 31250 + lru_indexer_capacity_per_server = 500,000*4/64 = 31250 lru_indexer_capacity_total = 3 * 31250 = 93750 ``` From 01925288c7ab8ae98675c0d83b4cd8a45f8609d2 Mon Sep 17 00:00:00 2001 From: Kfir Toledo Date: Thu, 12 Jun 2025 13:51:41 +0300 Subject: [PATCH 3/5] feat: Add per server LRU capacity Signed-off-by: Kfir Toledo --- cmd/epp/runner/runner.go | 2 +- .../framework/plugins/multi/prefix/indexer.go | 131 ++++++++++++------ .../plugins/multi/prefix/indexer_test.go | 13 +- .../framework/plugins/multi/prefix/plugin.go | 26 ++-- .../plugins/multi/prefix/plugin_test.go | 6 +- .../guides/epp-configuration/prefix-aware.md | 9 +- 6 files changed, 113 insertions(+), 74 deletions(-) diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index b23259f8c..5c65332b8 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -318,7 +318,7 @@ func loadPrefixCacheConfig() prefix.Config { return prefix.Config{ HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), - LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), + LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger), } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index d31e97549..d0c48d9f6 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -18,6 +18,7 @@ package prefix import ( "context" + "fmt" "sync" "time" @@ -27,32 +28,23 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// podSet holds an LRU cache of servers that may have a specific prefix hash. -type podSet struct { - enteries *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp). -} - // An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that -// prefix cached . +// prefix cached. type indexer struct { - mu sync.RWMutex - cache *lru.Cache[BlockHash, *podSet] - maxCacheSize int - maxServersToMatch int + mu sync.RWMutex + hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached + podToLRU map[string]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache + maxLRUSize int } // newIndexer initializes an indexer with size limits and starts cache size reporting. -func newIndexer(maxCacheSize, maxServersToMatch int) *indexer { - c, err := lru.New[BlockHash, *podSet](maxCacheSize) - if err != nil { - panic(err) - } +func newIndexer(maxLRUSize int) *indexer { ix := &indexer{ - cache: c, - maxCacheSize: maxCacheSize, - maxServersToMatch: maxServersToMatch, + hashToPods: make(map[BlockHash]podSet), + podToLRU: make(map[string]*lru.Cache[BlockHash, struct{}]), + maxLRUSize: maxLRUSize, } - go ix.ReportCacheSize(time.Second) + go ix.ReportLRUSize(time.Second) return ix } @@ -61,51 +53,106 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) { if pod.Name == "" { return } - i.mu.Lock() - defer i.mu.Unlock() + // Check if the LRU pod exist + podName := pod.String() + lruForPod, exists := i.podToLRU[podName] + if !exists { + newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) + i.podToLRU[podName] = newLRU + lruForPod = newLRU + } + i.mu.Unlock() + // Add to LRU (may evict) for _, hash := range hashes { - p, ok := i.cache.Get(hash) - if !ok { - // Create podSet with new LRU - podLRU, _ := lru.New[ServerID, struct{}](i.maxServersToMatch) - p = &podSet{enteries: podLRU} - i.cache.Add(hash, p) - } + lruForPod.Add(hash, struct{}{}) + } - p.enteries.Add(pod, struct{}{}) + // Update hashToPods once under lock + i.mu.Lock() + for _, hash := range hashes { + pods := i.hashToPods[hash] + if pods == nil { + pods = make(podSet) + } + pods[pod] = struct{}{} + i.hashToPods[hash] = pods } + i.mu.Unlock() + } // Get returns a set of servers that have the given prefix hash cached. -func (i *indexer) Get(hash BlockHash) map[ServerID]bool { +func (i *indexer) Get(hash BlockHash) podSet { i.mu.RLock() defer i.mu.RUnlock() - res := map[ServerID]bool{} - pods, ok := i.cache.Get(hash) + res := podSet{} + pods, ok := i.hashToPods[hash] if !ok { return res } - for _, pod := range pods.enteries.Keys() { - res[pod] = true + + return pods +} + +// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction. +func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) { + return func(hash BlockHash, _ struct{}) { + fmt.Printf("Evicted hash %v from pod %s\n", hash, pod) + + i.mu.Lock() + defer i.mu.Unlock() + print("enter eviction") + // Remove the pod from the hash→pods map + if podSet, ok := i.hashToPods[hash]; ok { + delete(podSet, pod) + if len(podSet) == 0 { + delete(i.hashToPods, hash) + } else { + i.hashToPods[hash] = podSet + } + } + print("After eviction") } - return res } -// ReportCacheSize starts a goroutine that periodically reports the cache size metric. -func (i *indexer) ReportCacheSize(interval time.Duration) { +// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric. +func (i *indexer) ReportLRUSize(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { i.mu.RLock() - size := i.cache.Len() - metrics.RecordPrefixCacheSize(int64(size)) - log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", - "# entries", size, - "prefix cache utilization [%]", float64(size)*100/float64(i.maxCacheSize), + totalEntries := 0 + maxPodEntries := 0 + maxPodName := "" + + for podName, lruCache := range i.podToLRU { + size := lruCache.Len() + totalEntries += size + if size > maxPodEntries { + maxPodEntries = size + maxPodName = podName + } + } + + numPods := len(i.podToLRU) + avg := 0.0 + if numPods > 0 { + avg = float64(totalEntries) / float64(numPods) + } + + metrics.RecordPrefixCacheSize(int64(totalEntries)) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state", + "total entries", totalEntries, + "# pods", numPods, + "avg entries per pod", avg, + "pod with max cache", maxPodName, + "max pod size", maxPodEntries, + "global max LRU cache capacity per pod", i.maxLRUSize, ) + i.mu.RUnlock() } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index 436ead771..ee1cb8b07 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -22,24 +22,25 @@ import ( ) func TestIndexer_AddAndGet(t *testing.T) { - i := newIndexer(2, 2) + i := newIndexer(2) hash1 := BlockHash(1) server := ServerID{Namespace: "default", Name: "server1"} - + serverName := server.String() // Add an entry to the cache i.Add([]BlockHash{hash1}, server) - // Retrieve the entry - assert.Equal(t, 1, i.cache.Len(), "Cache size should be 1 after adding an entry") + assert.Equal(t, 1, i.podToLRU[serverName].Len(), "Cache size should be 1 after adding an entry") servers := i.Get(hash1) assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. i.Add([]BlockHash{BlockHash(2)}, server) - assert.Equal(t, 2, i.cache.Len(), "Cache size should be 2 after adding an entry") + assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. + print("before Add") i.Add([]BlockHash{BlockHash(3)}, server) - assert.Equal(t, 2, i.cache.Len(), "Cache size should still be 2 after adding an entry") + print("after ADD") + assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index cf55c8400..31c991742 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -32,10 +32,6 @@ import ( const ( DefaultScorerWeight = 1 - // DefaultMaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer. - // This limits the number of recent pods associated with a given prefix to reduce memory usage - // and ensure faster lookup. When the limit is reached, the least recently used pod is evicted. - DefaultMaxPodsPerPrefix = 4 // vLLM default token block size is 16, and a good guess of average characters per token is 4. DefaultHashBlockSize = 64 // The maximum number of blocks to match. Two long requests with the same prefix up to this @@ -44,16 +40,15 @@ const ( // accuracy. Use a small value if most requests are short to reduce cache size and speed up the // matching process. Use a large value if most requests are long to increase the matching accuracy. DefaultMaxPrefixBlocks = 256 - // The indexer is an approximation to the actual prefix cache state on the model servers. + // The indexer is an approximation to the actual prefix LRU cache state on the model servers per server (pod). // A small capacity ensures a high accuracy of cache hit on the model server, but it will // increase the chance of false negatives. A high capacity does the opposite. // To properly size this, consider the sum of the total number of cache entries on all model // servers. Consider the llama3 8B model on 8 H100 80GB GPUs. The size of the model weight is // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 - // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 8 = 124.8K blocks, or - // roughly 130K. - DefaultLRUIndexerCapacity = 130000 + // in vLLM, we will have 250K / 16 = 15.6K blocks. + DefaultLRUCapacityPerServer = 15000 ) type Config struct { @@ -63,10 +58,8 @@ type Config struct { // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int - // MaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer. - MaxPodsPerPrefix int - // Max (approximate) size of the LRU indexer in number of entries. - LRUIndexerCapacity int + // Max (approximate) size of the LRU indexer in number of entries per server (pod). + LRUCapacityPerServer int } type Plugin struct { @@ -74,8 +67,11 @@ type Plugin struct { indexer Indexer } +// podSet holds an pods servers that may have a specific prefix hash. +type podSet map[ServerID]struct{} + type Indexer interface { - Get(hash BlockHash) map[ServerID]bool + Get(hash BlockHash) podSet Add(hashes []BlockHash, server ServerID) } @@ -121,7 +117,7 @@ var _ framework.PostCycle = &Plugin{} func New(config Config) *Plugin { m := &Plugin{ Config: config, - indexer: newIndexer(config.LRUIndexerCapacity, config.MaxPodsPerPrefix), + indexer: newIndexer(config.LRUCapacityPerServer), } return m } @@ -135,7 +131,7 @@ func (m *Plugin) Name() string { func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPodsPerPrefix) + hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) state := &schedulingContextState{ PrefixHashes: hashes, PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index d05988872..4c08f18a4 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -35,8 +35,7 @@ func TestPrefixPlugin(t *testing.T) { config := Config{ HashBlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, - LRUIndexerCapacity: DefaultLRUIndexerCapacity, - MaxPodsPerPrefix: DefaultMaxPodsPerPrefix, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, } plugin := New(config) @@ -150,8 +149,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { config := Config{ HashBlockSize: blockSize, MaxPrefixBlocksToMatch: maxPrefixBlocks, - LRUIndexerCapacity: DefaultLRUIndexerCapacity, - MaxPodsPerPrefix: DefaultMaxPodsPerPrefix, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, } plugin := New(config) diff --git a/site-src/guides/epp-configuration/prefix-aware.md b/site-src/guides/epp-configuration/prefix-aware.md index cbd2018f5..43e2ef064 100644 --- a/site-src/guides/epp-configuration/prefix-aware.md +++ b/site-src/guides/epp-configuration/prefix-aware.md @@ -32,11 +32,9 @@ extremely long inputs. 128 (or 128*64=8192 characters, or roughly 2048 tokens). This is useful to tradeoff prefix match accuracy for performance. -* `PREFIX_CACHE_LRU_CAPACITY`: Maximum capacity the prefix LRU indexer in number of block hashes. Below +* `PREFIX_CACHE_LRU_CAPACITY_PER_SERVER`: Maximum capacity the prefix LRU cache in number of block hashes per server (pod). Below shows a detailed analysis on how to estimate this. -* `PREFIX_MAX_PODS_PER_PREFIX`: Defines the maximum number of pods (servers) tracked per prefix hash in the internal LRU cache. -This setting helps optimize memory usage by retaining only the hottest (most recently active) pods for each prefix. -When the limit is reached, older pods are evicted based on least-recently-used (LRU) order. + The prefix cache plugin estimates the prefix cache indexes in model server HBMs. In the perfect @@ -68,7 +66,6 @@ When the limit is reached, older pods are evicted based on least-recently-used ( # assume avg_chars_per_token = 4, prefix_indexer_hash_block_size = 64 (default) # each entry is about 358KB, so the memory footrpint is abut 11 MB per server lru_indexer_capacity_per_server = 500,000*4/64 = 31250 - lru_indexer_capacity_total = 3 * 31250 = 93750 ``` See the [Use Helm section](#helm) to install an inferencepool with the environment variables. @@ -87,7 +84,7 @@ $ helm install triton-llama3-8b-instruct \ --set provider.name=[none|gke] \ --set inferenceExtension.env.EXPERIMENTAL_USE_SCHEDULER_V2=true \ --set inferenceExtension.env.ENABLE_PREFIX_CACHE_SCHEDULING=true \ - --set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY=93750 \ + --set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY_PER_SERVER=31250 \ --set inferenceExtension.env.PREFIX_CACHE_MAX_PREFIX_BLOCKS=1024 \ oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/inferencepool --version v0 ``` From e7255c8f9e06eff8cb942b6c1df1890b2b38f943 Mon Sep 17 00:00:00 2001 From: Kfir Toledo Date: Thu, 12 Jun 2025 22:52:21 +0300 Subject: [PATCH 4/5] fix: Fix typos and error handle Signed-off-by: Kfir Toledo --- .../framework/plugins/multi/prefix/indexer.go | 36 ++++++++----------- .../plugins/multi/prefix/indexer_test.go | 19 +++++----- .../framework/plugins/multi/prefix/plugin.go | 20 ++++++----- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index d0c48d9f6..947e14836 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -32,8 +32,8 @@ import ( // prefix cached. type indexer struct { mu sync.RWMutex - hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached - podToLRU map[string]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache + hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached + podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache maxLRUSize int } @@ -41,7 +41,7 @@ type indexer struct { func newIndexer(maxLRUSize int) *indexer { ix := &indexer{ hashToPods: make(map[BlockHash]podSet), - podToLRU: make(map[string]*lru.Cache[BlockHash, struct{}]), + podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]), maxLRUSize: maxLRUSize, } go ix.ReportLRUSize(time.Second) @@ -49,17 +49,17 @@ func newIndexer(maxLRUSize int) *indexer { } // Add adds a list of prefix hashes to the cache, tied to the server. -func (i *indexer) Add(hashes []BlockHash, pod ServerID) { - if pod.Name == "" { - return - } +func (i *indexer) Add(hashes []BlockHash, pod ServerID) error { i.mu.Lock() // Check if the LRU pod exist - podName := pod.String() - lruForPod, exists := i.podToLRU[podName] + lruForPod, exists := i.podToLRU[pod] if !exists { - newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) - i.podToLRU[podName] = newLRU + newLRU, err := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) + if err != nil { + i.mu.Unlock() + return fmt.Errorf("failed to create LRU for pod %s: %w", pod, err) + } + i.podToLRU[pod] = newLRU lruForPod = newLRU } i.mu.Unlock() @@ -80,7 +80,7 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) { i.hashToPods[hash] = pods } i.mu.Unlock() - + return nil } // Get returns a set of servers that have the given prefix hash cached. @@ -100,21 +100,15 @@ func (i *indexer) Get(hash BlockHash) podSet { // makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction. func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) { return func(hash BlockHash, _ struct{}) { - fmt.Printf("Evicted hash %v from pod %s\n", hash, pod) - i.mu.Lock() defer i.mu.Unlock() - print("enter eviction") // Remove the pod from the hash→pods map if podSet, ok := i.hashToPods[hash]; ok { delete(podSet, pod) if len(podSet) == 0 { delete(i.hashToPods, hash) - } else { - i.hashToPods[hash] = podSet } } - print("After eviction") } } @@ -126,14 +120,14 @@ func (i *indexer) ReportLRUSize(interval time.Duration) { i.mu.RLock() totalEntries := 0 maxPodEntries := 0 - maxPodName := "" + maxPodName := ServerID{} - for podName, lruCache := range i.podToLRU { + for pod, lruCache := range i.podToLRU { size := lruCache.Len() totalEntries += size if size > maxPodEntries { maxPodEntries = size - maxPodName = podName + maxPodName = pod } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index ee1cb8b07..2a00ae324 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -26,21 +26,22 @@ func TestIndexer_AddAndGet(t *testing.T) { hash1 := BlockHash(1) server := ServerID{Namespace: "default", Name: "server1"} - serverName := server.String() // Add an entry to the cache - i.Add([]BlockHash{hash1}, server) + err := i.Add([]BlockHash{hash1}, server) + assert.NoError(t, err) + // Retrieve the entry - assert.Equal(t, 1, i.podToLRU[serverName].Len(), "Cache size should be 1 after adding an entry") + assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry") servers := i.Get(hash1) assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. - i.Add([]BlockHash{BlockHash(2)}, server) - assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should be 2 after adding an entry") + err = i.Add([]BlockHash{BlockHash(2)}, server) + assert.NoError(t, err) + assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. - print("before Add") - i.Add([]BlockHash{BlockHash(3)}, server) - print("after ADD") - assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should still be 2 after adding an entry") + err = i.Add([]BlockHash{BlockHash(3)}, server) + assert.NoError(t, err) + assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 31c991742..c22b337e1 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -44,11 +44,11 @@ const ( // A small capacity ensures a high accuracy of cache hit on the model server, but it will // increase the chance of false negatives. A high capacity does the opposite. // To properly size this, consider the sum of the total number of cache entries on all model - // servers. Consider the llama3 8B model on 8 H100 80GB GPUs. The size of the model weight is - // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each - // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 - // in vLLM, we will have 250K / 16 = 15.6K blocks. - DefaultLRUCapacityPerServer = 15000 + // servers. Consider the llama3 8B model on a H100 80GB GPUs. The size of the model weight is + // about 16GB. The remaining HBM used for caching prefixes is 64GB. Each + // token is about 128KB in size, so we can cache 500K tokens. Using the default block size of 16 + // in vLLM, we will have 250K / 16 = 31.25K blocks. + DefaultLRUCapacityPerServer = 31250 ) type Config struct { @@ -58,7 +58,7 @@ type Config struct { // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int - // Max (approximate) size of the LRU indexer in number of entries per server (pod). + // Max capacity size of the LRU indexer in number of entries per server (pod). LRUCapacityPerServer int } @@ -72,7 +72,7 @@ type podSet map[ServerID]struct{} type Indexer interface { Get(hash BlockHash) podSet - Add(hashes []BlockHash, server ServerID) + Add(hashes []BlockHash, server ServerID) error } // BlockHash is a hash of the block of request body. @@ -165,7 +165,11 @@ func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, re log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state") return } - m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + err = m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + if err != nil { + log.FromContext(ctx).Error(err, "failed to add prefix hashes to indexer for target pod", "pod", targetPod.NamespacedName) + return + } total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) From ca451679b9e06cd2168abc2c65c7c12aca0ecab3 Mon Sep 17 00:00:00 2001 From: Kfir Toledo Date: Sun, 15 Jun 2025 10:07:44 +0300 Subject: [PATCH 5/5] fix: add safety check for LRUCapacityPerServer Signed-off-by: Kfir Toledo --- .../framework/plugins/multi/prefix/indexer.go | 13 +++++------- .../plugins/multi/prefix/indexer_test.go | 9 +++----- .../framework/plugins/multi/prefix/plugin.go | 21 ++++++++++++------- .../plugins/multi/prefix/plugin_test.go | 7 ++++++- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index 947e14836..716c9f265 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -18,7 +18,6 @@ package prefix import ( "context" - "fmt" "sync" "time" @@ -44,24 +43,22 @@ func newIndexer(maxLRUSize int) *indexer { podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]), maxLRUSize: maxLRUSize, } + go ix.ReportLRUSize(time.Second) return ix } // Add adds a list of prefix hashes to the cache, tied to the server. -func (i *indexer) Add(hashes []BlockHash, pod ServerID) error { +func (i *indexer) Add(hashes []BlockHash, pod ServerID) { i.mu.Lock() // Check if the LRU pod exist lruForPod, exists := i.podToLRU[pod] if !exists { - newLRU, err := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) - if err != nil { - i.mu.Unlock() - return fmt.Errorf("failed to create LRU for pod %s: %w", pod, err) - } + newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) i.podToLRU[pod] = newLRU lruForPod = newLRU } + i.mu.Unlock() // Add to LRU (may evict) @@ -79,8 +76,8 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) error { pods[pod] = struct{}{} i.hashToPods[hash] = pods } + i.mu.Unlock() - return nil } // Get returns a set of servers that have the given prefix hash cached. diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index 2a00ae324..240985033 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -27,8 +27,7 @@ func TestIndexer_AddAndGet(t *testing.T) { hash1 := BlockHash(1) server := ServerID{Namespace: "default", Name: "server1"} // Add an entry to the cache - err := i.Add([]BlockHash{hash1}, server) - assert.NoError(t, err) + i.Add([]BlockHash{hash1}, server) // Retrieve the entry assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry") @@ -36,12 +35,10 @@ func TestIndexer_AddAndGet(t *testing.T) { assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. - err = i.Add([]BlockHash{BlockHash(2)}, server) - assert.NoError(t, err) + i.Add([]BlockHash{BlockHash(2)}, server) assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. - err = i.Add([]BlockHash{BlockHash(3)}, server) - assert.NoError(t, err) + i.Add([]BlockHash{BlockHash(3)}, server) assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index c22b337e1..0d40746f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -72,7 +72,7 @@ type podSet map[ServerID]struct{} type Indexer interface { Get(hash BlockHash) podSet - Add(hashes []BlockHash, server ServerID) error + Add(hashes []BlockHash, server ServerID) } // BlockHash is a hash of the block of request body. @@ -115,9 +115,18 @@ var _ framework.PostCycle = &Plugin{} // New initializes a new prefix Plugin and returns its pointer. func New(config Config) *Plugin { + capacity := config.LRUCapacityPerServer + if capacity <= 0 { + capacity = DefaultLRUCapacityPerServer + log.FromContext(context.TODO()).V(logutil.DEFAULT).Info( + "LRUCapacityPerServer is not positive, using default value", + "defaultCapacity", DefaultLRUCapacityPerServer, + ) + } + m := &Plugin{ Config: config, - indexer: newIndexer(config.LRUCapacityPerServer), + indexer: newIndexer(capacity), } return m } @@ -165,11 +174,9 @@ func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, re log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state") return } - err = m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) - if err != nil { - log.FromContext(ctx).Error(err, "failed to add prefix hashes to indexer for target pod", "pod", targetPod.NamespacedName) - return - } + + m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 4c08f18a4..db1feacf4 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -154,7 +154,12 @@ func BenchmarkPrefixPluginStress(b *testing.B) { plugin := New(config) types.NewCycleState() - promptLen := []int{10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000} + var promptLen []int + for i := 1; i <= 1024; i++ { + promptLen = append(promptLen, i) + } + promptLen = append(promptLen, 2048, 4096, 8192, 10000, 20000, 50000) + for _, i := range promptLen { // Generate increasing-length random prompts prompt := randomPrompt(4 + i)