Skip to content

Commit b08d15d

Browse files
committed
initial commit - multiple destinations
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent fe5dea3 commit b08d15d

File tree

8 files changed

+105
-57
lines changed

8 files changed

+105
-57
lines changed

pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func LoraAffinityFilterFactory(name string, rawParameters json.RawMessage, _ plu
4646
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
4747
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", LoraAffinityFilterType, err)
4848
}
49+
4950
return NewLoraAffinityFilter(parameters.Threshold).WithName(name), nil
5051
}
5152

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
198198

199199
// PostCycle records in the plugin cache the result of the scheduling selection.
200200
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
201-
targetPod := res.TargetPod.GetPod()
201+
targetPod := res.TargetPods[0].GetPod() // assumes single TargetPod. this should become PostResponse to remove this assumption
202202
state, err := m.getPrefixState(cycleState)
203203
if err != nil {
204204
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestPrefixPlugin(t *testing.T) {
6161
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
6262

6363
// Simulate pod1 was picked.
64-
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPod: pod1})
64+
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
6565

6666
// Second request doesn't share any prefix with first one. It should be added to the cache but
6767
// the pod score should be 0.
@@ -82,7 +82,7 @@ func TestPrefixPlugin(t *testing.T) {
8282
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
8383

8484
// Simulate pod2 was picked.
85-
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPod: pod2})
85+
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
8686

8787
// Third request shares partial prefix with first one.
8888
req3 := &types.LLMRequest{
@@ -101,7 +101,7 @@ func TestPrefixPlugin(t *testing.T) {
101101
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
102102
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
103103

104-
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPod: pod1})
104+
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
105105

106106
// 4th request is same as req3 except the model is different, still no match.
107107
req4 := &types.LLMRequest{
@@ -120,7 +120,7 @@ func TestPrefixPlugin(t *testing.T) {
120120
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
121121
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
122122

123-
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPod: pod1})
123+
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
124124

125125
// 5th request shares partial prefix with 3rd one.
126126
req5 := &types.LLMRequest{
@@ -139,7 +139,7 @@ func TestPrefixPlugin(t *testing.T) {
139139
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
140140
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
141141

142-
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1})
142+
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
143143
}
144144

145145
// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -180,7 +180,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
180180
// First cycle: simulate scheduling and insert prefix info into the cache
181181
cycleState := types.NewCycleState()
182182
plugin.Score(context.Background(), cycleState, req, pods)
183-
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})
183+
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
184184

185185
// Second cycle: validate internal state
186186
state, err := plugin.getPrefixState(cycleState)

pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"sort"
2324

2425
"sigs.k8s.io/controller-runtime/pkg/log"
2526
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -32,26 +33,39 @@ const (
3233
MaxScorePickerType = "max-score"
3334
)
3435

36+
type maxScorePickerParameters struct {
37+
MaxNumOfEndpoints int `json:"maxEndpoints"`
38+
}
39+
3540
// compile-time type validation
3641
var _ framework.Picker = &MaxScorePicker{}
3742

3843
// MaxScorePickerFactory defines the factory function for MaxScorePicker.
39-
func MaxScorePickerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
40-
return NewMaxScorePicker().WithName(name), nil
44+
func MaxScorePickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
45+
parameters := maxScorePickerParameters{}
46+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
47+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", MaxScorePickerType, err)
48+
}
49+
50+
return NewMaxScorePicker(parameters.MaxNumOfEndpoints).WithName(name), nil
4151
}
4252

4353
// NewMaxScorePicker initializes a new MaxScorePicker and returns its pointer.
44-
func NewMaxScorePicker() *MaxScorePicker {
54+
func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
55+
if maxNumOfEndpoints <= 0 {
56+
maxNumOfEndpoints = 1 // on invalid configruation value, fallback to 1
57+
}
58+
4559
return &MaxScorePicker{
46-
name: MaxScorePickerType,
47-
random: NewRandomPicker(),
60+
name: MaxScorePickerType,
61+
maxNumOfEndpoints: maxNumOfEndpoints,
4862
}
4963
}
5064

