Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ const (
// token is about 128KB in size, so we can cache 500K tokens. Using the default block size of 16
// in vLLM, we will have 250K / 16 = 31.25K blocks.
DefaultLRUCapacityPerServer = 31250
// In P/D disaggregation mode, the prefill and decode are usually represented as two different scheduling profiles to pick
// the prefill and decode endpoints. This constant defines the prefill profile name to ensure that the index is updated
// for the prefill endpoint and not only for the primary endpoint that will initially handle the request.
// This is hardcoded for now until we land on a canonical approach for plugins to identify prefill and decode endpoints
// (See https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/2080)
Experimental_DefaultPrefillProfile = "prefill"

PrefixCachePluginType = "prefix-cache-scorer"
)
Expand Down Expand Up @@ -269,10 +275,10 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile
servers := []Server{p.makeServer(targetPod)}

gpuBlocks := p.config.LRUCapacityPerServer
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
if pr, exists := schedulingResult.ProfileResults[Experimental_DefaultPrefillProfile]; exists && len(pr.TargetPods) > 0 {
servers = append(servers, p.makeServer(pr.TargetPods[0]))
}

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
Expand All @@ -288,10 +294,9 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
// WaitGroup is added to the Plugin struct to allow waiting in tests.
p.wg.Add(1)
go func() {
p.indexer.Add(state.PrefixHashes, Server{
ServerID(targetPod.GetPod().NamespacedName),
gpuBlocks,
})
for _, s := range servers {
p.indexer.Add(state.PrefixHashes, s)
}
p.wg.Done()
}()

Expand All @@ -302,6 +307,17 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
}

func (p *Plugin) makeServer(targetPod types.Pod) Server {
gpuBlocks := p.config.LRUCapacityPerServer
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
}
return Server{
ServerID(targetPod.GetPod().NamespacedName),
gpuBlocks,
}
}

// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
Expand Down
13 changes: 8 additions & 5 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func TestPrefixPluginCompletion(t *testing.T) {

pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()}
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()}
pods := []types.Pod{pod1, pod2}
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, MetricsState: backendmetrics.NewMetricsState()}
pods := []types.Pod{pod1, pod2, pod3}

// First request.
req1 := &types.LLMRequest{
Expand All @@ -72,11 +73,12 @@ func TestPrefixPluginCompletion(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod1 was picked.
// Simulate pod1 was picked and pod3 was picked as a prefill node.
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod1}},
"default": {TargetPods: []types.Pod{pod1}},
Experimental_DefaultPrefillProfile: {TargetPods: []types.Pod{pod3}},
},
}
plugin.PreRequest(context.Background(), req1, schedulingResult)
Expand Down Expand Up @@ -131,8 +133,9 @@ func TestPrefixPluginCompletion(t *testing.T) {
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
// Total hashes = 2 (the first one is for the prefix with model)
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix")
assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match")
assert.Equal(t, 0.5, scores[pod3], "score should be 0.5 - the model and the first prefix block match on the prefill node")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

schedulingResult = &types.SchedulingResult{
Expand Down Expand Up @@ -191,7 +194,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
// Total hashes = 3 (the first one is for the prefix with model)
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix")
assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

Expand Down