@@ -18,16 +18,20 @@ package filter
1818
1919import (
2020 "context"
21+ "encoding/json"
2122 "testing"
2223
2324 "github.com/google/go-cmp/cmp"
25+ "github.com/google/go-cmp/cmp/cmpopts"
2426 "github.com/google/uuid"
2527 k8stypes "k8s.io/apimachinery/pkg/types"
2628 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2729 backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2830 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2931 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
32+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
3033 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
34+ "sigs.k8s.io/gateway-api-inference-extension/test/utils"
3135)
3236
3337// compile-time type assertion
@@ -390,3 +394,281 @@ func TestSubsettingFilter(t *testing.T) {
390394 })
391395 }
392396}
397+
398+ // TestDecisionTreeFilterFactory tests that the DecisionTreeFilterFactory function
399+ // properly instantiates DecisionTreeFilter instances
400+ func TestDecisionTreeFilterFactory (t * testing.T ) {
401+
402+ leastKvCacheFilter := NewLeastKVCacheFilter ()
403+ leastQueueFilter := NewLeastQueueFilter ()
404+ loraAffinityFilter := NewLoraAffinityFilter (config .Conf .LoraAffinityThreshold )
405+ lowQueueFilter := NewLowQueueFilter (config .Conf .QueueingThresholdLoRA )
406+
407+ kvCacheScorer := scorer .NewKVCacheScorer ()
408+
409+ testHandle := utils .NewTestHandle ()
410+
411+ testHandle .Plugins ().AddPlugin ("leastKvCache" , leastKvCacheFilter )
412+ testHandle .Plugins ().AddPlugin ("leastQueue" , leastQueueFilter )
413+ testHandle .Plugins ().AddPlugin ("loraAffinity" , loraAffinityFilter )
414+ testHandle .Plugins ().AddPlugin ("lowQueue" , lowQueueFilter )
415+
416+ testHandle .Plugins ().AddPlugin ("kvCacheScorer" , kvCacheScorer )
417+
418+ tests := []struct {
419+ name string
420+ parameters string
421+ want * DecisionTreeFilter
422+ wantErr bool
423+ }{
424+ {
425+ name : "success" ,
426+ parameters : decisionTreeParametersSuccess ,
427+ want : & DecisionTreeFilter {
428+ Current : lowQueueFilter ,
429+ NextOnSuccess : & DecisionTreeFilter {
430+ Current : loraAffinityFilter ,
431+ NextOnSuccessOrFailure : & DecisionTreeFilter {
432+ Current : leastQueueFilter ,
433+ NextOnSuccessOrFailure : & DecisionTreeFilter {
434+ Current : leastKvCacheFilter ,
435+ },
436+ },
437+ },
438+ NextOnFailure : & DecisionTreeFilter {
439+ Current : leastQueueFilter ,
440+ NextOnSuccessOrFailure : & DecisionTreeFilter {
441+ Current : loraAffinityFilter ,
442+ NextOnSuccessOrFailure : & DecisionTreeFilter {
443+ Current : leastKvCacheFilter ,
444+ },
445+ },
446+ },
447+ },
448+ wantErr : false ,
449+ },
450+ {
451+ name : "bothError" ,
452+ parameters : decisionTreeParametersErrorBoth ,
453+ want : nil ,
454+ wantErr : true ,
455+ },
456+ {
457+ name : "noneError" ,
458+ parameters : decisionTreeParametersErrorNone ,
459+ want : nil ,
460+ wantErr : true ,
461+ },
462+ {
463+ name : "badPlugin" ,
464+ parameters : decisionTreeParametersErrorBadPlugin ,
465+ want : nil ,
466+ wantErr : true ,
467+ },
468+ {
469+ name : "notFilter" ,
470+ parameters : decisionTreeParametersErrorNotFilter ,
471+ want : nil ,
472+ wantErr : true ,
473+ },
474+ {
475+ name : "noCurrent" ,
476+ parameters : decisionTreeParametersErrorNoCurrent ,
477+ want : nil ,
478+ wantErr : true ,
479+ },
480+ {
481+ name : "badNextOnSuccess" ,
482+ parameters : decisionTreeParametersErrorBadNextOnSuccess ,
483+ want : nil ,
484+ wantErr : true ,
485+ },
486+ {
487+ name : "badNextOnFailure" ,
488+ parameters : decisionTreeParametersErrorBadNextOnFailure ,
489+ want : nil ,
490+ wantErr : true ,
491+ },
492+ {
493+ name : "badNextOnSuccessOrFailure" ,
494+ parameters : decisionTreeParametersErrorBadNextOnSuccessOrFailure ,
495+ want : nil ,
496+ wantErr : true ,
497+ },
498+ }
499+
500+ cmpOptions := cmpopts .IgnoreUnexported (LeastKVCacheFilter {}, LeastQueueFilter {},
501+ LoraAffinityFilter {}, LowQueueFilter {}, scorer.KVCacheScorer {})
502+
503+ for _ , test := range tests {
504+ rawParameters := struct {
505+ Parameters json.RawMessage `json:"parameters"`
506+ }{}
507+ err := json .Unmarshal ([]byte (test .parameters ), & rawParameters )
508+ if err != nil {
509+ if test .wantErr {
510+ continue
511+ } else {
512+ t .Fatal ("failed to parse JSON of test " + test .name )
513+ }
514+ }
515+ got , err := DecisionTreeFilterFactory ("testing" , rawParameters .Parameters , testHandle )
516+ if err != nil {
517+ if test .wantErr {
518+ continue
519+ }
520+ t .Fatalf ("failed to instantiate DecisionTreeFilter. error: %s\n " , err )
521+ }
522+ if test .wantErr {
523+ t .Fatalf ("test %s did not return the expected error" , test .name )
524+ }
525+ if diff := cmp .Diff (test .want , got , cmpOptions ); diff != "" {
526+ t .Fatalf ("In test %s DecisionTreeFactory returned unexpected response, diff(-want, +got): %v" , test .name , diff )
527+ }
528+ }
529+ }
530+
531+ const decisionTreeParametersSuccess = `
532+ {
533+ "parameters": {
534+ "current": {
535+ "pluginRef": "lowQueue"
536+ },
537+ "nextOnSuccess": {
538+ "decisionTree": {
539+ "current": {
540+ "pluginRef": "loraAffinity"
541+ },
542+ "nextOnSuccessOrFailure": {
543+ "decisionTree": {
544+ "current": {
545+ "pluginRef": "leastQueue"
546+ },
547+ "nextOnSuccessOrFailure": {
548+ "decisionTree": {
549+ "current": {
550+ "pluginRef": "leastKvCache"
551+ }
552+ }
553+ }
554+ }
555+ }
556+ }
557+ },
558+ "nextOnFailure": {
559+ "decisionTree": {
560+ "current": {
561+ "pluginRef": "leastQueue"
562+ },
563+ "nextOnSuccessOrFailure": {
564+ "decisionTree": {
565+ "current": {
566+ "pluginRef": "loraAffinity"
567+ },
568+ "nextOnSuccessOrFailure": {
569+ "decisionTree": {
570+ "current": {
571+ "pluginRef": "leastKvCache"
572+ }
573+ }
574+ }
575+ }
576+ }
577+ }
578+ }
579+ }
580+ }
581+ `
582+
583+ const decisionTreeParametersErrorBoth = `
584+ {
585+ "parameters": {
586+ "current": {
587+ "pluginRef": "lowQueue",
588+ "decisionTree": {
589+ "current": {
590+ "pluginRef": "leastKvCache"
591+ }
592+ }
593+ }
594+ }
595+ }
596+ `
597+
598+ const decisionTreeParametersErrorNone = `
599+ {
600+ "parameters": {
601+ "current": {
602+ }
603+ }
604+ }
605+ `
606+
607+ const decisionTreeParametersErrorBadPlugin = `
608+ {
609+ "parameters": {
610+ "current": {
611+ "pluginRef": "plover"
612+ }
613+ }
614+ }
615+ `
616+
617+ const decisionTreeParametersErrorNotFilter = `
618+ {
619+ "parameters": {
620+ "current": {
621+ "pluginRef": "kvCacheScorer"
622+ }
623+ }
624+ }
625+ `
626+
627+ const decisionTreeParametersErrorNoCurrent = `
628+ {
629+ "parameters": {
630+ "NextOnSuccess": {
631+ "pluginRef": "lowQueue"
632+ }
633+ }
634+ }
635+ `
636+
637+ const decisionTreeParametersErrorBadNextOnSuccess = `
638+ {
639+ "parameters": {
640+ "current": {
641+ "pluginRef": "lowQueue"
642+ },
643+ "NextOnSuccess": {
644+ "pluginRef": "kvCacheScorer"
645+ }
646+ }
647+ }
648+ `
649+
650+ const decisionTreeParametersErrorBadNextOnFailure = `
651+ {
652+ "parameters": {
653+ "current": {
654+ "pluginRef": "lowQueue"
655+ },
656+ "NextOnFailure": {
657+ "pluginRef": "kvCacheScorer"
658+ }
659+ }
660+ }
661+ `
662+
663+ const decisionTreeParametersErrorBadNextOnSuccessOrFailure = `
664+ {
665+ "parameters": {
666+ "current": {
667+ "pluginRef": "lowQueue"
668+ },
669+ "NextOnSuccessOrFailure": {
670+ "pluginRef": "kvCacheScorer"
671+ }
672+ }
673+ }
674+ `
0 commit comments