5165
// MaxScorePicker picks the pod with the maximum score from the list of candidates.
5266
type MaxScorePicker struct {
53-
name string
54-
random *RandomPicker
67+
name string
68+
maxNumOfEndpoints int // maximum number of endpoints to pick
5569
}
5670

5771
// Type returns the type of the picker.
@@ -72,22 +86,22 @@ func (p *MaxScorePicker) WithName(name string) *MaxScorePicker {
7286

7387
// Pick selects the pod with the maximum score from the list of candidates.
7488
func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
75-
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods))
76-
77-
highestScorePods := []*types.ScoredPod{}
78-
maxScore := -1.0 // pods min score is 0, putting value lower than 0 in order to find at least one pod as highest
79-
for _, pod := range scoredPods {
80-
if pod.Score > maxScore {
81-
maxScore = pod.Score
82-
highestScorePods = []*types.ScoredPod{pod}
83-
} else if pod.Score == maxScore {
84-
highestScorePods = append(highestScorePods, pod)
85-
}
89+
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
90+
len(scoredPods), scoredPods))
91+
92+
sort.Slice(scoredPods, func(i, j int) bool { // highest score first
93+
return scoredPods[i].Score > scoredPods[j].Score
94+
})
95+
96+
// if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods
97+
if p.maxNumOfEndpoints < len(scoredPods) {
98+
scoredPods = scoredPods[:p.maxNumOfEndpoints]
8699
}
87100

88-
if len(highestScorePods) > 1 {
89-
return p.random.Pick(ctx, cycleState, highestScorePods) // pick randomly from the highest score pods
101+
targetPods := make([]types.Pod, len(scoredPods))
102+
for i, scoredPod := range scoredPods {
103+
targetPods[i] = scoredPod
90104
}
91105

92-
return &types.ProfileRunResult{TargetPod: highestScorePods[0]}
106+
return &types.ProfileRunResult{TargetPods: targetPods}
93107
}

pkg/epp/scheduling/framework/plugins/picker/random_picker.go

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,39 @@ const (
3333
RandomPickerType = "random"
3434
)
3535

36+
type randomPickerParameters struct {
37+
MaxNumOfEndpoints int `json:"maxEndpoints"`
38+
}
39+
3640
// compile-time type validation
3741
var _ framework.Picker = &RandomPicker{}
3842

3943
// RandomPickerFactory defines the factory function for RandomPicker.
40-
func RandomPickerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
41-
return NewRandomPicker().WithName(name), nil
44+
func RandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
45+
parameters := randomPickerParameters{}
46+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
47+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", RandomPickerType, err)
48+
}
49+
50+
return NewRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
4251
}
4352

4453
// NewRandomPicker initializes a new RandomPicker and returns its pointer.
45-
func NewRandomPicker() *RandomPicker {
54+
func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker {
55+
if maxNumOfEndpoints <= 0 {
56+
maxNumOfEndpoints = 1 // on invalid configruation value, fallback to 1
57+
}
58+
4659
return &RandomPicker{
47-
name: RandomPickerType,
60+
name: RandomPickerType,
61+
maxNumOfEndpoints: maxNumOfEndpoints,
4862
}
4963
}
5064

5165
// RandomPicker picks a random pod from the list of candidates.
5266
type RandomPicker struct {
53-
name string
67+
name string
68+
maxNumOfEndpoints int // maximum number of endpoints to pick
5469
}
5570

5671
// Type returns the type of the picker.
@@ -71,7 +86,23 @@ func (p *RandomPicker) WithName(name string) *RandomPicker {
7186

7287
// Pick selects a random pod from the list of candidates.
7388
func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
74-
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods))
75-
i := rand.Intn(len(scoredPods))
76-
return &types.ProfileRunResult{TargetPod: scoredPods[i]}
89+
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates randomly: %+v", p.maxNumOfEndpoints,
90+
len(scoredPods), scoredPods))
91+
92+
// Shuffle in-place
93+
rand.Shuffle(len(scoredPods), func(i, j int) {
94+
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
95+
})
96+
97+
// if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods
98+
if p.maxNumOfEndpoints < len(scoredPods) {
99+
scoredPods = scoredPods[:p.maxNumOfEndpoints]
100+
}
101+
102+
targetPods := make([]types.Pod, len(scoredPods))
103+
for i, scoredPod := range scoredPods {
104+
targetPods[i] = scoredPod
105+
}
106+
107+
return &types.ProfileRunResult{TargetPods: targetPods}
77108
}

