Skip to content

Commit 1c52369

Browse files
nirrozenbaumBenjaminBraunDev
authored andcommitted
convert subset filter from a plugin to logic in director (kubernetes-sigs#1088)
* convert subset filter from a plugin to logic in director Signed-off-by: Nir Rozenbaum <[email protected]> * replace interface{} with any Signed-off-by: Nir Rozenbaum <[email protected]> * make linter happy Signed-off-by: Nir Rozenbaum <[email protected]> * address code review comments Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent c957191 commit 1c52369

File tree

23 files changed

+845
-910
lines changed

23 files changed

+845
-910
lines changed

cmd/epp/runner/runner.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ import (
5050
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector"
5151
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
5252
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
53-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter"
5453
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
5554
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
5655
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
@@ -321,7 +320,6 @@ func (r *Runner) initializeScheduler(datastore datastore.Datastore) (*scheduling
321320
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)
322321

323322
schedulerProfile := framework.NewSchedulerProfile().
324-
WithFilters(filter.NewSubsetFilter()).
325323
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
326324
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
327325
WithPicker(picker.NewMaxScorePicker())

pkg/bbr/handlers/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ type streamedBody struct {
118118
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) {
119119
loggerVerbose := logger.V(logutil.VERBOSE)
120120

121-
var requestBody map[string]interface{}
121+
var requestBody map[string]any
122122
if s.streaming {
123123
streamedBody.body = append(streamedBody.body, body.Body...)
124124
// In the stream case, we can receive multiple request bodies.

pkg/epp/backend/metrics/metrics_state.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"time"
2222
)
2323

24-
// newMetricsState initializes a new MetricsState and returns its pointer.
25-
func newMetricsState() *MetricsState {
24+
// NewMetricsState initializes a new MetricsState and returns its pointer.
25+
func NewMetricsState() *MetricsState {
2626
return &MetricsState{
2727
ActiveModels: make(map[string]int),
2828
WaitingModels: make(map[string]int),

pkg/epp/backend/metrics/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.
5151
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
5252
}
5353
pm.pod.Store(pod)
54-
pm.metrics.Store(newMetricsState())
54+
pm.metrics.Store(NewMetricsState())
5555

5656
pm.startRefreshLoop(parentCtx)
5757
return pm

pkg/epp/handlers/response.go

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,15 @@ const (
3737
)
3838

3939
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
40-
func (s *StreamingServer) HandleResponseBody(
41-
ctx context.Context,
42-
reqCtx *RequestContext,
43-
response map[string]interface{},
44-
) (*RequestContext, error) {
40+
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
4541
logger := log.FromContext(ctx)
4642
responseBytes, err := json.Marshal(response)
4743
if err != nil {
4844
logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody")
4945
return reqCtx, err
5046
}
5147
if response["usage"] != nil {
52-
usg := response["usage"].(map[string]interface{})
48+
usg := response["usage"].(map[string]any)
5349
usage := Usage{
5450
PromptTokens: int(usg["prompt_tokens"].(float64)),
5551
CompletionTokens: int(usg["completion_tokens"].(float64)),
@@ -71,11 +67,7 @@ func (s *StreamingServer) HandleResponseBody(
7167
}
7268

7369
// The function is to handle streaming response if the modelServer is streaming.
74-
func (s *StreamingServer) HandleResponseBodyModelStreaming(
75-
ctx context.Context,
76-
reqCtx *RequestContext,
77-
responseText string,
78-
) {
70+
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
7971
if strings.Contains(responseText, streamingEndMsg) {
8072
resp := parseRespForUsage(ctx, responseText)
8173
reqCtx.Usage = resp.Usage
@@ -280,10 +272,7 @@ func (s *StreamingServer) generateResponseTrailers(reqCtx *RequestContext) []*co
280272
//
281273
// If include_usage is not included in the request, `data: [DONE]` is returned separately, which
282274
// indicates end of streaming.
283-
func parseRespForUsage(
284-
ctx context.Context,
285-
responseText string,
286-
) ResponseBody {
275+
func parseRespForUsage(ctx context.Context, responseText string) ResponseBody {
287276
response := ResponseBody{}
288277
logger := log.FromContext(ctx)
289278

pkg/epp/handlers/response_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestHandleResponseBody(t *testing.T) {
8686
if reqCtx == nil {
8787
reqCtx = &RequestContext{}
8888
}
89-
var responseMap map[string]interface{}
89+
var responseMap map[string]any
9090
marshalErr := json.Unmarshal(test.body, &responseMap)
9191
if marshalErr != nil {
9292
t.Error(marshalErr, "Error unmarshaling request body")

pkg/epp/handlers/server.go

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ type StreamingServer struct {
8080
director Director
8181
}
8282

83-
8483
// RequestContext stores context information during the life time of an HTTP request.
8584
// TODO: The requestContext is gathering a ton of fields. A future refactor needs to tease these fields apart.
8685
// Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request.
@@ -92,33 +91,33 @@ type RequestContext struct {
9291
ResolvedTargetModel string
9392
RequestReceivedTimestamp time.Time
9493
ResponseCompleteTimestamp time.Time
95-
FirstTokenTimestamp time.Time
96-
LastTokenTimestamp time.Time
94+
FirstTokenTimestamp time.Time
95+
LastTokenTimestamp time.Time
9796
RequestSize int
9897
Usage Usage
9998
ResponseSize int
10099
ResponseComplete bool
101100
ResponseStatusCode string
102101
RequestRunning bool
103102
Request *Request
104-
Prompt string
105-
GeneratedTokenCount int
103+
Prompt string
104+
GeneratedTokenCount int
106105

107-
LastSeenMetrics *backendmetrics.MetricsState
108-
SchedulingResult *schedulingtypes.SchedulingResult
106+
LastSeenMetrics *backendmetrics.MetricsState
107+
SchedulingResult *schedulingtypes.SchedulingResult
109108

110109
SchedulingRequest *schedulingtypes.LLMRequest
111110

112111
RequestState StreamRequestState
113112
ModelServerStreaming bool
114113

115-
TTFT float64
114+
TTFT float64
116115
PredictedTTFT float64
117116

118117
PredictedTPOTObservations []float64
119-
TPOTObservations []float64
120-
AvgTPOT float64
121-
AvgPredictedTPOT float64
118+
TPOTObservations []float64
119+
AvgTPOT float64
120+
AvgPredictedTPOT float64
122121

123122
TokenSampler *requtil.TokenSampler
124123

@@ -133,17 +132,14 @@ type RequestContext struct {
133132
respTrailerResp *extProcPb.ProcessingResponse
134133
}
135134

136-
137-
138135
type Request struct {
139136
Headers map[string]string
140-
Body map[string]interface{}
137+
Body map[string]any
141138
Metadata map[string]any
142139
}
143140
type Response struct {
144-
Headers map[string]string
141+
Headers map[string]string
145142
Trailers map[string]string
146-
147143
}
148144
type StreamRequestState int
149145

@@ -170,17 +166,17 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
170166
RequestState: RequestReceived,
171167
Request: &Request{
172168
Headers: make(map[string]string),
173-
Body: make(map[string]interface{}),
169+
Body: make(map[string]any),
174170
Metadata: make(map[string]any),
175171
},
176172
Response: &Response{
177-
Headers: make(map[string]string),
173+
Headers: make(map[string]string),
178174
Trailers: make(map[string]string),
179175
},
180176
}
181177

182178
var body []byte
183-
var responseBody map[string]interface{}
179+
var responseBody map[string]any
184180

185181
// Create error handling var as each request should only report once for
186182
// error metrics. This doesn't cover the error "Cannot receive stream request" because
@@ -302,49 +298,44 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
302298
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
303299

304300
if s.director.IsPredictorAvailable() {
305-
// var sumActual, sumPred float64
306-
// for _, actual := range reqCtx.TPOTObservations {
307-
// sumActual += actual
308-
309-
// }
310-
// for _, prediction := range reqCtx.PredictedTPOTObservations {
311-
// sumPred += prediction
312-
313-
// }
314-
315-
// avgActual := sumActual / float64(len(reqCtx.TPOTObservations))
316-
// avgPred := sumPred / float64(len(reqCtx.PredictedTPOTObservations))
317-
318-
// reqCtx.AvgTPOT = avgActual
319-
// reqCtx.AvgPredictedTPOT = avgPred
320-
321-
322-
// Compute MAPE for TTFT
323-
mapeTTFT := 0.0
324-
if reqCtx.TTFT > 0 {
325-
mapeTTFT = math.Abs((reqCtx.TTFT-reqCtx.PredictedTTFT)/reqCtx.TTFT) * 100
326-
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", reqCtx.TTFT, "avgPredictedTTFT", reqCtx.PredictedTTFT)
327-
logger.V(logutil.DEBUG).Info("MAPE TTFT computed", "mapeTTFT%", mapeTTFT)
328-
metrics.RecordRequestTTFT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.TTFT/1000)
329-
metrics.RecordRequestPredictedTTFT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.PredictedTTFT/1000)
330-
metrics.RecordRequestTTFTPredictionMape(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, mapeTTFT)
331-
332-
}
333-
334-
335-
mapeTPOT := 0.0
336-
if reqCtx.AvgTPOT > 0 {
337-
mapeTPOT = math.Abs((reqCtx.AvgTPOT-reqCtx.AvgPredictedTPOT)/reqCtx.AvgTPOT) * 100
338-
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", reqCtx.AvgTPOT, "avgPredictedTPOT", reqCtx.AvgPredictedTPOT)
339-
logger.V(logutil.DEBUG).Info("MAPE TPOT computed", "mapeTPOT%", mapeTPOT)
340-
metrics.RecordRequestTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgTPOT/1000)
341-
metrics.RecordRequestPredictedTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgPredictedTPOT/1000)
342-
metrics.RecordRequestTPOTPredictionMape(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, mapeTPOT)
301+
// var sumActual, sumPred float64
302+
// for _, actual := range reqCtx.TPOTObservations {
303+
// sumActual += actual
304+
305+
// }
306+
// for _, prediction := range reqCtx.PredictedTPOTObservations {
307+
// sumPred += prediction
308+
309+
// }
310+
311+
// avgActual := sumActual / float64(len(reqCtx.TPOTObservations))
312+
// avgPred := sumPred / float64(len(reqCtx.PredictedTPOTObservations))
313+
314+
// reqCtx.AvgTPOT = avgActual
315+
// reqCtx.AvgPredictedTPOT = avgPred
316+
317+
// Compute MAPE for TTFT
318+
mapeTTFT := 0.0
319+
if reqCtx.TTFT > 0 {
320+
mapeTTFT = math.Abs((reqCtx.TTFT-reqCtx.PredictedTTFT)/reqCtx.TTFT) * 100
321+
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", reqCtx.TTFT, "avgPredictedTTFT", reqCtx.PredictedTTFT)
322+
logger.V(logutil.DEBUG).Info("MAPE TTFT computed", "mapeTTFT%", mapeTTFT)
323+
metrics.RecordRequestTTFT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.TTFT/1000)
324+
metrics.RecordRequestPredictedTTFT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.PredictedTTFT/1000)
325+
metrics.RecordRequestTTFTPredictionMape(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, mapeTTFT)
326+
327+
}
328+
329+
mapeTPOT := 0.0
330+
if reqCtx.AvgTPOT > 0 {
331+
mapeTPOT = math.Abs((reqCtx.AvgTPOT-reqCtx.AvgPredictedTPOT)/reqCtx.AvgTPOT) * 100
332+
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", reqCtx.AvgTPOT, "avgPredictedTPOT", reqCtx.AvgPredictedTPOT)
333+
logger.V(logutil.DEBUG).Info("MAPE TPOT computed", "mapeTPOT%", mapeTPOT)
334+
metrics.RecordRequestTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgTPOT/1000)
335+
metrics.RecordRequestPredictedTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgPredictedTPOT/1000)
336+
metrics.RecordRequestTPOTPredictionMape(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, mapeTPOT)
337+
}
343338
}
344-
}
345-
346-
347-
348339

349340
}
350341

@@ -380,21 +371,21 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
380371
}
381372
case *extProcPb.ProcessingRequest_ResponseTrailers:
382373
logger.V(logutil.DEFAULT).Info("Processing response trailers", "trailers", v.ResponseTrailers.Trailers)
383-
if reqCtx.ModelServerStreaming{
384-
374+
if reqCtx.ModelServerStreaming {
375+
385376
var trailerErr error
386377
reqCtx, trailerErr = s.HandleResponseTrailers(ctx, reqCtx)
387378
if trailerErr != nil {
388-
logger.V(logutil.DEFAULT).Error(trailerErr, "Failed to process response trailers")
389-
}
379+
logger.V(logutil.DEFAULT).Error(trailerErr, "Failed to process response trailers")
380+
}
390381
reqCtx.respTrailerResp = s.generateResponseTrailerResponse(reqCtx)
391-
}
382+
}
392383
}
393384

394385
// Handle the err and fire an immediate response.
395386
if err != nil {
396387
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
397-
resp, err := BuildErrResponse(err)
388+
resp, err := buildErrResponse(err)
398389
if err != nil {
399390
return err
400391
}
@@ -475,9 +466,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
475466
return nil
476467
}
477468

478-
479-
480-
func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
469+
func buildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
481470
var resp *extProcPb.ProcessingResponse
482471

483472
switch errutil.CanonicalCode(err) {
@@ -504,6 +493,17 @@ func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
504493
},
505494
},
506495
}
496+
// This code can be returned by the director when there are no candidate pods for the request scheduling.
497+
case errutil.ServiceUnavailable:
498+
resp = &extProcPb.ProcessingResponse{
499+
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
500+
ImmediateResponse: &extProcPb.ImmediateResponse{
501+
Status: &envoyTypePb.HttpStatus{
502+
Code: envoyTypePb.StatusCode_ServiceUnavailable,
503+
},
504+
},
505+
},
506+
}
507507
// This code can be returned when users provide invalid json request.
508508
case errutil.BadRequest:
509509
resp = &extProcPb.ProcessingResponse{

0 commit comments

Comments
 (0)