diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index cafcc80b0..fe70b4378 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -31,6 +31,7 @@ import ( schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) type Scheduler interface { @@ -82,6 +83,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } llmReq := &schedulingtypes.LLMRequest{ + RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], Model: reqCtx.Model, ResolvedTargetModel: reqCtx.ResolvedTargetModel, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, diff --git a/pkg/epp/scheduling/plugins/filter/filter_test.go b/pkg/epp/scheduling/plugins/filter/filter_test.go index 2354c3ef5..dd90907ad 100644 --- a/pkg/epp/scheduling/plugins/filter/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter/filter_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -52,7 +53,7 @@ func TestFilter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input) got := test.filter.Filter(ctx, test.input) if diff := cmp.Diff(test.output, got); diff != "" { @@ -187,7 +188,7 @@ func TestFilterFunc(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input) got := test.f(ctx, test.input) if diff := cmp.Diff(test.output, got); diff != "" { @@ -221,6 +222,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { // Create a test request and pods req := &types.LLMRequest{ + RequestId: uuid.NewString(), Model: testAffinityModel, ResolvedTargetModel: testAffinityModel, } @@ -244,7 +246,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, } - ctx := types.NewSchedulingContext(context.Background(), req, pods) + ctx := types.NewSchedulingContext(context.Background(), req, nil, pods) // Run the filter function multiple times and count the results affinityCount := 0 diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache_test.go b/pkg/epp/scheduling/plugins/scorer/kvcache_test.go index 257a58c17..68be8a213 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache_test.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache_test.go @@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods) + ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods) scorer := &KVCacheScorer{} scores := scorer.Score(ctx, tt.pods) diff --git a/pkg/epp/scheduling/plugins/scorer/queue_test.go b/pkg/epp/scheduling/plugins/scorer/queue_test.go index 907681b25..d60eab66a 100644 --- a/pkg/epp/scheduling/plugins/scorer/queue_test.go +++ b/pkg/epp/scheduling/plugins/scorer/queue_test.go @@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods) + ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods) scores := scorer.Score(ctx, tt.pods) for i, pod := range tt.pods { diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 9215489fe..4c7a6fbe2 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -108,7 +108,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index b44c7ac2e..b9fe0cd51 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds @@ -40,6 +41,7 @@ func TestSchedule(t *testing.T) { { name: "no pods in datastore", req: &types.LLMRequest{ + RequestId: uuid.NewString(), Model: "any-model", ResolvedTargetModel: "any-model", Critical: true, @@ -50,6 +52,7 @@ func TestSchedule(t *testing.T) { { name: "critical request", req: &types.LLMRequest{ + RequestId: uuid.NewString(), Model: "critical", ResolvedTargetModel: "critical", Critical: true, @@ -114,6 +117,7 @@ func TestSchedule(t *testing.T) { { name: "sheddable request, accepted", req: &types.LLMRequest{ + RequestId: uuid.NewString(), Model: "sheddable", ResolvedTargetModel: "sheddable", Critical: false, @@ -177,6 +181,7 @@ func TestSchedule(t *testing.T) { { name: "sheddable request, dropped", req: &types.LLMRequest{ + RequestId: uuid.NewString(), Model: "sheddable", ResolvedTargetModel: "sheddable", Critical: false, @@ -356,7 +361,10 @@ func TestSchedulePlugins(t *testing.T) { // Initialize the scheduler scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) - req := &types.LLMRequest{Model: "test-model"} + req := &types.LLMRequest{ + RequestId: uuid.NewString(), + Model: "test-model", + } got, err := scheduler.Schedule(context.Background(), req) // Validate error state diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 795ef65d2..5d965fcbe 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -28,6 +28,8 @@ import ( // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { + // RequestId is the Envoy generated Id for the request being processed + RequestId string // Model is the name of the model that the user specified in the request body. Model string // ResolvedTargetModel is the final target model after traffic split. @@ -45,6 +47,20 @@ func (r *LLMRequest) String() string { r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers) } +// LLMResponse contains information from the response received to be passed to plugins +type LLMResponse struct { + // RequestId is the Envoy generated Id for the request being processed + RequestId string + // Headers is a map of the response headers. Nil during body processing + Headers map[string]string + // Body Is the body of the response or nil during header processing + Body string + // IsStreaming indicates whether or not the response is being streamed by the model + IsStreaming bool + // EndOfStream when true indicates that this invocation contains the last chunk of the response + EndOfStream bool +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.Metrics @@ -61,6 +77,7 @@ type SchedulingContext struct { context.Context Logger logr.Logger Req *LLMRequest + Resp *LLMResponse PodsSnapshot []Pod } @@ -84,12 +101,13 @@ type PodMetrics struct { *backendmetrics.Metrics } -func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { +func NewSchedulingContext(ctx context.Context, req *LLMRequest, resp *LLMResponse, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ Context: ctx, Logger: logger, Req: req, + Resp: resp, PodsSnapshot: pods, } }