Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type podPredictionResult struct {
}

// generatePredictions creates prediction results for all candidate pods
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
logger := log.FromContext(ctx)
predictions := make([]podPredictionResult, 0, len(candidatePods))

Expand All @@ -55,7 +55,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul
logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String())

// Get prefix cache score for the pod
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
prefixCacheScore := sloCtx.prefixCacheScoresForPods[pod.GetPod().String()]
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore

logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ var _ requestcontrol.PreRequest = &SLOAwareRouter{}
var _ requestcontrol.ResponseReceived = &SLOAwareRouter{}
var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{}
var _ requestcontrol.ResponseComplete = &SLOAwareRouter{}
var _ requestcontrol.AdmissionPlugin = &SLOAwareRouter{}

type sloRequestContext struct {
schedulingRequest schedulingtypes.LLMRequest
Expand Down Expand Up @@ -107,6 +108,10 @@ func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLM

func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("SLOAwareRouter.PreRequest: request is nil, skipping")
return
}

if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.")
Expand Down Expand Up @@ -157,6 +162,10 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype

func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseReceived: request is nil, skipping")
return
}
if !t.checkPredictor(logger, targetPod) {
return
}
Expand All @@ -177,6 +186,10 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli

func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseStreaming: request is nil, skipping")
return
}
if !t.checkPredictor(logger, pod) || response.EndOfStream {
return
}
Expand All @@ -199,6 +212,10 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul

func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseComplete: request is nil, skipping")
return
}
targetPod := pod
if !t.checkPredictor(logger, targetPod) {
return
Expand Down Expand Up @@ -250,6 +267,25 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli
t.deleteSLOContextForRequest(request)
}

func (t *SLOAwareRouter) AdmitRequest(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("SLOAwareRouter.AdmissionController: request is nil, skipping")
return nil
}
sloCtx, err := t.getSLOContextForRequest(request)
if err != nil {
logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.AdmissionController: Failed to get SLO context for request")
return nil
}
if sloCtx.hasValidPod {
return nil
}
errMsg := "request cannot be admitted: no valid pod available based on SLO predictions"
logger.V(logutil.TRACE).Error(errors.New(errMsg), "SLOAwareRouter.AdmissionController: No Valid Pod")
return errors.New(errMsg)
}

func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, targetPod *backend.Pod) bool {
if targetPod == nil {
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,148 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) {
assert.NotEqual(t, queue1, queue2)
}

func TestSLOAwareRouter_AdmitRequest_NilRequest(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
pods := []schedulingtypes.Pod{}

err := router.AdmitRequest(ctx, nil, pods)

assert.NoError(t, err, "Should return nil for nil request")
}

func TestSLOAwareRouter_AdmitRequest_NoContext(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)
pods := []schedulingtypes.Pod{}

// Don't set SLO context
err := router.AdmitRequest(ctx, request, pods)

assert.NoError(t, err, "Should return nil when SLO context not found")
}

func TestSLOAwareRouter_AdmitRequest_HasValidPod(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)
pods := []schedulingtypes.Pod{}

// Create SLO context with valid pod
sloCtx := newSLORequestContext(request)
sloCtx.hasValidPod = true
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

assert.NoError(t, err, "Should admit request when hasValidPod is true")
}

func TestSLOAwareRouter_AdmitRequest_NoValidPod(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)
pods := []schedulingtypes.Pod{}

// Create SLO context without valid pod
sloCtx := newSLORequestContext(request)
sloCtx.hasValidPod = false
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

assert.Error(t, err, "Should reject request when hasValidPod is false")
assert.Contains(t, err.Error(), "no valid pod available based on SLO predictions")
}

func TestSLOAwareRouter_AdmitRequest_WithMultiplePods(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)

// Create multiple test pods
pod1 := createTestPod("test-pod-1", 1, 1, 1)
pod2 := createTestPod("test-pod-2", 1, 1, 1)
pods := []schedulingtypes.Pod{pod1, pod2}

// Create SLO context with valid pod
sloCtx := newSLORequestContext(request)
sloCtx.hasValidPod = true
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

assert.NoError(t, err, "Should admit request with valid pod even with multiple pods available")
}

func TestSLOAwareRouter_AdmitRequest_DefaultHasValidPod(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)
pods := []schedulingtypes.Pod{}

// Create SLO context - hasValidPod defaults to false
sloCtx := newSLORequestContext(request)
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

assert.Error(t, err, "Should reject request when hasValidPod defaults to false")
assert.Contains(t, err.Error(), "no valid pod available")
}

func TestSLOAwareRouter_AdmitRequest_ConcurrentAccess(t *testing.T) {
router := createTestRouter()
ctx := context.Background()

var wg sync.WaitGroup
numGoroutines := 50

// Half with valid pods, half without
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()

requestID := uuid.New().String()
request := createTestLLMRequest(requestID, 100, 50, true)
pods := []schedulingtypes.Pod{}

sloCtx := newSLORequestContext(request)
sloCtx.hasValidPod = (idx%2 == 0) // Alternate between true and false
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

if idx%2 == 0 {
assert.NoError(t, err, "Should admit request with valid pod")
} else {
assert.Error(t, err, "Should reject request without valid pod")
}
}(i)
}

wg.Wait()
}

func TestSLOAwareRouter_AdmitRequest_ErrorMessage(t *testing.T) {
router := createTestRouter()
ctx := context.Background()
request := createTestLLMRequest("test", 100, 50, true)
pods := []schedulingtypes.Pod{}

sloCtx := newSLORequestContext(request)
sloCtx.hasValidPod = false
router.setSLOContextForRequest(request, sloCtx)

err := router.AdmitRequest(ctx, request, pods)

require.Error(t, err)
expectedMsg := "request cannot be admitted: no valid pod available based on SLO predictions"
assert.Equal(t, expectedMsg, err.Error(), "Error message should match expected format")
}

func TestSLORequestContext_SLOValidation(t *testing.T) {
tests := []struct {
name string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle

s.parseSLOHeaders(ctx, request, sloCtx)

for _, pod := range pods {
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore
}

// Check if SLOs are provided
if !sloCtx.predictorBasedScheduling {
logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering")
Expand All @@ -168,7 +173,7 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle

source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source)
predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods)
predictions, err := s.generatePredictions(ctx, request, sloCtx, pods)
if err != nil {
logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring")
// Fall back to composite-only scoring using prefix cache scores
Expand All @@ -180,8 +185,8 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
allPreds := append([]podPredictionResult(nil), predictions...)
allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal)

// Check if all pods are invalid and all have running requests
allPodsInvalid := true
// Check if all pods are invalid and all have running requests. If slos are == 0 then all pods are valid
allPodsInvalid := (sloCtx.ttftSLO > 0 && sloCtx.avgTPOTSLO > 0)
allPodsHaveRunningRequests := true

for _, pred := range allPreds {
Expand Down