Skip to content

Commit 2b6e9fe

Browse files
committed
implement multiple destination as the output of the scheduler
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 8b3dbe9 commit 2b6e9fe

File tree

15 files changed

+238
-134
lines changed

15 files changed

+238
-134
lines changed

cmd/epp/runner/runner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
293293
schedulerProfile := framework.NewSchedulerProfile().
294294
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
295295
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
296-
WithPicker(picker.NewMaxScorePicker())
296+
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
297297

298298
if prefixCacheScheduling {
299299
prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog)

conformance/testing-epp/scheduler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
func NewReqHeaderBasedScheduler() *scheduling.Scheduler {
3131
predicatableSchedulerProfile := framework.NewSchedulerProfile().
3232
WithFilters(filter.NewHeaderBasedTestingFilter()).
33-
WithPicker(picker.NewMaxScorePicker())
33+
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
3434

3535
return scheduling.NewSchedulerWithConfig(scheduling.NewSchedulerConfig(
3636
profile.NewSingleProfileHandler(), map[string]*framework.SchedulerProfile{"req-header-based-profile": predicatableSchedulerProfile}))

conformance/testing-epp/scheduler_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ func TestSchedule(t *testing.T) {
8282
wantRes: &types.SchedulingResult{
8383
ProfileResults: map[string]*types.ProfileRunResult{
8484
"req-header-based-profile": {
85-
TargetPod: &types.ScoredPod{
86-
Pod: &types.PodMetrics{
87-
Pod: &backend.Pod{
88-
Address: "matched-endpoint",
89-
Labels: map[string]string{},
85+
TargetPods: []types.Pod{
86+
&types.ScoredPod{
87+
Pod: &types.PodMetrics{
88+
Pod: &backend.Pod{
89+
Address: "matched-endpoint",
90+
Labels: map[string]string{},
91+
},
9092
},
9193
},
9294
},

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
238238
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
239239
}
240240
// primary profile is used to set destination
241-
targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPod.GetPod()
241+
// TODO should use multiple destinations according to epp protocol. current code assumes a single target
242+
targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPods[0].GetPod()
242243

243244
pool, err := d.datastore.PoolGet()
244245
if err != nil {

pkg/epp/requestcontrol/director_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@ func TestDirector_HandleRequest(t *testing.T) {
131131
defaultSuccessfulScheduleResults := &schedulingtypes.SchedulingResult{
132132
ProfileResults: map[string]*schedulingtypes.ProfileRunResult{
133133
"testProfile": {
134-
TargetPod: &schedulingtypes.ScoredPod{
135-
Pod: &schedulingtypes.PodMetrics{
136-
Pod: &backend.Pod{
137-
Address: "192.168.1.100",
138-
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
134+
TargetPods: []schedulingtypes.Pod{
135+
&schedulingtypes.ScoredPod{
136+
Pod: &schedulingtypes.PodMetrics{
137+
Pod: &backend.Pod{
138+
Address: "192.168.1.100",
139+
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
140+
},
139141
},
140142
},
141143
},

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
193193

194194
// PostCycle records in the plugin cache the result of the scheduling selection.
195195
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
196-
targetPod := res.TargetPod.GetPod()
196+
targetPod := res.TargetPods[0].GetPod()
197197
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
198198
if err != nil {
199199
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestPrefixPlugin(t *testing.T) {
6161
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
6262

6363
// Simulate pod1 was picked.
64-
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPod: pod1})
64+
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
6565

6666
// Second request doesn't share any prefix with first one. It should be added to the cache but
6767
// the pod score should be 0.
@@ -82,7 +82,7 @@ func TestPrefixPlugin(t *testing.T) {
8282
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
8383

8484
// Simulate pod2 was picked.
85-
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPod: pod2})
85+
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
8686

8787
// Third request shares partial prefix with first one.
8888
req3 := &types.LLMRequest{
@@ -101,7 +101,7 @@ func TestPrefixPlugin(t *testing.T) {
101101
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
102102
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
103103

104-
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPod: pod1})
104+
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
105105

106106
// 4th request is same as req3 except the model is different, still no match.
107107
req4 := &types.LLMRequest{
@@ -120,7 +120,7 @@ func TestPrefixPlugin(t *testing.T) {
120120
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
121121
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
122122

123-
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPod: pod1})
123+
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
124124

125125
// 5th request shares partial prefix with 3rd one.
126126
req5 := &types.LLMRequest{
@@ -139,7 +139,7 @@ func TestPrefixPlugin(t *testing.T) {
139139
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
140140
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
141141

142-
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1})
142+
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
143143
}
144144

145145
// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -180,7 +180,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
180180
// First cycle: simulate scheduling and insert prefix info into the cache
181181
cycleState := types.NewCycleState()
182182
plugin.Score(context.Background(), cycleState, req, pods)
183-
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})
183+
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
184184

185185
// Second cycle: validate internal state
186186
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package picker
18+
19+
const (
20+
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
21+
)
22+
23+
// pickerParameters defines the common parameters for all pickers
24+
type pickerParameters struct {
25+
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
26+
}

pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"sort"
2324

2425
"sigs.k8s.io/controller-runtime/pkg/log"
2526
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -36,53 +37,65 @@ const (
3637
var _ framework.Picker = &MaxScorePicker{}
3738

3839
// MaxScorePickerFactory defines the factory function for MaxScorePicker.
39-
func MaxScorePickerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
40-
return NewMaxScorePicker().WithName(name), nil
40+
func MaxScorePickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
41+
parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints}
42+
if rawParameters != nil {
43+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
44+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", MaxScorePickerType, err)
45+
}
46+
}
47+
48+
return NewMaxScorePicker(parameters.MaxNumOfEndpoints).WithName(name), nil
4149
}
4250

4351
// NewMaxScorePicker initializes a new MaxScorePicker and returns its pointer.
44-
func NewMaxScorePicker() *MaxScorePicker {
52+
func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
53+
if maxNumOfEndpoints <= 0 {
54+
maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value
55+
}
56+
4557
return &MaxScorePicker{
46-
tn: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
47-
random: NewRandomPicker(),
58+
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
59+
maxNumOfEndpoints: maxNumOfEndpoints,
4860
}
4961
}
5062

51-
// MaxScorePicker picks the pod with the maximum score from the list of candidates.
63+
// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
5264
type MaxScorePicker struct {
53-
tn plugins.TypedName
54-
random *RandomPicker
55-
}
56-
57-
// TypedName returns the type and name tuple of this plugin instance.
58-
func (p *MaxScorePicker) TypedName() plugins.TypedName {
59-
return p.tn
65+
typedName plugins.TypedName
66+
maxNumOfEndpoints int // maximum number of endpoints to pick
6067
}
6168

6269
// WithName sets the picker's name
6370
func (p *MaxScorePicker) WithName(name string) *MaxScorePicker {
64-
p.tn.Name = name
71+
p.typedName.Name = name
6572
return p
6673
}
6774

75+
// TypedName returns the type and name tuple of this plugin instance.
76+
func (p *MaxScorePicker) TypedName() plugins.TypedName {
77+
return p.typedName
78+
}
79+
6880
// Pick selects the pod with the maximum score from the list of candidates.
6981
func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
70-
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods))
71-
72-
highestScorePods := []*types.ScoredPod{}
73-
maxScore := -1.0 // pods min score is 0, putting value lower than 0 in order to find at least one pod as highest
74-
for _, pod := range scoredPods {
75-
if pod.Score > maxScore {
76-
maxScore = pod.Score
77-
highestScorePods = []*types.ScoredPod{pod}
78-
} else if pod.Score == maxScore {
79-
highestScorePods = append(highestScorePods, pod)
80-
}
82+
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
83+
len(scoredPods), scoredPods))
84+
85+
sort.Slice(scoredPods, func(i, j int) bool { // highest score first
86+
return scoredPods[i].Score > scoredPods[j].Score
87+
})
88+
89+
// if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods
90+
if p.maxNumOfEndpoints < len(scoredPods) {
91+
scoredPods = scoredPods[:p.maxNumOfEndpoints]
8192
}
8293

83-
if len(highestScorePods) > 1 {
84-
return p.random.Pick(ctx, cycleState, highestScorePods) // pick randomly from the highest score pods
94+
targetPods := make([]types.Pod, len(scoredPods))
95+
for i, scoredPod := range scoredPods {
96+
targetPods[i] = scoredPod
8597
}
8698

87-
return &types.ProfileRunResult{TargetPod: highestScorePods[0]}
99+
return &types.ProfileRunResult{TargetPods: targetPods}
100+
88101
}

0 commit comments

Comments
 (0)