From 85b8a196e5979ad17d0494d1efc6c0dad29b4922 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Fri, 19 Dec 2025 08:34:21 +0000 Subject: [PATCH 1/2] Store the precise prefix cache score in cycleState. Signed-off-by: HyunKyun Moon --- pkg/plugins/profile/pd_profile_handler.go | 15 +++++++---- .../profile/pd_profile_handler_test.go | 8 ++++++ pkg/plugins/scorer/no_hit_lru.go | 25 ++++++++++++------- pkg/plugins/scorer/no_hit_lru_test.go | 18 ++++++++----- pkg/plugins/scorer/precise_prefix_cache.go | 16 +++++++++++- pkg/scheduling/pd/scheduler_test.go | 2 +- 6 files changed, 62 insertions(+), 22 deletions(-) diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index a3fe3e75d3..d470acb5e0 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -27,13 +27,14 @@ const ( defaultDecodeProfile = "decode" defaultPrefillProfile = "prefill" - defaultPrefixPluginName = prefix.PrefixCachePluginType + defaultPrefixPluginType = prefix.PrefixCachePluginType ) type pdProfileHandlerParameters struct { Threshold int `json:"threshold"` DecodeProfile string `json:"decodeProfile"` PrefillProfile string `json:"prefillProfile"` + PrefixPluginType string `json:"prefixPluginType"` PrefixPluginName string `json:"prefixPluginName"` HashBlockSize int `json:"hashBlockSize"` PrimaryPort int `json:"primaryPort"` @@ -48,7 +49,7 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi Threshold: 0, DecodeProfile: defaultDecodeProfile, PrefillProfile: defaultPrefillProfile, - PrefixPluginName: defaultPrefixPluginName, + PrefixPluginType: defaultPrefixPluginType, HashBlockSize: prefix.DefaultBlockSize, PrimaryPort: 0, } @@ -58,6 +59,10 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi } } + if parameters.PrefixPluginName == "" { + parameters.PrefixPluginName = parameters.PrefixPluginType + } + if parameters.Threshold < 0 { return nil, fmt.Errorf("invalid threshold: must be >= 0, got %d", parameters.Threshold) } @@ -72,15 +77,15 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi } } - return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginName, + return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName, parameters.Threshold, parameters.HashBlockSize, parameters.PrimaryPort).WithName(name), nil } // NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer. -func NewPdProfileHandler(prefillProfile string, decodeProfile string, prefixPluginName string, pdThreshold int, hashBlockSize int, primaryPort int) *PdProfileHandler { +func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler { result := &PdProfileHandler{ typedName: plugins.TypedName{Type: PdProfileHandlerType}, - prefixPluginTypedName: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefixPluginName}, + prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, decodeProfile: decodeProfile, prefillProfile: prefillProfile, pdThreshold: pdThreshold, diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go index fc78b62121..c6661bcf9c 100644 --- a/pkg/plugins/profile/pd_profile_handler_test.go +++ b/pkg/plugins/profile/pd_profile_handler_test.go @@ -232,6 +232,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name string pdThreshold int hashBlockSize int + prefixPluginType string prefixPluginName string setupPrefixState func(*types.CycleState) profileResults map[string]*types.ProfileRunResult @@ -241,6 +242,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name: "decode not executed yet → run decode", pdThreshold: 100, hashBlockSize: 16, + prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*types.ProfileRunResult{}, expectedProfiles: []string{"decode"}, @@ -249,6 +251,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name: "decode failed (nil result) → run nothing", pdThreshold: 100, hashBlockSize: 16, + prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*types.ProfileRunResult{ "decode": nil, @@ -259,6 +262,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name: "all profiles already executed → run nothing", pdThreshold: 100, hashBlockSize: 16, + prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*types.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), @@ -270,6 +274,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name: "pd threshold NOT triggered → run prefill", pdThreshold: 5, hashBlockSize: 16, + prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, setupPrefixState: func(cs *types.CycleState) { state := &prefix.SchedulingContextState{ @@ -289,6 +294,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { name: "pd threshold triggered (short non-cached suffix) → skip prefill", pdThreshold: 100, hashBlockSize: 16, + prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, setupPrefixState: func(cs *types.CycleState) { state := &prefix.SchedulingContextState{ @@ -311,6 +317,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { handler := NewPdProfileHandler( "prefill", "decode", + tt.prefixPluginType, tt.prefixPluginName, tt.pdThreshold, tt.hashBlockSize, @@ -402,6 +409,7 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { "prefill", "decode", prefix.PrefixCachePluginType, + prefix.PrefixCachePluginType, 0, prefix.DefaultBlockSize, tt.primaryPort, diff --git a/pkg/plugins/scorer/no_hit_lru.go b/pkg/plugins/scorer/no_hit_lru.go index 67c689a3bd..417cf05a56 100644 --- a/pkg/plugins/scorer/no_hit_lru.go +++ b/pkg/plugins/scorer/no_hit_lru.go @@ -29,6 +29,9 @@ var _ requestcontrol.PreRequest = &NoHitLRU{} // NoHitLRUParameters defines the parameters for the NoHitLRU scorer. type NoHitLRUParameters struct { + // PrefixPluginType defines the type of the prefix cache plugin to read state from. + // Defaults to "prefix-cache-scorer". + PrefixPluginType string `json:"prefixPluginType"` // PrefixPluginName defines the name of the prefix cache plugin to read state from. // Defaults to "prefix-cache-scorer". PrefixPluginName string `json:"prefixPluginName"` @@ -69,10 +72,14 @@ func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins. // NewNoHitLRU creates a new NoHitLRU scorer func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU { + prefixPluginType := prefix.PrefixCachePluginType prefixPluginName := prefix.PrefixCachePluginType lruSize := defaultLRUSize if params != nil { + if params.PrefixPluginType != "" { + prefixPluginType = params.PrefixPluginType + } if params.PrefixPluginName != "" { prefixPluginName = params.PrefixPluginName } @@ -88,10 +95,10 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU { } return &NoHitLRU{ - typedName: plugins.TypedName{Type: NoHitLRUType}, - lruCache: lruCache, - prefixPluginName: prefixPluginName, - pluginState: plugins.NewPluginState(ctx), + typedName: plugins.TypedName{Type: NoHitLRUType}, + lruCache: lruCache, + prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, + pluginState: plugins.NewPluginState(ctx), } } @@ -99,10 +106,10 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU { // This can help evenly distribute cache growth, since cold requests result in more // new KV blocks. type NoHitLRU struct { - typedName plugins.TypedName - lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order) - prefixPluginName string - pluginState *plugins.PluginState + typedName plugins.TypedName + lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order) + prefixPluginTypedName plugins.TypedName + pluginState *plugins.PluginState } // TypedName returns the typed name of the plugin. @@ -123,7 +130,7 @@ func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleSta // Read prefix cache state to determine if this is a cold request // This is treated as an optimization - if the state isn't available, we assume cold request - prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginName)) + prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginTypedName.String())) if err != nil { logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err) diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go index 74eba957f0..fe02492f77 100644 --- a/pkg/plugins/scorer/no_hit_lru_test.go +++ b/pkg/plugins/scorer/no_hit_lru_test.go @@ -202,7 +202,8 @@ func TestNoHitLRUScorer(t *testing.T) { // Create cycle state and set prefix state cycleState := &types.CycleState{} if test.prefixState != nil { - cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), test.prefixState) + cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), test.prefixState) } got := test.scorer.Score(context.Background(), cycleState, test.req, test.input) @@ -234,7 +235,8 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) { PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request } cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), coldPrefixState) + cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), coldPrefixState) scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) @@ -295,7 +297,8 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) { primaryProfile := "primary-profile" toPrefixState := func(entries map[prefix.ServerID]int) *types.CycleState { cycle := &types.CycleState{} - cycle.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{PrefixCacheServers: entries}) + cycle.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{PrefixCacheServers: entries}) return cycle } @@ -406,7 +409,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) { t.Run("empty pods list", func(t *testing.T) { emptyPods := []types.Pod{} cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) @@ -419,7 +423,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) { t.Run("nil pods list", func(t *testing.T) { cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) @@ -436,7 +441,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) { t.Run("single pod returns 1.0", func(t *testing.T) { pods := []types.Pod{podA} cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 9f6866c2ca..2ce6551c0f 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -13,6 +13,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -125,7 +126,7 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor // Score scores the provided pod based on the KVCache index state. // The returned scores are normalized to a range of 0-1. -func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) @@ -150,6 +151,19 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat return metricsPod.Address, true } + state := &prefix.SchedulingContextState{ + PrefixHashes: []prefix.BlockHash{}, + PrefixCacheServers: map[prefix.ServerID]int{}, + } + for _, pod := range pods { + key, ok := podToKey(pod) + if !ok { + continue + } + state.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] = int(scores[key]) + } + cycleState.Write(plugins.StateKey(s.typedName.String()), state) + return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) } diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go index 25e5945ad2..bd1f6b1c11 100644 --- a/pkg/scheduling/pd/scheduler_test.go +++ b/pkg/scheduling/pd/scheduler_test.go @@ -247,7 +247,7 @@ func TestPDSchedule(t *testing.T) { err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Name, 10, 5, 0) + profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 5, 0) schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{ prefill: prefillSchedulerProfile, From 8232e67cce8ffe061acce1ae093ef5ac8c3950c1 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Mon, 22 Dec 2025 00:10:05 +0900 Subject: [PATCH 2/2] edit test code Signed-off-by: HyunKyun Moon --- .gitignore | 3 +++ pkg/plugins/scorer/precise_prefix_cache_test.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index dbb7ca444e..f94c6c6ce6 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,6 @@ go.work.sum .vscode vendor + +# Build output +/build diff --git a/pkg/plugins/scorer/precise_prefix_cache_test.go b/pkg/plugins/scorer/precise_prefix_cache_test.go index e033228d90..00887cc7bf 100644 --- a/pkg/plugins/scorer/precise_prefix_cache_test.go +++ b/pkg/plugins/scorer/precise_prefix_cache_test.go @@ -583,7 +583,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { } } - got := prefixCacheScorer.Score(ctx, nil, tt.request, tt.pods) + got := prefixCacheScorer.Score(ctx, types.NewCycleState(), tt.request, tt.pods) gotByAddress := make(map[string]float64) for pod, score := range got {