@@ -18,16 +18,20 @@ package filter
1818
1919import (
2020 "context"
21+ "encoding/json"
2122 "testing"
2223
2324 "github.com/google/go-cmp/cmp"
25+ "github.com/google/go-cmp/cmp/cmpopts"
2426 "github.com/google/uuid"
2527 k8stypes "k8s.io/apimachinery/pkg/types"
2628 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2729 backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2830 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2931 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
32+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
3033 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
34+ "sigs.k8s.io/gateway-api-inference-extension/test/utils"
3135)
3236
3337// compile-time type assertion
@@ -251,3 +255,281 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
251255 actualAvailablePercent , availableLowerBound , availableUpperBound )
252256 }
253257}
258+
259+ // TestDecisionTreeFilterFactory tests that the DecisionTreeFilterFactory function
260+ // properly instantiates DecisionTreeFilter instances
261+ func TestDecisionTreeFilterFactory (t * testing.T ) {
262+
263+ leastKvCacheFilter := NewLeastKVCacheFilter ()
264+ leastQueueFilter := NewLeastQueueFilter ()
265+ loraAffinityFilter := NewLoraAffinityFilter (config .Conf .LoraAffinityThreshold )
266+ lowQueueFilter := NewLowQueueFilter (config .Conf .QueueingThresholdLoRA )
267+
268+ kvCacheScorer := scorer .NewKVCacheScorer ()
269+
270+ testHandle := utils .NewTestHandle ()
271+
272+ testHandle .Plugins ().AddPlugin ("leastKvCache" , leastKvCacheFilter )
273+ testHandle .Plugins ().AddPlugin ("leastQueue" , leastQueueFilter )
274+ testHandle .Plugins ().AddPlugin ("loraAffinity" , loraAffinityFilter )
275+ testHandle .Plugins ().AddPlugin ("lowQueue" , lowQueueFilter )
276+
277+ testHandle .Plugins ().AddPlugin ("kvCacheScorer" , kvCacheScorer )
278+
279+ tests := []struct {
280+ name string
281+ parameters string
282+ want * DecisionTreeFilter
283+ wantErr bool
284+ }{
285+ {
286+ name : "success" ,
287+ parameters : decisionTreeParametersSuccess ,
288+ want : & DecisionTreeFilter {
289+ Current : lowQueueFilter ,
290+ NextOnSuccess : & DecisionTreeFilter {
291+ Current : loraAffinityFilter ,
292+ NextOnSuccessOrFailure : & DecisionTreeFilter {
293+ Current : leastQueueFilter ,
294+ NextOnSuccessOrFailure : & DecisionTreeFilter {
295+ Current : leastKvCacheFilter ,
296+ },
297+ },
298+ },
299+ NextOnFailure : & DecisionTreeFilter {
300+ Current : leastQueueFilter ,
301+ NextOnSuccessOrFailure : & DecisionTreeFilter {
302+ Current : loraAffinityFilter ,
303+ NextOnSuccessOrFailure : & DecisionTreeFilter {
304+ Current : leastKvCacheFilter ,
305+ },
306+ },
307+ },
308+ },
309+ wantErr : false ,
310+ },
311+ {
312+ name : "bothError" ,
313+ parameters : decisionTreeParametersErrorBoth ,
314+ want : nil ,
315+ wantErr : true ,
316+ },
317+ {
318+ name : "noneError" ,
319+ parameters : decisionTreeParametersErrorNone ,
320+ want : nil ,
321+ wantErr : true ,
322+ },
323+ {
324+ name : "badPlugin" ,
325+ parameters : decisionTreeParametersErrorBadPlugin ,
326+ want : nil ,
327+ wantErr : true ,
328+ },
329+ {
330+ name : "notFilter" ,
331+ parameters : decisionTreeParametersErrorNotFilter ,
332+ want : nil ,
333+ wantErr : true ,
334+ },
335+ {
336+ name : "noCurrent" ,
337+ parameters : decisionTreeParametersErrorNoCurrent ,
338+ want : nil ,
339+ wantErr : true ,
340+ },
341+ {
342+ name : "badNextOnSuccess" ,
343+ parameters : decisionTreeParametersErrorBadNextOnSuccess ,
344+ want : nil ,
345+ wantErr : true ,
346+ },
347+ {
348+ name : "badNextOnFailure" ,
349+ parameters : decisionTreeParametersErrorBadNextOnFailure ,
350+ want : nil ,
351+ wantErr : true ,
352+ },
353+ {
354+ name : "badNextOnSuccessOrFailure" ,
355+ parameters : decisionTreeParametersErrorBadNextOnSuccessOrFailure ,
356+ want : nil ,
357+ wantErr : true ,
358+ },
359+ }
360+
361+ cmpOptions := cmpopts .IgnoreUnexported (LeastKVCacheFilter {}, LeastQueueFilter {},
362+ LoraAffinityFilter {}, LowQueueFilter {}, scorer.KVCacheScorer {})
363+
364+ for _ , test := range tests {
365+ rawParameters := struct {
366+ Parameters json.RawMessage `json:"parameters"`
367+ }{}
368+ err := json .Unmarshal ([]byte (test .parameters ), & rawParameters )
369+ if err != nil {
370+ if test .wantErr {
371+ continue
372+ } else {
373+ t .Fatal ("failed to parse JSON of test " + test .name )
374+ }
375+ }
376+ got , err := DecisionTreeFilterFactory ("testing" , rawParameters .Parameters , testHandle )
377+ if err != nil {
378+ if test .wantErr {
379+ continue
380+ }
381+ t .Fatalf ("failed to instantiate DecisionTreeFilter. error: %s\n " , err )
382+ }
383+ if test .wantErr {
384+ t .Fatalf ("test %s did not return the expected error" , test .name )
385+ }
386+ if diff := cmp .Diff (test .want , got , cmpOptions ); diff != "" {
387+ t .Fatalf ("In test %s DecisionTreeFactory returned unexpected response, diff(-want, +got): %v" , test .name , diff )
388+ }
389+ }
390+ }
391+
392+ const decisionTreeParametersSuccess = `
393+ {
394+ "parameters": {
395+ "current": {
396+ "pluginRef": "lowQueue"
397+ },
398+ "nextOnSuccess": {
399+ "decisionTree": {
400+ "current": {
401+ "pluginRef": "loraAffinity"
402+ },
403+ "nextOnSuccessOrFailure": {
404+ "decisionTree": {
405+ "current": {
406+ "pluginRef": "leastQueue"
407+ },
408+ "nextOnSuccessOrFailure": {
409+ "decisionTree": {
410+ "current": {
411+ "pluginRef": "leastKvCache"
412+ }
413+ }
414+ }
415+ }
416+ }
417+ }
418+ },
419+ "nextOnFailure": {
420+ "decisionTree": {
421+ "current": {
422+ "pluginRef": "leastQueue"
423+ },
424+ "nextOnSuccessOrFailure": {
425+ "decisionTree": {
426+ "current": {
427+ "pluginRef": "loraAffinity"
428+ },
429+ "nextOnSuccessOrFailure": {
430+ "decisionTree": {
431+ "current": {
432+ "pluginRef": "leastKvCache"
433+ }
434+ }
435+ }
436+ }
437+ }
438+ }
439+ }
440+ }
441+ }
442+ `
443+
444+ const decisionTreeParametersErrorBoth = `
445+ {
446+ "parameters": {
447+ "current": {
448+ "pluginRef": "lowQueue",
449+ "decisionTree": {
450+ "current": {
451+ "pluginRef": "leastKvCache"
452+ }
453+ }
454+ }
455+ }
456+ }
457+ `
458+
459+ const decisionTreeParametersErrorNone = `
460+ {
461+ "parameters": {
462+ "current": {
463+ }
464+ }
465+ }
466+ `
467+
468+ const decisionTreeParametersErrorBadPlugin = `
469+ {
470+ "parameters": {
471+ "current": {
472+ "pluginRef": "plover"
473+ }
474+ }
475+ }
476+ `
477+
478+ const decisionTreeParametersErrorNotFilter = `
479+ {
480+ "parameters": {
481+ "current": {
482+ "pluginRef": "kvCacheScorer"
483+ }
484+ }
485+ }
486+ `
487+
488+ const decisionTreeParametersErrorNoCurrent = `
489+ {
490+ "parameters": {
491+ "NextOnSuccess": {
492+ "pluginRef": "lowQueue"
493+ }
494+ }
495+ }
496+ `
497+
498+ const decisionTreeParametersErrorBadNextOnSuccess = `
499+ {
500+ "parameters": {
501+ "current": {
502+ "pluginRef": "lowQueue"
503+ },
504+ "NextOnSuccess": {
505+ "pluginRef": "kvCacheScorer"
506+ }
507+ }
508+ }
509+ `
510+
511+ const decisionTreeParametersErrorBadNextOnFailure = `
512+ {
513+ "parameters": {
514+ "current": {
515+ "pluginRef": "lowQueue"
516+ },
517+ "NextOnFailure": {
518+ "pluginRef": "kvCacheScorer"
519+ }
520+ }
521+ }
522+ `
523+
524+ const decisionTreeParametersErrorBadNextOnSuccessOrFailure = `
525+ {
526+ "parameters": {
527+ "current": {
528+ "pluginRef": "lowQueue"
529+ },
530+ "NextOnSuccessOrFailure": {
531+ "pluginRef": "kvCacheScorer"
532+ }
533+ }
534+ }
535+ `
0 commit comments