Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
schedulerProfile := framework.NewSchedulerProfile().
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
WithPicker(picker.NewMaxScorePicker())
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))

if prefixCacheScheduling {
prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog)
Expand Down
2 changes: 1 addition & 1 deletion conformance/testing-epp/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
func NewReqHeaderBasedScheduler() *scheduling.Scheduler {
predicatableSchedulerProfile := framework.NewSchedulerProfile().
WithFilters(filter.NewHeaderBasedTestingFilter()).
WithPicker(picker.NewMaxScorePicker())
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))

return scheduling.NewSchedulerWithConfig(scheduling.NewSchedulerConfig(
profile.NewSingleProfileHandler(), map[string]*framework.SchedulerProfile{"req-header-based-profile": predicatableSchedulerProfile}))
Expand Down
12 changes: 7 additions & 5 deletions conformance/testing-epp/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ func TestSchedule(t *testing.T) {
wantRes: &types.SchedulingResult{
ProfileResults: map[string]*types.ProfileRunResult{
"req-header-based-profile": {
TargetPod: &types.ScoredPod{
Pod: &types.PodMetrics{
Pod: &backend.Pod{
Address: "matched-endpoint",
Labels: map[string]string{},
TargetPods: []types.Pod{
&types.ScoredPod{
Pod: &types.PodMetrics{
Pod: &backend.Pod{
Address: "matched-endpoint",
Labels: map[string]string{},
},
},
},
},
Expand Down
3 changes: 2 additions & 1 deletion pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
}
// primary profile is used to set destination
targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPod.GetPod()
// TODO should use multiple destinations according to epp protocol. current code assumes a single target
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we bind an issue to this TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have #414 which covers this point current PR implements first half (scheduling part) and next PR should mark the issue as completed (handle request control + hndlers).
a new issue will be a duplicate

targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPods[0].GetPod()

pool, err := d.datastore.PoolGet()
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ func TestDirector_HandleRequest(t *testing.T) {
defaultSuccessfulScheduleResults := &schedulingtypes.SchedulingResult{
ProfileResults: map[string]*schedulingtypes.ProfileRunResult{
"testProfile": {
TargetPod: &schedulingtypes.ScoredPod{
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
Address: "192.168.1.100",
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
TargetPods: []schedulingtypes.Pod{
&schedulingtypes.ScoredPod{
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
Address: "192.168.1.100",
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
},
},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques

// PostCycle records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
targetPod := res.TargetPod.GetPod()
targetPod := res.TargetPods[0].GetPod()
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
Expand Down
12 changes: 6 additions & 6 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod1 was picked.
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPod: pod1})
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})

// Second request doesn't share any prefix with first one. It should be added to the cache but
// the pod score should be 0.
Expand All @@ -82,7 +82,7 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod2 was picked.
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPod: pod2})
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})

// Third request shares partial prefix with first one.
req3 := &types.LLMRequest{
Expand All @@ -101,7 +101,7 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPod: pod1})
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})

// 4th request is same as req3 except the model is different, still no match.
req4 := &types.LLMRequest{
Expand All @@ -120,7 +120,7 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPod: pod1})
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})

// 5th request shares partial prefix with 3rd one.
req5 := &types.LLMRequest{
Expand All @@ -139,7 +139,7 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1})
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
}

// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
Expand Down Expand Up @@ -180,7 +180,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
// First cycle: simulate scheduling and insert prefix info into the cache
cycleState := types.NewCycleState()
plugin.Score(context.Background(), cycleState, req, pods)
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})

// Second cycle: validate internal state
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
Expand Down
26 changes: 26 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package picker

const (
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
)

// pickerParameters defines the common parameters for all pickers
type pickerParameters struct {
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
}
74 changes: 47 additions & 27 deletions pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"context"
"encoding/json"
"fmt"
"slices"

"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
Expand All @@ -36,53 +38,71 @@ const (
var _ framework.Picker = &MaxScorePicker{}

// MaxScorePickerFactory defines the factory function for MaxScorePicker.
func MaxScorePickerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewMaxScorePicker().WithName(name), nil
func MaxScorePickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", MaxScorePickerType, err)
}
}

return NewMaxScorePicker(parameters.MaxNumOfEndpoints).WithName(name), nil
}

// NewMaxScorePicker initializes a new MaxScorePicker and returns its pointer.
func NewMaxScorePicker() *MaxScorePicker {
func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
if maxNumOfEndpoints <= 0 {
maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value
}

return &MaxScorePicker{
tn: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
random: NewRandomPicker(),
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
}
}

// MaxScorePicker picks the pod with the maximum score from the list of candidates.
// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
type MaxScorePicker struct {
tn plugins.TypedName
random *RandomPicker
}

// TypedName returns the type and name tuple of this plugin instance.
func (p *MaxScorePicker) TypedName() plugins.TypedName {
return p.tn
typedName plugins.TypedName
maxNumOfEndpoints int // maximum number of endpoints to pick
}

// WithName sets the picker's name
func (p *MaxScorePicker) WithName(name string) *MaxScorePicker {
p.tn.Name = name
p.typedName.Name = name
return p
}

// TypedName returns the type and name tuple of this plugin instance.
func (p *MaxScorePicker) TypedName() plugins.TypedName {
return p.typedName
}

// Pick selects the pod with the maximum score from the list of candidates.
func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods))

highestScorePods := []*types.ScoredPod{}
maxScore := -1.0 // pods min score is 0, putting value lower than 0 in order to find at least one pod as highest
for _, pod := range scoredPods {
if pod.Score > maxScore {
maxScore = pod.Score
highestScorePods = []*types.ScoredPod{pod}
} else if pod.Score == maxScore {
highestScorePods = append(highestScorePods, pod)
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
len(scoredPods), scoredPods))

slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
if i.Score > j.Score {
return -1
}
if i.Score < j.Score {
return 1
}
return 0
})

// if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods
if p.maxNumOfEndpoints < len(scoredPods) {
scoredPods = scoredPods[:p.maxNumOfEndpoints]
}

if len(highestScorePods) > 1 {
return p.random.Pick(ctx, cycleState, highestScorePods) // pick randomly from the highest score pods
targetPods := make([]types.Pod, len(scoredPods))
for i, scoredPod := range scoredPods {
targetPods[i] = scoredPod
}

return &types.ProfileRunResult{TargetPod: highestScorePods[0]}
return &types.ProfileRunResult{TargetPods: targetPods}

}
Loading