diff --git a/provider/pihole/client_test.go b/provider/pihole/client_test.go index 1a77505678..b94545b400 100644 --- a/provider/pihole/client_test.go +++ b/provider/pihole/client_test.go @@ -96,6 +96,32 @@ func TestNewPiholeClient(t *testing.T) { } } +// Helper function to validate records against expected values +func ValidateRecords(t *testing.T, records []*endpoint.Endpoint, expected [][]string, expectedCount int, recordType string) { + t.Helper() + if len(records) != expectedCount { + t.Fatalf("Expected %d %s records returned, got: %d", expectedCount, recordType, len(records)) + } + for idx, rec := range records { + if rec.DNSName != expected[idx][0] { + t.Errorf("Got invalid DNS Name: %s, expected: %s", rec.DNSName, expected[idx][0]) + } + if rec.Targets[0] != expected[idx][1] { + t.Errorf("Got invalid target: %s, expected: %s", rec.Targets[0], expected[idx][1]) + } + } +} + +// Helper function to test record retrieval for a specific type +func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, expected [][]string, expectedCount int) { + t.Helper() + records, err := cl.listRecords(context.Background(), recordType) + if err != nil { + t.Fatal(err) + } + ValidateRecords(t, records, expected, expectedCount, recordType) +} + func TestListRecords(t *testing.T) { srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { r.ParseForm() @@ -140,76 +166,27 @@ func TestListRecords(t *testing.T) { } // Test retrieve A records unfiltered - arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA) - if err != nil { - t.Fatal(err) - } - if len(arecs) != 3 { - t.Fatal("Expected 3 A records returned, got:", len(arecs)) - } - // Ensure records were parsed correctly - expected := [][]string{ + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{ {"test1.example.com", "192.168.1.1"}, {"test2.example.com", "192.168.1.2"}, {"test3.match.com", "192.168.1.3"}, - } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - } + }, 3) // Test retrieve AAAA records unfiltered - arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA) - if err != nil { - t.Fatal(err) - } - if len(arecs) != 3 { - t.Fatal("Expected 3 AAAA records returned, got:", len(arecs)) - } - // Ensure records were parsed correctly - expected = [][]string{ + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{ {"test1.example.com", "fc00::1:192:168:1:1"}, {"test2.example.com", "fc00::1:192:168:1:2"}, {"test3.match.com", "fc00::1:192:168:1:3"}, - } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - } + }, 3) // Test retrieve CNAME records unfiltered - cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME) - if err != nil { - t.Fatal(err) - } - if len(cnamerecs) != 3 { - t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs)) - } - // Ensure records were parsed correctly - expected = [][]string{ + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{ {"test4.example.com", "cname.example.com"}, {"test5.example.com", "cname.example.com"}, {"test6.match.com", "cname.example.com"}, - } - for idx, rec := range cnamerecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - } + }, 3) // Same tests but with a domain filter - cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"}) cl, err = newPiholeClient(cfg) if err != nil { @@ -217,68 +194,53 @@ func TestListRecords(t *testing.T) { } // Test retrieve A records filtered - arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeA) - if err != nil { - t.Fatal(err) - } - if len(arecs) != 1 { - t.Fatal("Expected 1 A record returned, got:", len(arecs)) - } - // Ensure records were parsed correctly - expected = [][]string{ + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{ {"test3.match.com", "192.168.1.3"}, - } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - } + }, 1) // Test retrieve AAAA records filtered - arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA) - if err != nil { - t.Fatal(err) - } - if len(arecs) != 1 { - t.Fatal("Expected 1 AAAA record returned, got:", len(arecs)) - } - // Ensure records were parsed correctly - expected = [][]string{ + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{ {"test3.match.com", "fc00::1:192:168:1:3"}, - } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - } + }, 1) // Test retrieve CNAME records filtered - cnamerecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeCNAME) + CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{ + {"test6.match.com", "cname.example.com"}, + }, 1) + +} + +// Helper function to test error scenarios +func testErrorScenarios(t *testing.T, srvrErr *httptest.Server) { + t.Helper() + cfgExpired := PiholeConfig{ + Server: srvrErr.URL, + } + clExpired, err := newPiholeClient(cfgExpired) if err != nil { t.Fatal(err) } - if len(cnamerecs) != 1 { - t.Fatal("Expected 1 CNAME record returned, got:", len(cnamerecs)) + //set clExpired.token to a valid token + clExpired.(*piholeClient).token = "expired" + clExpired.(*piholeClient).cfg.Password = "notcorrect" + + cnamerecs, err := clExpired.listRecords(context.Background(), "notarealrecordtype") + if err == nil { + t.Fatal("Should return error, type is unknown ! ") } - // Ensure records were parsed correctly - expected = [][]string{ - {"test6.match.com", "cname.example.com"}, + cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) + if err == nil { + t.Fatal("Should return error on failed auth ! ") } - for idx, rec := range cnamerecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } + clExpired.(*piholeClient).token = "correct" + clExpired.(*piholeClient).cfg.Password = "correct" + cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) + if len(cnamerecs) != 0 { + t.Fatal("Should return empty on missing data in response ! ") } +} +func TestErrorScenarios(t *testing.T) { // Test errors token srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { r.ParseForm() @@ -318,6 +280,7 @@ func TestListRecords(t *testing.T) { `)) }) defer srvrErr.Close() + cfgExpired := PiholeConfig{ Server: srvrErr.URL, } @@ -329,21 +292,23 @@ func TestListRecords(t *testing.T) { clExpired.(*piholeClient).token = "expired" clExpired.(*piholeClient).cfg.Password = "notcorrect" - cnamerecs, err = clExpired.listRecords(context.Background(), "notarealrecordtype") + _, err = clExpired.listRecords(context.Background(), "notarealrecordtype") if err == nil { t.Fatal("Should return error, type is unknown ! ") } - cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) + _, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) if err == nil { t.Fatal("Should return error on failed auth ! ") } clExpired.(*piholeClient).token = "correct" clExpired.(*piholeClient).cfg.Password = "correct" - cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) + cnamerecs, err := clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME) + if err != nil { + t.Fatal(err) + } if len(cnamerecs) != 0 { t.Fatal("Should return empty on missing data in response ! ") } - } func TestCreateRecord(t *testing.T) {