Skip to content

Commit bba1ce2

Browse files
committed
add admit request plugin and fix minor bugs
1 parent cc86091 commit bba1ce2

File tree

4 files changed

+189
-6
lines changed

4 files changed

+189
-6
lines changed

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type podPredictionResult struct {
4141
}
4242

4343
// generatePredictions creates prediction results for all candidate pods
44-
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
44+
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
4545
logger := log.FromContext(ctx)
4646
predictions := make([]podPredictionResult, 0, len(candidatePods))
4747

@@ -55,7 +55,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul
5555
logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String())
5656

5757
// Get prefix cache score for the pod
58-
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
58+
prefixCacheScore := sloCtx.prefixCacheScoresForPods[pod.GetPod().String()]
5959
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore
6060

6161
logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore)

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ var _ requestcontrol.PreRequest = &SLOAwareRouter{}
3939
var _ requestcontrol.ResponseReceived = &SLOAwareRouter{}
4040
var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{}
4141
var _ requestcontrol.ResponseComplete = &SLOAwareRouter{}
42+
var _ requestcontrol.AdmissionPlugin = &SLOAwareRouter{}
4243

4344
type sloRequestContext struct {
4445
schedulingRequest schedulingtypes.LLMRequest
@@ -107,6 +108,10 @@ func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLM
107108

108109
func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
109110
logger := log.FromContext(ctx)
111+
if request == nil {
112+
logger.V(logutil.DEBUG).Info("SLOAwareRouter.PreRequest: request is nil, skipping")
113+
return
114+
}
110115

111116
if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
112117
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.")
@@ -157,6 +162,10 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype
157162

158163
func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
159164
logger := log.FromContext(ctx)
165+
if request == nil {
166+
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseReceived: request is nil, skipping")
167+
return
168+
}
160169
if !t.checkPredictor(logger, targetPod) {
161170
return
162171
}
@@ -177,6 +186,10 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli
177186

178187
func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
179188
logger := log.FromContext(ctx)
189+
if request == nil {
190+
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseStreaming: request is nil, skipping")
191+
return
192+
}
180193
if !t.checkPredictor(logger, pod) || response.EndOfStream {
181194
return
182195
}
@@ -199,6 +212,10 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul
199212

200213
func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
201214
logger := log.FromContext(ctx)
215+
if request == nil {
216+
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseComplete: request is nil, skipping")
217+
return
218+
}
202219
targetPod := pod
203220
if !t.checkPredictor(logger, targetPod) {
204221
return
@@ -250,6 +267,25 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli
250267
t.deleteSLOContextForRequest(request)
251268
}
252269

270+
func (t *SLOAwareRouter) AdmitRequest(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
271+
logger := log.FromContext(ctx)
272+
if request == nil {
273+
logger.V(logutil.DEBUG).Info("SLOAwareRouter.AdmissionController: request is nil, skipping")
274+
return nil
275+
}
276+
sloCtx, err := t.getSLOContextForRequest(request)
277+
if err != nil {
278+
logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.AdmissionController: Failed to get SLO context for request")
279+
return nil
280+
}
281+
if sloCtx.hasValidPod {
282+
return nil
283+
}
284+
errMsg := "request cannot be admitted: no valid pod available based on SLO predictions"
285+
logger.V(logutil.TRACE).Error(errors.New(errMsg), "SLOAwareRouter.AdmissionController: No Valid Pod")
286+
return errors.New(errMsg)
287+
}
288+
253289
func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, targetPod *backend.Pod) bool {
254290
if targetPod == nil {
255291
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.")

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,148 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) {
864864
assert.NotEqual(t, queue1, queue2)
865865
}
866866

867+
func TestSLOAwareRouter_AdmitRequest_NilRequest(t *testing.T) {
868+
router := createTestRouter()
869+
ctx := context.Background()
870+
pods := []schedulingtypes.Pod{}
871+
872+
err := router.AdmitRequest(ctx, nil, pods)
873+
874+
assert.NoError(t, err, "Should return nil for nil request")
875+
}
876+
877+
func TestSLOAwareRouter_AdmitRequest_NoContext(t *testing.T) {
878+
router := createTestRouter()
879+
ctx := context.Background()
880+
request := createTestLLMRequest("test", 100, 50, true)
881+
pods := []schedulingtypes.Pod{}
882+
883+
// Don't set SLO context
884+
err := router.AdmitRequest(ctx, request, pods)
885+
886+
assert.NoError(t, err, "Should return nil when SLO context not found")
887+
}
888+
889+
func TestSLOAwareRouter_AdmitRequest_HasValidPod(t *testing.T) {
890+
router := createTestRouter()
891+
ctx := context.Background()
892+
request := createTestLLMRequest("test", 100, 50, true)
893+
pods := []schedulingtypes.Pod{}
894+
895+
// Create SLO context with valid pod
896+
sloCtx := newSLORequestContext(request)
897+
sloCtx.hasValidPod = true
898+
router.setSLOContextForRequest(request, sloCtx)
899+
900+
err := router.AdmitRequest(ctx, request, pods)
901+
902+
assert.NoError(t, err, "Should admit request when hasValidPod is true")
903+
}
904+
905+
func TestSLOAwareRouter_AdmitRequest_NoValidPod(t *testing.T) {
906+
router := createTestRouter()
907+
ctx := context.Background()
908+
request := createTestLLMRequest("test", 100, 50, true)
909+
pods := []schedulingtypes.Pod{}
910+
911+
// Create SLO context without valid pod
912+
sloCtx := newSLORequestContext(request)
913+
sloCtx.hasValidPod = false
914+
router.setSLOContextForRequest(request, sloCtx)
915+
916+
err := router.AdmitRequest(ctx, request, pods)
917+
918+
assert.Error(t, err, "Should reject request when hasValidPod is false")
919+
assert.Contains(t, err.Error(), "no valid pod available based on SLO predictions")
920+
}
921+
922+
func TestSLOAwareRouter_AdmitRequest_WithMultiplePods(t *testing.T) {
923+
router := createTestRouter()
924+
ctx := context.Background()
925+
request := createTestLLMRequest("test", 100, 50, true)
926+
927+
// Create multiple test pods
928+
pod1 := createTestPod("test-pod-1", 1, 1, 1)
929+
pod2 := createTestPod("test-pod-2", 1, 1, 1)
930+
pods := []schedulingtypes.Pod{pod1, pod2}
931+
932+
// Create SLO context with valid pod
933+
sloCtx := newSLORequestContext(request)
934+
sloCtx.hasValidPod = true
935+
router.setSLOContextForRequest(request, sloCtx)
936+
937+
err := router.AdmitRequest(ctx, request, pods)
938+
939+
assert.NoError(t, err, "Should admit request with valid pod even with multiple pods available")
940+
}
941+
942+
func TestSLOAwareRouter_AdmitRequest_DefaultHasValidPod(t *testing.T) {
943+
router := createTestRouter()
944+
ctx := context.Background()
945+
request := createTestLLMRequest("test", 100, 50, true)
946+
pods := []schedulingtypes.Pod{}
947+
948+
// Create SLO context - hasValidPod defaults to false
949+
sloCtx := newSLORequestContext(request)
950+
router.setSLOContextForRequest(request, sloCtx)
951+
952+
err := router.AdmitRequest(ctx, request, pods)
953+
954+
assert.Error(t, err, "Should reject request when hasValidPod defaults to false")
955+
assert.Contains(t, err.Error(), "no valid pod available")
956+
}
957+
958+
func TestSLOAwareRouter_AdmitRequest_ConcurrentAccess(t *testing.T) {
959+
router := createTestRouter()
960+
ctx := context.Background()
961+
962+
var wg sync.WaitGroup
963+
numGoroutines := 50
964+
965+
// Half with valid pods, half without
966+
for i := 0; i < numGoroutines; i++ {
967+
wg.Add(1)
968+
go func(idx int) {
969+
defer wg.Done()
970+
971+
requestID := uuid.New().String()
972+
request := createTestLLMRequest(requestID, 100, 50, true)
973+
pods := []schedulingtypes.Pod{}
974+
975+
sloCtx := newSLORequestContext(request)
976+
sloCtx.hasValidPod = (idx%2 == 0) // Alternate between true and false
977+
router.setSLOContextForRequest(request, sloCtx)
978+
979+
err := router.AdmitRequest(ctx, request, pods)
980+
981+
if idx%2 == 0 {
982+
assert.NoError(t, err, "Should admit request with valid pod")
983+
} else {
984+
assert.Error(t, err, "Should reject request without valid pod")
985+
}
986+
}(i)
987+
}
988+
989+
wg.Wait()
990+
}
991+
992+
func TestSLOAwareRouter_AdmitRequest_ErrorMessage(t *testing.T) {
993+
router := createTestRouter()
994+
ctx := context.Background()
995+
request := createTestLLMRequest("test", 100, 50, true)
996+
pods := []schedulingtypes.Pod{}
997+
998+
sloCtx := newSLORequestContext(request)
999+
sloCtx.hasValidPod = false
1000+
router.setSLOContextForRequest(request, sloCtx)
1001+
1002+
err := router.AdmitRequest(ctx, request, pods)
1003+
1004+
require.Error(t, err)
1005+
expectedMsg := "request cannot be admitted: no valid pod available based on SLO predictions"
1006+
assert.Equal(t, expectedMsg, err.Error(), "Error message should match expected format")
1007+
}
1008+
8671009
func TestSLORequestContext_SLOValidation(t *testing.T) {
8681010
tests := []struct {
8691011
name string

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
153153

154154
s.parseSLOHeaders(ctx, request, sloCtx)
155155

156+
for _, pod := range pods {
157+
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
158+
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore
159+
}
160+
156161
// Check if SLOs are provided
157162
if !sloCtx.predictorBasedScheduling {
158163
logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering")
@@ -168,7 +173,7 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
168173

169174
source := rand.NewSource(time.Now().UnixNano())
170175
r := rand.New(source)
171-
predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods)
176+
predictions, err := s.generatePredictions(ctx, request, sloCtx, pods)
172177
if err != nil {
173178
logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring")
174179
// Fall back to composite-only scoring using prefix cache scores
@@ -180,12 +185,12 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
180185
allPreds := append([]podPredictionResult(nil), predictions...)
181186
allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal)
182187

183-
// Check if all pods are invalid and all have running requests
184-
allPodsInvalid := true
188+
// Check if all pods are invalid and all have running requests. If slos are == 0 then all pods are valid
189+
allPodsInvalid := (sloCtx.ttftSLO > 0 && sloCtx.avgTPOTSLO > 0)
185190
allPodsHaveRunningRequests := true
186191

187192
for _, pred := range allPreds {
188-
if pred.IsValid {
193+
if pred.IsValid && pred.TTFTValid && pred.TPOTValid {
189194
allPodsInvalid = false
190195
}
191196

0 commit comments

Comments
 (0)