diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go
index 1340ff6dc..81ec37347 100644
--- a/cmd/epp/runner/runner.go
+++ b/cmd/epp/runner/runner.go
@@ -47,7 +47,6 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
@@ -292,7 +291,6 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)
schedulerProfile := framework.NewSchedulerProfile().
- WithFilters(filter.NewSubsetFilter()).
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
WithPicker(picker.NewMaxScorePicker())
diff --git a/conformance/testing-epp/sheduler_test.go b/conformance/testing-epp/scheduler_test.go
similarity index 100%
rename from conformance/testing-epp/sheduler_test.go
rename to conformance/testing-epp/scheduler_test.go
diff --git a/pkg/bbr/handlers/server.go b/pkg/bbr/handlers/server.go
index eb6b93d67..a5803806b 100644
--- a/pkg/bbr/handlers/server.go
+++ b/pkg/bbr/handlers/server.go
@@ -118,7 +118,7 @@ type streamedBody struct {
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) {
loggerVerbose := logger.V(logutil.VERBOSE)
- var requestBody map[string]interface{}
+ var requestBody map[string]any
if s.streaming {
streamedBody.body = append(streamedBody.body, body.Body...)
// In the stream case, we can receive multiple request bodies.
diff --git a/pkg/epp/backend/metrics/metrics_state.go b/pkg/epp/backend/metrics/metrics_state.go
index 3be7d535a..0215ac05f 100644
--- a/pkg/epp/backend/metrics/metrics_state.go
+++ b/pkg/epp/backend/metrics/metrics_state.go
@@ -21,8 +21,8 @@ import (
"time"
)
-// newMetricsState initializes a new MetricsState and returns its pointer.
-func newMetricsState() *MetricsState {
+// NewMetricsState initializes a new MetricsState and returns its pointer.
+func NewMetricsState() *MetricsState {
return &MetricsState{
ActiveModels: make(map[string]int),
WaitingModels: make(map[string]int),
diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go
index bb78c2b34..80b708555 100644
--- a/pkg/epp/backend/metrics/types.go
+++ b/pkg/epp/backend/metrics/types.go
@@ -51,7 +51,7 @@ func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
}
pm.pod.Store(pod)
- pm.metrics.Store(newMetricsState())
+ pm.metrics.Store(NewMetricsState())
pm.startRefreshLoop(parentCtx)
return pm
diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go
index 7284628cd..a776bd1d9 100644
--- a/pkg/epp/handlers/response.go
+++ b/pkg/epp/handlers/response.go
@@ -34,11 +34,7 @@ const (
)
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
-func (s *StreamingServer) HandleResponseBody(
- ctx context.Context,
- reqCtx *RequestContext,
- response map[string]interface{},
-) (*RequestContext, error) {
+func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
logger := log.FromContext(ctx)
responseBytes, err := json.Marshal(response)
if err != nil {
@@ -46,7 +42,7 @@ func (s *StreamingServer) HandleResponseBody(
return reqCtx, err
}
if response["usage"] != nil {
- usg := response["usage"].(map[string]interface{})
+ usg := response["usage"].(map[string]any)
usage := Usage{
PromptTokens: int(usg["prompt_tokens"].(float64)),
CompletionTokens: int(usg["completion_tokens"].(float64)),
@@ -68,11 +64,7 @@ func (s *StreamingServer) HandleResponseBody(
}
// The function is to handle streaming response if the modelServer is streaming.
-func (s *StreamingServer) HandleResponseBodyModelStreaming(
- ctx context.Context,
- reqCtx *RequestContext,
- responseText string,
-) {
+func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
if strings.Contains(responseText, streamingEndMsg) {
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
@@ -160,10 +152,7 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
//
// If include_usage is not included in the request, `data: [DONE]` is returned separately, which
// indicates end of streaming.
-func parseRespForUsage(
- ctx context.Context,
- responseText string,
-) ResponseBody {
+func parseRespForUsage(ctx context.Context, responseText string) ResponseBody {
response := ResponseBody{}
logger := log.FromContext(ctx)
diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go
index bfe5a6293..b79f4ee46 100644
--- a/pkg/epp/handlers/response_test.go
+++ b/pkg/epp/handlers/response_test.go
@@ -86,7 +86,7 @@ func TestHandleResponseBody(t *testing.T) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
- var responseMap map[string]interface{}
+ var responseMap map[string]any
marshalErr := json.Unmarshal(test.body, &responseMap)
if marshalErr != nil {
t.Error(marshalErr, "Error unmarshaling request body")
diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go
index e31b53a16..3ac13c892 100644
--- a/pkg/epp/handlers/server.go
+++ b/pkg/epp/handlers/server.go
@@ -112,7 +112,7 @@ type RequestContext struct {
type Request struct {
Headers map[string]string
- Body map[string]interface{}
+ Body map[string]any
Metadata map[string]any
}
type Response struct {
@@ -143,7 +143,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
RequestState: RequestReceived,
Request: &Request{
Headers: make(map[string]string),
- Body: make(map[string]interface{}),
+ Body: make(map[string]any),
Metadata: make(map[string]any),
},
Response: &Response{
@@ -152,7 +152,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
}
var body []byte
- var responseBody map[string]interface{}
+ var responseBody map[string]any
// Create error handling var as each request should only report once for
// error metrics. This doesn't cover the error "Cannot receive stream request" because
@@ -308,7 +308,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
// Handle the err and fire an immediate response.
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
- resp, err := BuildErrResponse(err)
+ resp, err := buildErrResponse(err)
if err != nil {
return err
}
@@ -389,7 +389,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return nil
}
-func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
+func buildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
var resp *extProcPb.ProcessingResponse
switch errutil.CanonicalCode(err) {
@@ -416,6 +416,17 @@ func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
},
},
}
+ // This code can be returned by the director when there are no candidate pods for the request scheduling.
+ case errutil.ServiceUnavailable:
+ resp = &extProcPb.ProcessingResponse{
+ Response: &extProcPb.ProcessingResponse_ImmediateResponse{
+ ImmediateResponse: &extProcPb.ImmediateResponse{
+ Status: &envoyTypePb.HttpStatus{
+ Code: envoyTypePb.StatusCode_ServiceUnavailable,
+ },
+ },
+ },
+ }
// This code can be returned when users provide invalid json request.
case errutil.BadRequest:
resp = &extProcPb.ProcessingResponse{
diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go
index 59cc44246..78effeda0 100644
--- a/pkg/epp/requestcontrol/director.go
+++ b/pkg/epp/requestcontrol/director.go
@@ -24,12 +24,14 @@ import (
"math/rand"
"net"
"strconv"
+ "strings"
"time"
"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -39,6 +41,11 @@ import (
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)
+const (
+ subsetHintNamespace = "envoy.lb.subset_hint"
+ subsetHintKey = "x-gateway-destination-endpoint-subset"
+)
+
// Scheduler defines the interface required by the Director for scheduling.
type Scheduler interface {
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
@@ -118,12 +125,12 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}
// Prepare LLMRequest (needed for both saturation detection and Scheduler)
- reqCtx.SchedulingRequest = schedulingtypes.NewLLMRequest(
- reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
- reqCtx.ResolvedTargetModel,
- prompt,
- reqCtx.Request.Headers,
- reqCtx.Request.Metadata)
+ reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
+ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
+ TargetModel: reqCtx.ResolvedTargetModel,
+ Prompt: prompt,
+ Headers: reqCtx.Request.Headers,
+ }
logger = logger.WithValues("model", reqCtx.Model, "resolvedTargetModel", reqCtx.ResolvedTargetModel, "criticality", requestCriticality)
@@ -135,11 +142,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
return reqCtx, err
}
- // --- 3. Call Scheduler ---
- // Snapshot pod metrics from the datastore to:
- // 1. Reduce concurrent access to the datastore.
- // 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
- candidatePods := schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
+ // --- 3. Call Scheduler (with the relevant candidate pods) ---
+ candidatePods := d.getCandidatePodsForScheduling(ctx, reqCtx.Request.Metadata)
+ if len(candidatePods) == 0 {
+ return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
+ }
results, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, candidatePods)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
@@ -177,6 +184,52 @@ func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2
return nil
}
+// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
+// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
+// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
+// Snapshot pod metrics from the datastore to:
+// 1. Reduce concurrent access to the datastore.
+// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
+func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []schedulingtypes.Pod {
+ loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
+
+ subsetMap, found := requestMetadata[subsetHintNamespace].(map[string]any)
+ if !found {
+ return schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
+ }
+
+ // Check if endpoint key is present in the subset map and ensure there is at least one value
+ endpointSubsetList, found := subsetMap[subsetHintKey].([]any)
+ if !found {
+ return schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
+ } else if len(endpointSubsetList) == 0 {
+ loggerTrace.Info("found empty subset filter in request metadata, filtering all pods")
+ return []schedulingtypes.Pod{}
+ }
+
+ // Create a map of endpoint addresses for easy lookup
+ endpoints := make(map[string]bool)
+ for _, endpoint := range endpointSubsetList {
+ // Extract address from endpoint
+ // The endpoint is formatted as "
:" (ex. "10.0.1.0:8080")
+ epStr := strings.Split(endpoint.(string), ":")[0]
+ endpoints[epStr] = true
+ }
+
+ podTotalCount := 0
+ podFitleredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool {
+ podTotalCount++
+ if _, found := endpoints[pm.GetPod().Address]; found {
+ return true
+ }
+ return false
+ })
+
+ loggerTrace.Info("filtered candidate pods by subset filtering", "podTotalCount", podTotalCount, "filteredCount", len(podFitleredList))
+
+ return schedulingtypes.ToSchedulerPodMetrics(podFitleredList)
+}
+
// prepareRequest populates the RequestContext and calls the registered PreRequest plugins
// for allowing plugging customized logic based on the scheduling results.
func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {
diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go
index aa396666f..0f214b830 100644
--- a/pkg/epp/requestcontrol/director_test.go
+++ b/pkg/epp/requestcontrol/director_test.go
@@ -23,6 +23,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -144,7 +145,7 @@ func TestDirector_HandleRequest(t *testing.T) {
tests := []struct {
name string
- reqBodyMap map[string]interface{}
+ reqBodyMap map[string]any
mockSaturationDetector *mockSaturationDetector
schedulerMockSetup func(m *mockScheduler)
wantErrCode string // Expected errutil code string
@@ -153,7 +154,7 @@ func TestDirector_HandleRequest(t *testing.T) {
}{
{
name: "successful completions request (critical, saturation ignored)",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
"prompt": "critical prompt",
},
@@ -174,10 +175,10 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "successful chat completions request (critical, saturation ignored)",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
- "messages": []interface{}{
- map[string]interface{}{
+ "messages": []any{
+ map[string]any{
"role": "user",
"content": "critical prompt",
},
@@ -199,14 +200,14 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "successful chat completions request with multiple messages (critical, saturation ignored)",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
- "messages": []interface{}{
- map[string]interface{}{
+ "messages": []any{
+ map[string]any{
"role": "developer",
"content": "You are a helpful assistant.",
},
- map[string]interface{}{
+ map[string]any{
"role": "user",
"content": "Hello!",
},
@@ -228,7 +229,7 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "successful completions request (sheddable, not saturated)",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": modelSheddable,
"prompt": "sheddable prompt",
},
@@ -249,7 +250,7 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "successful request with target model resolution",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": modelWithResolvedTarget,
"prompt": "prompt for target resolution",
},
@@ -283,7 +284,7 @@ func TestDirector_HandleRequest(t *testing.T) {
TargetEndpoint: "192.168.1.100:8000",
},
wantMutatedBodyModel: "food-review-1",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": "food-review-1",
"prompt": "test prompt",
},
@@ -292,7 +293,7 @@ func TestDirector_HandleRequest(t *testing.T) {
{
name: "request dropped (sheddable, saturated)",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": modelSheddable,
"prompt": "sheddable prompt",
},
@@ -301,27 +302,27 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "model not found, expect err",
- reqBodyMap: map[string]interface{}{"prompt": "p"},
+ reqBodyMap: map[string]any{"prompt": "p"},
mockSaturationDetector: &mockSaturationDetector{isSaturated: false},
wantErrCode: errutil.BadRequest,
},
{
name: "prompt or messages not found, expect err",
- reqBodyMap: map[string]interface{}{"model": model},
+ reqBodyMap: map[string]any{"model": model},
wantErrCode: errutil.BadRequest,
},
{
name: "empty messages, expect err",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
- "messages": []interface{}{},
+ "messages": []any{},
},
wantErrCode: errutil.BadRequest,
},
{
name: "scheduler returns error",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
"prompt": "prompt that causes scheduler error",
},
@@ -332,7 +333,7 @@ func TestDirector_HandleRequest(t *testing.T) {
},
{
name: "scheduler returns nil result and nil error",
- reqBodyMap: map[string]interface{}{
+ reqBodyMap: map[string]any{
"model": model,
"prompt": "prompt for nil,nil scheduler return",
},
@@ -355,7 +356,7 @@ func TestDirector_HandleRequest(t *testing.T) {
reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
// Create a copy of the map for each test run to avoid mutation issues.
- Body: make(map[string]interface{}),
+ Body: make(map[string]any),
Headers: map[string]string{
requtil.RequestIdHeaderKey: "test-req-id-" + test.name, // Ensure a default request ID
},
@@ -396,6 +397,138 @@ func TestDirector_HandleRequest(t *testing.T) {
}
}
+// TestGetCandidatePodsForScheduling is testing getCandidatePodsForScheduling and more specifically the functionality of SubsetFilter.
+func TestGetCandidatePodsForScheduling(t *testing.T) {
+ var makeFilterMetadata = func(data []any) map[string]any {
+ return map[string]any{
+ "envoy.lb.subset_hint": map[string]any{
+ "x-gateway-destination-endpoint-subset": data,
+ },
+ }
+ }
+
+ testInput := []*corev1.Pod{
+ {
+ ObjectMeta: metav1.ObjectMeta{
+ Name: "pod1",
+ },
+ Status: corev1.PodStatus{
+ PodIP: "10.0.0.1",
+ },
+ },
+ {
+ ObjectMeta: metav1.ObjectMeta{
+ Name: "pod2",
+ },
+ Status: corev1.PodStatus{
+ PodIP: "10.0.0.2",
+ },
+ },
+ }
+
+ outputPod1 := &backend.Pod{
+ NamespacedName: types.NamespacedName{Name: "pod1"},
+ Address: "10.0.0.1",
+ Labels: map[string]string{},
+ }
+
+ outputPod2 := &backend.Pod{
+ NamespacedName: types.NamespacedName{Name: "pod2"},
+ Address: "10.0.0.2",
+ Labels: map[string]string{},
+ }
+
+ tests := []struct {
+ name string
+ metadata map[string]any
+ output []schedulingtypes.Pod
+ }{
+ {
+ name: "SubsetFilter, filter not present — return all pods",
+ metadata: map[string]any{},
+ output: []schedulingtypes.Pod{
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod1,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod2,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ },
+ },
+ {
+ name: "SubsetFilter, namespace present filter not present — return all pods",
+ metadata: map[string]any{"envoy.lb.subset_hint": map[string]any{}},
+ output: []schedulingtypes.Pod{
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod1,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod2,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ },
+ },
+ {
+ name: "SubsetFilter, filter present with empty list — return error",
+ metadata: makeFilterMetadata([]any{}),
+ output: []schedulingtypes.Pod{},
+ },
+ {
+ name: "SubsetFilter, subset with one matching pod",
+ metadata: makeFilterMetadata([]any{"10.0.0.1"}),
+ output: []schedulingtypes.Pod{
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod1,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ },
+ },
+ {
+ name: "SubsetFilter, subset with multiple matching pods",
+ metadata: makeFilterMetadata([]any{"10.0.0.1", "10.0.0.2", "10.0.0.3"}),
+ output: []schedulingtypes.Pod{
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod1,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ &schedulingtypes.PodMetrics{
+ Pod: outputPod2,
+ MetricsState: backendmetrics.NewMetricsState(),
+ },
+ },
+ },
+ {
+ name: "SubsetFilter, subset with no matching pods",
+ metadata: makeFilterMetadata([]any{"10.0.0.3"}),
+ output: []schedulingtypes.Pod{},
+ },
+ }
+
+ pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
+ ds := datastore.NewDatastore(t.Context(), pmf)
+ for _, testPod := range testInput {
+ ds.PodUpdateOrAddIfNotExist(testPod)
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockSaturationDetector{}, NewConfig())
+
+ got := director.getCandidatePodsForScheduling(context.Background(), test.metadata)
+
+ diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b schedulingtypes.Pod) bool {
+ return a.GetPod().NamespacedName.String() < b.GetPod().NamespacedName.String()
+ }))
+ if diff != "" {
+ t.Errorf("Unexpected output (-want +got): %v", diff)
+ }
+ })
+ }
+}
+
func TestRandomWeightedDraw(t *testing.T) {
logger := logutil.NewTestLogger()
// Note: These tests verify deterministic outcomes for a fixed seed (420).
diff --git a/pkg/epp/scheduling/framework/plugins/filter/filter_test.go b/pkg/epp/scheduling/framework/plugins/filter/filter_test.go
index 75cd790d0..978e91c3e 100644
--- a/pkg/epp/scheduling/framework/plugins/filter/filter_test.go
+++ b/pkg/epp/scheduling/framework/plugins/filter/filter_test.go
@@ -256,145 +256,6 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
}
}
-func TestSubsettingFilter(t *testing.T) {
- var makeFilterMetadata = func(data []interface{}) map[string]any {
- return map[string]any{
- "envoy.lb.subset_hint": map[string]any{
- "x-gateway-destination-endpoint-subset": data,
- },
- }
- }
-
- tests := []struct {
- name string
- metadata map[string]any
- filter framework.Filter
- input []types.Pod
- output []types.Pod
- }{
- {
- name: "SubsetFilter, filter not present — return all pods",
- filter: &SubsetFilter{},
- metadata: map[string]any{},
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- },
- {
- name: "SubsetFilter, namespace present filter not present — return all pods",
- filter: &SubsetFilter{},
- metadata: map[string]any{"envoy.lb.subset_hint": map[string]any{}},
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- },
- {
- name: "SubsetFilter, filter present with empty list — return no pods",
- filter: &SubsetFilter{},
- metadata: makeFilterMetadata([]interface{}{}),
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{},
- },
- {
- name: "SubsetFilter, subset with one matching pod",
- metadata: makeFilterMetadata([]interface{}{"10.0.0.1"}),
- filter: &SubsetFilter{},
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- },
- },
- {
- name: "SubsetFilter, subset with multiple matching pods",
- metadata: makeFilterMetadata([]interface{}{"10.0.0.1", "10.0.0.2", "10.0.0.3"}),
- filter: &SubsetFilter{},
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- },
- {
- name: "SubsetFilter, subset with no matching pods",
- metadata: makeFilterMetadata([]interface{}{"10.0.0.3"}),
- filter: &SubsetFilter{},
- input: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.1"},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{Address: "10.0.0.2"},
- },
- },
- output: []types.Pod{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- req := types.NewLLMRequest(uuid.NewString(), "", "", nil, test.metadata)
- got := test.filter.Filter(context.Background(), types.NewCycleState(), req, test.input)
-
- if diff := cmp.Diff(test.output, got); diff != "" {
- t.Errorf("Unexpected output (-want +got): %v", diff)
- }
- })
- }
-}
-
// TestDecisionTreeFilterFactory tests that the DecisionTreeFilterFactory function
// properly instantiates DecisionTreeFilter instances
func TestDecisionTreeFilterFactory(t *testing.T) {
diff --git a/pkg/epp/scheduling/framework/plugins/filter/subsetting_filter.go b/pkg/epp/scheduling/framework/plugins/filter/subsetting_filter.go
deleted file mode 100644
index 2962b9511..000000000
--- a/pkg/epp/scheduling/framework/plugins/filter/subsetting_filter.go
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
-Copyright 2025 The Kubernetes Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package filter
-
-import (
- "context"
- "strings"
-
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
-)
-
-const (
- SubsetFilterType = "subset"
-
- subsetHintKey = "x-gateway-destination-endpoint-subset"
- subsetHintNamespace = "envoy.lb.subset_hint"
-)
-
-// compile-time type assertion
-var _ framework.Filter = &SubsetFilter{}
-
-// NewSubsetFilter initializes a new SubsetFilter.
-func NewSubsetFilter() *SubsetFilter {
- return &SubsetFilter{}
-}
-
-// SubsetFilter filters Pods based on the subset hint provided by the proxy via filterMetadata.
-type SubsetFilter struct{}
-
-// Name returns the name of the filter.
-func (f *SubsetFilter) Name() string {
- return "subset-hint"
-}
-
-// Type returns the type of the filter.
-func (f *SubsetFilter) Type() string {
- return SubsetFilterType
-}
-
-// Filter filters out pods that are not in the subset provided in filterMetadata.
-func (f *SubsetFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
- // Check if subset namespace key is present in the metadata map
- subsetMap, found := request.GetMetadata()[subsetHintNamespace].(map[string]any)
- if !found {
- return pods
- }
-
- // Check if endpoint key is present in the subset map and ensure there is at least one value
- endpointSubsetList, found := subsetMap[subsetHintKey].([]interface{})
- if !found {
- return pods
- } else if len(endpointSubsetList) == 0 {
- return []types.Pod{}
- }
-
- // Create a map of endpoint addrs for easy lookup
- endpoints := make(map[string]bool)
- for _, endpoint := range endpointSubsetList {
- // Extract address from endpoint
- // The endpoint is formatted as ":" (ex. "10.0.1.0:8080")
- epStr := strings.Split(endpoint.(string), ":")[0]
- endpoints[epStr] = true
- }
-
- // Filter based on address
- filteredPods := []types.Pod{}
- for _, pod := range pods {
- if _, found := endpoints[pod.GetPod().Address]; found {
- filteredPods = append(filteredPods, pod)
- }
- }
-
- return filteredPods
-}
diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go
index f70005799..b848b26dc 100644
--- a/pkg/epp/scheduling/scheduler.go
+++ b/pkg/epp/scheduling/scheduler.go
@@ -44,7 +44,6 @@ func NewScheduler() *Scheduler {
// it's possible to call NewSchedulerWithConfig to pass a different scheduler config.
// For build time plugins changes, it's recommended to call in main.go to NewSchedulerWithConfig.
loraAffinityFilter := filter.NewLoraAffinityFilter(config.Conf.LoraAffinityThreshold)
- endpointSubsetFilter := filter.NewSubsetFilter()
leastQueueFilter := filter.NewLeastQueueFilter()
leastKvCacheFilter := filter.NewLeastKVCacheFilter()
@@ -71,7 +70,7 @@ func NewScheduler() *Scheduler {
}
defaultProfile := framework.NewSchedulerProfile().
- WithFilters(endpointSubsetFilter, lowLatencyFilter).
+ WithFilters(lowLatencyFilter).
WithPicker(&picker.RandomPicker{})
profileHandler := profile.NewSingleProfileHandler()
diff --git a/pkg/epp/scheduling/types/cycle_state.go b/pkg/epp/scheduling/types/cycle_state.go
index 97381dd68..789ece245 100644
--- a/pkg/epp/scheduling/types/cycle_state.go
+++ b/pkg/epp/scheduling/types/cycle_state.go
@@ -59,7 +59,7 @@ func (c *CycleState) Clone() *CycleState {
}
copy := NewCycleState()
// Safe copy storage in case of overwriting.
- c.storage.Range(func(k, v interface{}) bool {
+ c.storage.Range(func(k, v any) bool {
copy.storage.Store(k, v.(StateData).Clone())
return true
})
diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go
index d58729714..451384751 100644
--- a/pkg/epp/scheduling/types/types.go
+++ b/pkg/epp/scheduling/types/types.go
@@ -33,29 +33,12 @@ type LLMRequest struct {
Prompt string
// Headers is a map of the request headers.
Headers map[string]string
-
- // metadata is a map of metadata in the request
- metadata map[string]any
-}
-
-func NewLLMRequest(reqID, targetModel, prompt string, headers map[string]string, metadata map[string]any) *LLMRequest {
- return &LLMRequest{
- RequestId: reqID,
- TargetModel: targetModel,
- Prompt: prompt,
- Headers: headers,
- metadata: metadata,
- }
}
func (r *LLMRequest) String() string {
return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers)
}
-func (r *LLMRequest) GetMetadata() map[string]any {
- return r.metadata
-}
-
type Pod interface {
GetPod() *backend.Pod
GetMetrics() *backendmetrics.MetricsState
diff --git a/pkg/epp/util/error/error.go b/pkg/epp/util/error/error.go
index d580d66aa..264830980 100644
--- a/pkg/epp/util/error/error.go
+++ b/pkg/epp/util/error/error.go
@@ -30,6 +30,7 @@ const (
Unknown = "Unknown"
BadRequest = "BadRequest"
Internal = "Internal"
+ ServiceUnavailable = "ServiceUnavailable"
ModelServerError = "ModelServerError"
BadConfiguration = "BadConfiguration"
InferencePoolResourceExhausted = "InferencePoolResourceExhausted"
diff --git a/pkg/epp/util/error/error_test.go b/pkg/epp/util/error/error_test.go
index 7393c44c2..8905e847f 100644
--- a/pkg/epp/util/error/error_test.go
+++ b/pkg/epp/util/error/error_test.go
@@ -43,6 +43,14 @@ func TestError_Error(t *testing.T) {
},
want: "inference gateway: Internal - unexpected condition",
},
+ {
+ name: "ServiceUnavailable error",
+ err: Error{
+ Code: ServiceUnavailable,
+ Msg: "service unavailable",
+ },
+ want: "inference gateway: ServiceUnavailable - service unavailable",
+ },
{
name: "ModelServerError",
err: Error{
@@ -124,6 +132,14 @@ func TestCanonicalCode(t *testing.T) {
},
want: Internal,
},
+ {
+ name: "Error type with ServiceUnavailable code",
+ err: Error{
+ Code: ServiceUnavailable,
+ Msg: "Service unavailable error",
+ },
+ want: ServiceUnavailable,
+ },
{
name: "Error type with ModelServerError code",
err: Error{
@@ -205,6 +221,7 @@ func TestErrorConstants(t *testing.T) {
Unknown: "Unknown",
BadRequest: "BadRequest",
Internal: "Internal",
+ ServiceUnavailable: "ServiceUnavailable",
ModelServerError: "ModelServerError",
BadConfiguration: "BadConfiguration",
InferencePoolResourceExhausted: "InferencePoolResourceExhausted",
diff --git a/pkg/epp/util/logging/fatal.go b/pkg/epp/util/logging/fatal.go
index d8a9a9379..ddc15c400 100644
--- a/pkg/epp/util/logging/fatal.go
+++ b/pkg/epp/util/logging/fatal.go
@@ -25,7 +25,7 @@ import (
// Fatal calls logger.Error followed by os.Exit(1).
//
// This is a utility function and should not be used in production code!
-func Fatal(logger logr.Logger, err error, msg string, keysAndValues ...interface{}) {
+func Fatal(logger logr.Logger, err error, msg string, keysAndValues ...any) {
logger.Error(err, msg, keysAndValues...)
os.Exit(1)
}
diff --git a/pkg/epp/util/request/body.go b/pkg/epp/util/request/body.go
index 83a600f08..46de1fa54 100644
--- a/pkg/epp/util/request/body.go
+++ b/pkg/epp/util/request/body.go
@@ -22,14 +22,14 @@ import (
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)
-func ExtractPromptFromRequestBody(body map[string]interface{}) (string, error) {
+func ExtractPromptFromRequestBody(body map[string]any) (string, error) {
if _, ok := body["messages"]; ok {
return extractPromptFromMessagesField(body)
}
return extractPromptField(body)
}
-func extractPromptField(body map[string]interface{}) (string, error) {
+func extractPromptField(body map[string]any) (string, error) {
prompt, ok := body["prompt"]
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
@@ -41,12 +41,12 @@ func extractPromptField(body map[string]interface{}) (string, error) {
return promptStr, nil
}
-func extractPromptFromMessagesField(body map[string]interface{}) (string, error) {
+func extractPromptFromMessagesField(body map[string]any) (string, error) {
messages, ok := body["messages"]
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"}
}
- messageList, ok := messages.([]interface{})
+ messageList, ok := messages.([]any)
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"}
}
@@ -56,7 +56,7 @@ func extractPromptFromMessagesField(body map[string]interface{}) (string, error)
prompt := ""
for _, msg := range messageList {
- msgMap, ok := msg.(map[string]interface{})
+ msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
diff --git a/pkg/epp/util/request/body_test.go b/pkg/epp/util/request/body_test.go
index 696bfd501..ce5a93921 100644
--- a/pkg/epp/util/request/body_test.go
+++ b/pkg/epp/util/request/body_test.go
@@ -23,23 +23,23 @@ import (
func TestExtractPromptFromRequestBody(t *testing.T) {
tests := []struct {
name string
- body map[string]interface{}
+ body map[string]any
want string
wantErr bool
errType error
}{
{
name: "chat completions request body",
- body: map[string]interface{}{
+ body: map[string]any{
"model": "test",
- "messages": []interface{}{
- map[string]interface{}{
+ "messages": []any{
+ map[string]any{
"role": "system", "content": "this is a system message",
},
- map[string]interface{}{
+ map[string]any{
"role": "user", "content": "hello",
},
- map[string]interface{}{
+ map[string]any{
"role": "assistant", "content": "hi, what can I do for you?",
},
},
@@ -50,7 +50,7 @@ func TestExtractPromptFromRequestBody(t *testing.T) {
},
{
name: "completions request body",
- body: map[string]interface{}{
+ body: map[string]any{
"model": "test",
"prompt": "test prompt",
},
@@ -58,16 +58,16 @@ func TestExtractPromptFromRequestBody(t *testing.T) {
},
{
name: "invalid prompt format",
- body: map[string]interface{}{
+ body: map[string]any{
"model": "test",
- "prompt": []interface{}{
- map[string]interface{}{
+ "prompt": []any{
+ map[string]any{
"role": "system", "content": "this is a system message",
},
- map[string]interface{}{
+ map[string]any{
"role": "user", "content": "hello",
},
- map[string]interface{}{
+ map[string]any{
"role": "assistant", "content": "hi, what can I",
},
},
@@ -76,9 +76,9 @@ func TestExtractPromptFromRequestBody(t *testing.T) {
},
{
name: "invalid messaged format",
- body: map[string]interface{}{
+ body: map[string]any{
"model": "test",
- "messages": map[string]interface{}{
+ "messages": map[string]any{
"role": "system", "content": "this is a system message",
},
},
@@ -86,7 +86,7 @@ func TestExtractPromptFromRequestBody(t *testing.T) {
},
{
name: "prompt does not exist",
- body: map[string]interface{}{
+ body: map[string]any{
"model": "test",
},
wantErr: true,
@@ -110,25 +110,25 @@ func TestExtractPromptFromRequestBody(t *testing.T) {
func TestExtractPromptField(t *testing.T) {
tests := []struct {
name string
- body map[string]interface{}
+ body map[string]any
want string
wantErr bool
}{
{
name: "valid prompt",
- body: map[string]interface{}{
+ body: map[string]any{
"prompt": "test prompt",
},
want: "test prompt",
},
{
name: "prompt not found",
- body: map[string]interface{}{},
+ body: map[string]any{},
wantErr: true,
},
{
name: "non-string prompt",
- body: map[string]interface{}{
+ body: map[string]any{
"prompt": 123,
},
wantErr: true,
@@ -152,23 +152,23 @@ func TestExtractPromptField(t *testing.T) {
func TestExtractPromptFromMessagesField(t *testing.T) {
tests := []struct {
name string
- body map[string]interface{}
+ body map[string]any
want string
wantErr bool
}{
{
name: "valid messages",
- body: map[string]interface{}{
- "messages": []interface{}{
- map[string]interface{}{"role": "user", "content": "test1"},
- map[string]interface{}{"role": "assistant", "content": "test2"},
+ body: map[string]any{
+ "messages": []any{
+ map[string]any{"role": "user", "content": "test1"},
+ map[string]any{"role": "assistant", "content": "test2"},
},
},
want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n",
},
{
name: "invalid messages format",
- body: map[string]interface{}{
+ body: map[string]any{
"messages": "invalid",
},
wantErr: true,
diff --git a/pkg/epp/util/request/metadata_test.go b/pkg/epp/util/request/metadata_test.go
index 0a0c71e35..f03d83e58 100644
--- a/pkg/epp/util/request/metadata_test.go
+++ b/pkg/epp/util/request/metadata_test.go
@@ -27,9 +27,9 @@ import (
func TestExtractMetadataValues(t *testing.T) {
var makeFilterMetadata = func() map[string]*structpb.Struct {
- structVal, _ := structpb.NewStruct(map[string]interface{}{
+ structVal, _ := structpb.NewStruct(map[string]any{
"hello": "world",
- "random-key": []interface{}{"hello", "world"},
+ "random-key": []any{"hello", "world"},
})
return map[string]*structpb.Struct{
@@ -46,9 +46,9 @@ func TestExtractMetadataValues(t *testing.T) {
name: "Exact match",
metadata: makeFilterMetadata(),
expected: map[string]any{
- "key-1": map[string]interface{}{
+ "key-1": map[string]any{
"hello": "world",
- "random-key": []interface{}{"hello", "world"},
+ "random-key": []any{"hello", "world"},
},
},
},
diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go
index 2b9bb764a..6d439d17d 100644
--- a/test/integration/epp/hermetic_test.go
+++ b/test/integration/epp/hermetic_test.go
@@ -821,9 +821,9 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) {
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
- Code: envoyTypePb.StatusCode_TooManyRequests,
+ Code: envoyTypePb.StatusCode_ServiceUnavailable,
},
- Body: []byte("inference gateway: InferencePoolResourceExhausted - failed to find target pod: failed to run scheduler profile 'default'"),
+ Body: []byte("inference gateway: ServiceUnavailable - failed to find candidate pods for serving the request"),
},
},
},
diff --git a/test/integration/util.go b/test/integration/util.go
index a1baa33d2..925107bf8 100644
--- a/test/integration/util.go
+++ b/test/integration/util.go
@@ -87,7 +87,7 @@ func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessCli
}
func GenerateRequest(logger logr.Logger, prompt, model string, filterMetadata []string) *extProcPb.ProcessingRequest {
- j := map[string]interface{}{
+ j := map[string]any{
"prompt": prompt,
"max_tokens": 100,
"temperature": 0,
@@ -139,12 +139,12 @@ func GenerateStreamedRequestSet(logger logr.Logger, prompt, model string, filter
func GenerateRequestMetadata(filterMetadata []string) map[string]*structpb.Struct {
metadata := make(map[string]*structpb.Struct)
- interfaceList := make([]interface{}, len(filterMetadata))
+ interfaceList := make([]any, len(filterMetadata))
for i, val := range filterMetadata {
interfaceList[i] = val
}
if filterMetadata != nil {
- structVal, _ := structpb.NewStruct(map[string]interface{}{
+ structVal, _ := structpb.NewStruct(map[string]any{
"x-gateway-destination-endpoint-subset": interfaceList,
})
metadata["envoy.lb.subset_hint"] = structVal