pkg/epp/scheduling/framework/scheduler_profile_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,12 @@ func TestSchedulePlugins(t *testing.T) {
141141
}
142142

143143
// Validate output
144-
wantPod := &types.PodMetrics{
145-
Pod: &backend.Pod{NamespacedName: test.wantTargetPod, Labels: make(map[string]string)},
146-
}
147144
wantRes := &types.ProfileRunResult{
148-
TargetPod: wantPod,
145+
TargetPods: []types.Pod{
146+
&types.PodMetrics{
147+
Pod: &backend.Pod{NamespacedName: test.wantTargetPod, Labels: make(map[string]string)},
148+
},
149+
},
149150
}
150151

151152
if diff := cmp.Diff(wantRes, got); diff != "" {
@@ -231,15 +232,15 @@ func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods []
231232
tp.PickCallCount++
232233
tp.NumOfPickerCandidates = len(scoredPods)
233234

234-
var winnerPod types.Pod
235+
winnerPods := []types.Pod{}
235236
for _, scoredPod := range scoredPods {
236237
if scoredPod.GetPod().NamespacedName.String() == tp.PickRes.String() {
237-
winnerPod = scoredPod.Pod
238+
winnerPods = append(winnerPods, scoredPod.Pod)
238239
tp.WinnderPodScore = scoredPod.Score
239240
}
240241
}
241242

242-
return &types.ProfileRunResult{TargetPod: winnerPod}
243+
return &types.ProfileRunResult{TargetPods: winnerPods}
243244
}
244245

245246
func (tp *testPlugin) PostCycle(_ context.Context, _ *types.CycleState, res *types.ProfileRunResult) {

pkg/epp/scheduling/scheduler_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,23 @@ func TestSchedule(t *testing.T) {
9595
wantRes: &types.SchedulingResult{
9696
ProfileResults: map[string]*types.ProfileRunResult{
9797
"default": {
98-
TargetPod: &types.ScoredPod{
99-
Pod: &types.PodMetrics{
100-
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, Labels: make(map[string]string)},
101-
MetricsState: &backendmetrics.MetricsState{
102-
WaitingQueueSize: 3,
103-
KVCacheUsagePercent: 0.1,
104-
MaxActiveModels: 2,
105-
ActiveModels: map[string]int{
106-
"foo": 1,
107-
"critical": 1,
98+
TargetPods: []types.Pod{
99+
&types.ScoredPod{
100+
Pod: &types.PodMetrics{
101+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, Labels: make(map[string]string)},
102+
MetricsState: &backendmetrics.MetricsState{
103+
WaitingQueueSize: 3,
104+
KVCacheUsagePercent: 0.1,
105+
MaxActiveModels: 2,
106+
ActiveModels: map[string]int{
107+
"foo": 1,
108+
"critical": 1,
109+
},
110+
WaitingModels: map[string]int{},
108111
},
109-
WaitingModels: map[string]int{},
110112
},
111113
},
112-
},
113-
},
114+
}},
114115
},
115116
PrimaryProfileName: "default",
116117
},

pkg/epp/scheduling/types/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod {
8080

8181
// ProfileRunResult captures the profile run result.
8282
type ProfileRunResult struct {
83-
TargetPod Pod
83+
TargetPods []Pod
8484
}
8585

8686
// SchedulingResult captures the result of the scheduling cycle.

0 commit comments

Comments
 (0)