Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

type Scheduler interface {
Expand Down Expand Up @@ -82,6 +83,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}

llmReq := &schedulingtypes.LLMRequest{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Model: reqCtx.Model,
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
Expand Down
8 changes: 5 additions & 3 deletions pkg/epp/scheduling/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
got := test.filter.Filter(ctx, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
Expand Down Expand Up @@ -187,7 +188,7 @@ func TestFilterFunc(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
got := test.f(ctx, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
Expand Down Expand Up @@ -221,6 +222,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {

// Create a test request and pods
req := &types.LLMRequest{
RequestId: uuid.NewString(),
Model: testAffinityModel,
ResolvedTargetModel: testAffinityModel,
}
Expand All @@ -244,7 +246,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
},
},
}
ctx := types.NewSchedulingContext(context.Background(), req, pods)
ctx := types.NewSchedulingContext(context.Background(), req, nil, pods)

// Run the filter function multiple times and count the results
affinityCount := 0
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/plugins/scorer/kvcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
scorer := &KVCacheScorer{}
scores := scorer.Score(ctx, tt.pods)

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/plugins/scorer/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
scores := scorer.Score(ctx, tt.pods)

for i, pod := range tt.pods {
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))

s.runPreSchedulePlugins(sCtx)
Expand Down
10 changes: 9 additions & 1 deletion pkg/epp/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
Expand All @@ -40,6 +41,7 @@ func TestSchedule(t *testing.T) {
{
name: "no pods in datastore",
req: &types.LLMRequest{
RequestId: uuid.NewString(),
Model: "any-model",
ResolvedTargetModel: "any-model",
Critical: true,
Expand All @@ -50,6 +52,7 @@ func TestSchedule(t *testing.T) {
{
name: "critical request",
req: &types.LLMRequest{
RequestId: uuid.NewString(),
Model: "critical",
ResolvedTargetModel: "critical",
Critical: true,
Expand Down Expand Up @@ -114,6 +117,7 @@ func TestSchedule(t *testing.T) {
{
name: "sheddable request, accepted",
req: &types.LLMRequest{
RequestId: uuid.NewString(),
Model: "sheddable",
ResolvedTargetModel: "sheddable",
Critical: false,
Expand Down Expand Up @@ -177,6 +181,7 @@ func TestSchedule(t *testing.T) {
{
name: "sheddable request, dropped",
req: &types.LLMRequest{
RequestId: uuid.NewString(),
Model: "sheddable",
ResolvedTargetModel: "sheddable",
Critical: false,
Expand Down Expand Up @@ -356,7 +361,10 @@ func TestSchedulePlugins(t *testing.T) {
// Initialize the scheduler
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)

req := &types.LLMRequest{Model: "test-model"}
req := &types.LLMRequest{
RequestId: uuid.NewString(),
Model: "test-model",
}
got, err := scheduler.Schedule(context.Background(), req)

// Validate error state
Expand Down
20 changes: 19 additions & 1 deletion pkg/epp/scheduling/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (

// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
type LLMRequest struct {
// RequestId is the Envoy generated Id for the request being processed
RequestId string
// Model is the name of the model that the user specified in the request body.
Model string
// ResolvedTargetModel is the final target model after traffic split.
Expand All @@ -45,6 +47,20 @@ func (r *LLMRequest) String() string {
r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers)
}

// LLMResponse contains information from the response received to be passed to plugins
type LLMResponse struct {
// RequestId is the Envoy generated Id for the request being processed
RequestId string
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence if there should be a dedicated field for this, we could put this in the Headers map, or check to see if it's already there, and then populate.

// Headers is a map of the response headers. Nil during body processing
Headers map[string]string
// Body Is the body of the response or nil during header processing
Body string
// IsStreaming indicates whether or not the response is being streamed by the model
IsStreaming bool
// EndOfStream when true indicates that this invocation contains the last chunk of the response
EndOfStream bool
}

type Pod interface {
GetPod() *backend.Pod
GetMetrics() *backendmetrics.Metrics
Expand All @@ -61,6 +77,7 @@ type SchedulingContext struct {
context.Context
Logger logr.Logger
Req *LLMRequest
Resp *LLMResponse
PodsSnapshot []Pod
}

Expand All @@ -84,12 +101,13 @@ type PodMetrics struct {
*backendmetrics.Metrics
}

func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
func NewSchedulingContext(ctx context.Context, req *LLMRequest, resp *LLMResponse, pods []Pod) *SchedulingContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This param is nil in all the places it's called. It's also a exported field, perhaps the New func should just initialize the LLMResponse struct? And any consumer can populate or replace the field as needed.

(Completely out of scope of this PR, but I would like to move away from LLM naming, a lot of this code can be generalized to Inference, LLMs are just the current focus)

logger := log.FromContext(ctx).WithValues("request", req)
return &SchedulingContext{
Context: ctx,
Logger: logger,
Req: req,
Resp: resp,
PodsSnapshot: pods,
}
}
Expand Down