diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index f0f44d51e3..b5b8c55bbc 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -55,6 +55,12 @@ const ( // token is about 128KB in size, so we can cache 500K tokens. Using the default block size of 16 // in vLLM, we will have 250K / 16 = 31.25K blocks. DefaultLRUCapacityPerServer = 31250 + // In P/D disaggregation mode, the prefill and decode are usually represented as two different scheduling profiles to pick + // the prefill and decode endpoints. This constant defines the prefill profile name to ensure that the index is updated + // for the prefill endpoint and not only for the primary endpoint that will initially handle the request. + // This is hardcoded for now until we land on a canonical approach for plugins to identify prefill and decode endpoints + // (See https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/2080) + Experimental_DefaultPrefillProfile = "prefill" PrefixCachePluginType = "prefix-cache-scorer" ) @@ -269,10 +275,10 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile + servers := []Server{p.makeServer(targetPod)} - gpuBlocks := p.config.LRUCapacityPerServer - if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 { - gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks + if pr, exists := schedulingResult.ProfileResults[Experimental_DefaultPrefillProfile]; exists && len(pr.TargetPods) > 0 { + servers = append(servers, p.makeServer(pr.TargetPods[0])) } state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) @@ -288,10 +294,9 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche // WaitGroup is added to the Plugin struct to allow waiting in tests. p.wg.Add(1) go func() { - p.indexer.Add(state.PrefixHashes, Server{ - ServerID(targetPod.GetPod().NamespacedName), - gpuBlocks, - }) + for _, s := range servers { + p.indexer.Add(state.PrefixHashes, s) + } p.wg.Done() }() @@ -302,6 +307,17 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) } +func (p *Plugin) makeServer(targetPod types.Pod) Server { + gpuBlocks := p.config.LRUCapacityPerServer + if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 { + gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks + } + return Server{ + ServerID(targetPod.GetPod().NamespacedName), + gpuBlocks, + } +} + // matchLongestPrefix returns a map of servers and length of prefix that each server caches. func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 31cb51edba..f6400fab1b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -49,7 +49,8 @@ func TestPrefixPluginCompletion(t *testing.T) { pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()} pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()} - pods := []types.Pod{pod1, pod2} + pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, MetricsState: backendmetrics.NewMetricsState()} + pods := []types.Pod{pod1, pod2, pod3} // First request. req1 := &types.LLMRequest{ @@ -72,11 +73,12 @@ func TestPrefixPluginCompletion(t *testing.T) { assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - // Simulate pod1 was picked. + // Simulate pod1 was picked and pod3 was picked as a prefill node. schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetPods: []types.Pod{pod1}}, + Experimental_DefaultPrefillProfile: {TargetPods: []types.Pod{pod3}}, }, } plugin.PreRequest(context.Background(), req1, schedulingResult) @@ -131,8 +133,9 @@ func TestPrefixPluginCompletion(t *testing.T) { // Input size is 8, hash block size is 4, so 2 hashes will be calculated. // Total hashes = 2 (the first one is for the prefix with model) assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix") assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match") + assert.Equal(t, 0.5, scores[pod3], "score should be 0.5 - the model and the first prefix block match on the prefill node") assert.Equal(t, float64(0), scores[pod2], "score for pod2") schedulingResult = &types.SchedulingResult{ @@ -191,7 +194,7 @@ func TestPrefixPluginCompletion(t *testing.T) { // Input size is 12, hash block size is 4, so 3 hashes will be calculated. // Total hashes = 3 (the first one is for the prefix with model) assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix") assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match") assert.Equal(t, float64(0), scores[pod2], "score for pod2")