Skip to content

Commit b88054e

Browse files
committed
Added tests of the factory function of the DecisionTreeFilter
Signed-off-by: Shmuel Kallner <[email protected]>
1 parent 9e6223b commit b88054e

File tree

1 file changed

+282
-0
lines changed

1 file changed

+282
-0
lines changed

pkg/epp/scheduling/framework/plugins/filter/filter_test.go

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@ package filter
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"testing"
2223

2324
"github.com/google/go-cmp/cmp"
25+
"github.com/google/go-cmp/cmp/cmpopts"
2426
"github.com/google/uuid"
2527
k8stypes "k8s.io/apimachinery/pkg/types"
2628
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2729
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2830
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2931
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
3033
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
34+
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
3135
)
3236

3337
// compile-time type assertion
@@ -390,3 +394,281 @@ func TestSubsettingFilter(t *testing.T) {
390394
})
391395
}
392396
}
397+
398+
// TestDecisionTreeFilterFactory tests that the DecisionTreeFilterFactory function
399+
// properly instantiates DecisionTreeFilter instances
400+
func TestDecisionTreeFilterFactory(t *testing.T) {
401+
402+
leastKvCacheFilter := NewLeastKVCacheFilter()
403+
leastQueueFilter := NewLeastQueueFilter()
404+
loraAffinityFilter := NewLoraAffinityFilter(config.Conf.LoraAffinityThreshold)
405+
lowQueueFilter := NewLowQueueFilter(config.Conf.QueueingThresholdLoRA)
406+
407+
kvCacheScorer := scorer.NewKVCacheScorer()
408+
409+
testHandle := utils.NewTestHandle()
410+
411+
testHandle.Plugins().AddPlugin("leastKvCache", leastKvCacheFilter)
412+
testHandle.Plugins().AddPlugin("leastQueue", leastQueueFilter)
413+
testHandle.Plugins().AddPlugin("loraAffinity", loraAffinityFilter)
414+
testHandle.Plugins().AddPlugin("lowQueue", lowQueueFilter)
415+
416+
testHandle.Plugins().AddPlugin("kvCacheScorer", kvCacheScorer)
417+
418+
tests := []struct {
419+
name string
420+
parameters string
421+
want *DecisionTreeFilter
422+
wantErr bool
423+
}{
424+
{
425+
name: "success",
426+
parameters: decisionTreeParametersSuccess,
427+
want: &DecisionTreeFilter{
428+
Current: lowQueueFilter,
429+
NextOnSuccess: &DecisionTreeFilter{
430+
Current: loraAffinityFilter,
431+
NextOnSuccessOrFailure: &DecisionTreeFilter{
432+
Current: leastQueueFilter,
433+
NextOnSuccessOrFailure: &DecisionTreeFilter{
434+
Current: leastKvCacheFilter,
435+
},
436+
},
437+
},
438+
NextOnFailure: &DecisionTreeFilter{
439+
Current: leastQueueFilter,
440+
NextOnSuccessOrFailure: &DecisionTreeFilter{
441+
Current: loraAffinityFilter,
442+
NextOnSuccessOrFailure: &DecisionTreeFilter{
443+
Current: leastKvCacheFilter,
444+
},
445+
},
446+
},
447+
},
448+
wantErr: false,
449+
},
450+
{
451+
name: "bothError",
452+
parameters: decisionTreeParametersErrorBoth,
453+
want: nil,
454+
wantErr: true,
455+
},
456+
{
457+
name: "noneError",
458+
parameters: decisionTreeParametersErrorNone,
459+
want: nil,
460+
wantErr: true,
461+
},
462+
{
463+
name: "badPlugin",
464+
parameters: decisionTreeParametersErrorBadPlugin,
465+
want: nil,
466+
wantErr: true,
467+
},
468+
{
469+
name: "notFilter",
470+
parameters: decisionTreeParametersErrorNotFilter,
471+
want: nil,
472+
wantErr: true,
473+
},
474+
{
475+
name: "noCurrent",
476+
parameters: decisionTreeParametersErrorNoCurrent,
477+
want: nil,
478+
wantErr: true,
479+
},
480+
{
481+
name: "badNextOnSuccess",
482+
parameters: decisionTreeParametersErrorBadNextOnSuccess,
483+
want: nil,
484+
wantErr: true,
485+
},
486+
{
487+
name: "badNextOnFailure",
488+
parameters: decisionTreeParametersErrorBadNextOnFailure,
489+
want: nil,
490+
wantErr: true,
491+
},
492+
{
493+
name: "badNextOnSuccessOrFailure",
494+
parameters: decisionTreeParametersErrorBadNextOnSuccessOrFailure,
495+
want: nil,
496+
wantErr: true,
497+
},
498+
}
499+
500+
cmpOptions := cmpopts.IgnoreUnexported(LeastKVCacheFilter{}, LeastQueueFilter{},
501+
LoraAffinityFilter{}, LowQueueFilter{}, scorer.KVCacheScorer{})
502+
503+
for _, test := range tests {
504+
rawParameters := struct {
505+
Parameters json.RawMessage `json:"parameters"`
506+
}{}
507+
err := json.Unmarshal([]byte(test.parameters), &rawParameters)
508+
if err != nil {
509+
if test.wantErr {
510+
continue
511+
} else {
512+
t.Fatal("failed to parse JSON of test " + test.name)
513+
}
514+
}
515+
got, err := DecisionTreeFilterFactory("testing", rawParameters.Parameters, testHandle)
516+
if err != nil {
517+
if test.wantErr {
518+
continue
519+
}
520+
t.Fatalf("failed to instantiate DecisionTreeFilter. error: %s\n", err)
521+
}
522+
if test.wantErr {
523+
t.Fatalf("test %s did not return the expected error", test.name)
524+
}
525+
if diff := cmp.Diff(test.want, got, cmpOptions); diff != "" {
526+
t.Fatalf("In test %s DecisionTreeFactory returned unexpected response, diff(-want, +got): %v", test.name, diff)
527+
}
528+
}
529+
}
530+
531+
const decisionTreeParametersSuccess = `
532+
{
533+
"parameters": {
534+
"current": {
535+
"pluginRef": "lowQueue"
536+
},
537+
"nextOnSuccess": {
538+
"decisionTree": {
539+
"current": {
540+
"pluginRef": "loraAffinity"
541+
},
542+
"nextOnSuccessOrFailure": {
543+
"decisionTree": {
544+
"current": {
545+
"pluginRef": "leastQueue"
546+
},
547+
"nextOnSuccessOrFailure": {
548+
"decisionTree": {
549+
"current": {
550+
"pluginRef": "leastKvCache"
551+
}
552+
}
553+
}
554+
}
555+
}
556+
}
557+
},
558+
"nextOnFailure": {
559+
"decisionTree": {
560+
"current": {
561+
"pluginRef": "leastQueue"
562+
},
563+
"nextOnSuccessOrFailure": {
564+
"decisionTree": {
565+
"current": {
566+
"pluginRef": "loraAffinity"
567+
},
568+
"nextOnSuccessOrFailure": {
569+
"decisionTree": {
570+
"current": {
571+
"pluginRef": "leastKvCache"
572+
}
573+
}
574+
}
575+
}
576+
}
577+
}
578+
}
579+
}
580+
}
581+
`
582+
583+
const decisionTreeParametersErrorBoth = `
584+
{
585+
"parameters": {
586+
"current": {
587+
"pluginRef": "lowQueue",
588+
"decisionTree": {
589+
"current": {
590+
"pluginRef": "leastKvCache"
591+
}
592+
}
593+
}
594+
}
595+
}
596+
`
597+
598+
const decisionTreeParametersErrorNone = `
599+
{
600+
"parameters": {
601+
"current": {
602+
}
603+
}
604+
}
605+
`
606+
607+
const decisionTreeParametersErrorBadPlugin = `
608+
{
609+
"parameters": {
610+
"current": {
611+
"pluginRef": "plover"
612+
}
613+
}
614+
}
615+
`
616+
617+
const decisionTreeParametersErrorNotFilter = `
618+
{
619+
"parameters": {
620+
"current": {
621+
"pluginRef": "kvCacheScorer"
622+
}
623+
}
624+
}
625+
`
626+
627+
const decisionTreeParametersErrorNoCurrent = `
628+
{
629+
"parameters": {
630+
"NextOnSuccess": {
631+
"pluginRef": "lowQueue"
632+
}
633+
}
634+
}
635+
`
636+
637+
const decisionTreeParametersErrorBadNextOnSuccess = `
638+
{
639+
"parameters": {
640+
"current": {
641+
"pluginRef": "lowQueue"
642+
},
643+
"NextOnSuccess": {
644+
"pluginRef": "kvCacheScorer"
645+
}
646+
}
647+
}
648+
`
649+
650+
const decisionTreeParametersErrorBadNextOnFailure = `
651+
{
652+
"parameters": {
653+
"current": {
654+
"pluginRef": "lowQueue"
655+
},
656+
"NextOnFailure": {
657+
"pluginRef": "kvCacheScorer"
658+
}
659+
}
660+
}
661+
`
662+
663+
const decisionTreeParametersErrorBadNextOnSuccessOrFailure = `
664+
{
665+
"parameters": {
666+
"current": {
667+
"pluginRef": "lowQueue"
668+
},
669+
"NextOnSuccessOrFailure": {
670+
"pluginRef": "kvCacheScorer"
671+
}
672+
}
673+
}
674+
`

0 commit comments

Comments
 (0)