@@ -18,100 +18,50 @@ package handlers
1818
1919import (
2020 "context"
21- "encoding/json"
22- "fmt"
2321 "strconv"
2422 "time"
2523
24+ configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2625 extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27- "sigs.k8s.io/controller-runtime/pkg/log"
28- "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
29- schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+ "google.golang.org/protobuf/types/known/structpb"
3027 errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
31- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3228)
3329
34- // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
35- func (s * StreamingServer ) HandleRequestBody (ctx context.Context , reqCtx * RequestContext ) (* RequestContext , error ) {
36- logger := log .FromContext (ctx )
37-
38- var requestBodyBytes []byte
39- requestBodyMap := reqCtx .Request .Body
40- // Resolve target models.
41- model , ok := requestBodyMap ["model" ].(string )
42- if ! ok {
43- return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
44- }
45- prompt , ok := requestBodyMap ["prompt" ].(string )
46- if ! ok {
47- return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "prompt not found in request" }
48- }
49-
50- modelName := model
30+ func (s * StreamingServer ) HandleRequestHeaders (ctx context.Context , reqCtx * RequestContext , req * extProcPb.ProcessingRequest_RequestHeaders ) error {
31+ reqCtx .RequestReceivedTimestamp = time .Now ()
5132
52- // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
53- // This might be a security risk in the future where adapters not registered in the InferenceModel
54- // are able to be requested by using their distinct name.
55- modelObj := s .datastore .ModelGet (model )
56- if modelObj == nil {
57- return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
58- }
59- if len (modelObj .Spec .TargetModels ) > 0 {
60- modelName = RandomWeightedDraw (logger , modelObj , 0 )
61- if modelName == "" {
62- return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
33+ // an EoS in the request headers means this request has no body or trailers.
34+ if req .RequestHeaders .EndOfStream {
35+ // We will route this request to a random pod as this is assumed to just be a GET
36+ // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
37+ // The above PR will address endpoint admission, but currently any request without a body will be
38+ // routed to a random upstream pod.
39+ pod := s .director .GetRandomPod ()
40+ if pod == nil {
41+ return errutil.Error {Code : errutil .Internal , Msg : "no pods available in datastore" }
6342 }
43+ pool , err := s .datastore .PoolGet ()
44+ if err != nil {
45+ return err
46+ }
47+ reqCtx .TargetEndpoint = pod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
48+ reqCtx .RequestSize = 0
49+ reqCtx .reqHeaderResp = s .generateRequestHeaderResponse (reqCtx )
50+ return nil
6451 }
65- llmReq := & schedulingtypes.LLMRequest {
66- Model : model ,
67- ResolvedTargetModel : modelName ,
68- Critical : modelObj .Spec .Criticality != nil && * modelObj .Spec .Criticality == v1alpha2 .Critical ,
69- Prompt : prompt ,
70- Headers : reqCtx .Request .Headers ,
71- }
72- logger .V (logutil .DEBUG ).Info ("LLM request assembled" , "request" , llmReq )
73-
74- var err error
75- // Update target models in the body.
76- if llmReq .Model != llmReq .ResolvedTargetModel {
77- requestBodyMap ["model" ] = llmReq .ResolvedTargetModel
78- }
79-
80- requestBodyBytes , err = json .Marshal (requestBodyMap )
81- if err != nil {
82- logger .V (logutil .DEFAULT ).Error (err , "Error marshaling request body" )
83- return reqCtx , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
84- }
85-
86- res , err := s .scheduler .Schedule (ctx , llmReq )
87- if err != nil {
88- return reqCtx , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
89- }
90- targetPod := res .TargetPod .GetPod ()
9152
92- // Insert target endpoint to instruct Envoy to route requests to the specified target pod.
93- // Attach the port number
94- pool , err := s .datastore .PoolGet ()
95- if err != nil {
96- return reqCtx , err
53+ for _ , header := range req .RequestHeaders .Headers .Headers {
54+ if header .RawValue != nil {
55+ reqCtx .Request .Headers [header .Key ] = string (header .RawValue )
56+ } else {
57+ reqCtx .Request .Headers [header .Key ] = header .Value
58+ }
9759 }
98- endpoint := targetPod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
99-
100- logger .V (logutil .DEFAULT ).Info ("Request handled" ,
101- "model" , llmReq .Model , "targetModel" , llmReq .ResolvedTargetModel , "endpoint" , targetPod )
102-
103- reqCtx .Model = llmReq .Model
104- reqCtx .ResolvedTargetModel = llmReq .ResolvedTargetModel
105- reqCtx .RequestSize = len (requestBodyBytes )
106- reqCtx .TargetPod = targetPod .NamespacedName .String ()
107- reqCtx .TargetEndpoint = endpoint
108-
109- s .populateRequestHeaderResponse (reqCtx , endpoint , len (requestBodyBytes ))
60+ return nil
61+ }
11062
111- reqCtx .reqBodyResp = & extProcPb.ProcessingResponse {
112- // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
113- // and as an unstructure ext-proc response metadata key/value pair. This enables different integration
114- // options for gateway providers.
63+ func (s * StreamingServer ) generateRequestBodyResponse (requestBodyBytes []byte ) * extProcPb.ProcessingResponse {
64+ return & extProcPb.ProcessingResponse {
11565 Response : & extProcPb.ProcessingResponse_RequestBody {
11666 RequestBody : & extProcPb.BodyResponse {
11767 Response : & extProcPb.CommonResponse {
@@ -127,37 +77,82 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request
12777 },
12878 },
12979 }
130- return reqCtx , nil
13180}
13281
133- func (s * StreamingServer ) HandleRequestHeaders (ctx context.Context , reqCtx * RequestContext , req * extProcPb.ProcessingRequest_RequestHeaders ) error {
134- reqCtx .RequestReceivedTimestamp = time .Now ()
82+ func (s * StreamingServer ) generateRequestHeaderResponse (reqCtx * RequestContext ) * extProcPb.ProcessingResponse {
83+ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
84+ // and as an unstructure ext-proc response metadata key/value pair. This enables different integration
85+ // options for gateway providers.
86+ return & extProcPb.ProcessingResponse {
87+ Response : & extProcPb.ProcessingResponse_RequestHeaders {
88+ RequestHeaders : & extProcPb.HeadersResponse {
89+ Response : & extProcPb.CommonResponse {
90+ ClearRouteCache : true ,
91+ HeaderMutation : & extProcPb.HeaderMutation {
92+ SetHeaders : s .generateHeaders (reqCtx ),
93+ },
94+ },
95+ },
96+ },
97+ DynamicMetadata : s .generateMetadata (reqCtx .TargetEndpoint ),
98+ }
99+ }
135100
136- // an EoS in the request headers means this request has no body or trailers.
137- if req .RequestHeaders .EndOfStream {
138- // We will route this request to a random pod as this is assumed to just be a GET
139- // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
140- // The above PR will address endpoint admission, but currently any request without a body will be
141- // routed to a random upstream pod.
142- pod := GetRandomPod (s .datastore )
143- if pod == nil {
144- return errutil.Error {Code : errutil .Internal , Msg : "no pods available in datastore" }
145- }
146- pool , err := s .datastore .PoolGet ()
147- if err != nil {
148- return err
149- }
150- endpoint := pod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
151- s .populateRequestHeaderResponse (reqCtx , endpoint , 0 )
152- return nil
101+ func (s * StreamingServer ) generateHeaders (reqCtx * RequestContext ) []* configPb.HeaderValueOption {
102+ // can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic.
103+ headers := []* configPb.HeaderValueOption {
104+ {
105+ Header : & configPb.HeaderValue {
106+ Key : s .destinationEndpointHintKey ,
107+ RawValue : []byte (reqCtx .TargetEndpoint ),
108+ },
109+ },
110+ }
111+ if reqCtx .RequestSize > 0 {
112+ // We need to update the content length header if the body is mutated, see Envoy doc:
113+ // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
114+ headers = append (headers , & configPb.HeaderValueOption {
115+ Header : & configPb.HeaderValue {
116+ Key : "Content-Length" ,
117+ RawValue : []byte (strconv .Itoa (reqCtx .RequestSize )),
118+ },
119+ })
153120 }
154121
155- for _ , header := range req .RequestHeaders .Headers .Headers {
156- if header .RawValue != nil {
157- reqCtx .Request .Headers [header .Key ] = string (header .RawValue )
158- } else {
159- reqCtx .Request .Headers [header .Key ] = header .Value
122+ // include all headers
123+ for key , value := range reqCtx .Request .Headers {
124+ headers = append (headers , & configPb.HeaderValueOption {
125+ Header : & configPb.HeaderValue {
126+ Key : key ,
127+ RawValue : []byte (value ),
128+ },
129+ })
130+ }
131+ return headers
132+ }
133+
134+ func (s * StreamingServer ) generateMetadata (endpoint string ) * structpb.Struct {
135+ targetEndpointValue := & structpb.Struct {
136+ Fields : map [string ]* structpb.Value {
137+ s .destinationEndpointHintKey : {
138+ Kind : & structpb.Value_StringValue {
139+ StringValue : endpoint ,
140+ },
141+ },
142+ },
143+ }
144+ dynamicMetadata := targetEndpointValue
145+ if s .destinationEndpointHintMetadataNamespace != "" {
146+ // If a namespace is defined, wrap the selected endpoint with that.
147+ dynamicMetadata = & structpb.Struct {
148+ Fields : map [string ]* structpb.Value {
149+ s .destinationEndpointHintMetadataNamespace : {
150+ Kind : & structpb.Value_StructValue {
151+ StructValue : targetEndpointValue ,
152+ },
153+ },
154+ },
160155 }
161156 }
162- return nil
157+ return dynamicMetadata
163158}
0 commit comments