Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 73 additions & 108 deletions provider/pihole/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,32 @@ func TestNewPiholeClient(t *testing.T) {
}
}

// Helper function to validate records against expected values
func ValidateRecords(t *testing.T, records []*endpoint.Endpoint, expected [][]string, expectedCount int, recordType string) {
t.Helper()
if len(records) != expectedCount {
t.Fatalf("Expected %d %s records returned, got: %d", expectedCount, recordType, len(records))
}
for idx, rec := range records {
if rec.DNSName != expected[idx][0] {
t.Errorf("Got invalid DNS Name: %s, expected: %s", rec.DNSName, expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Errorf("Got invalid target: %s, expected: %s", rec.Targets[0], expected[idx][1])
}
}
}

// Helper function to test record retrieval for a specific type
func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, expected [][]string, expectedCount int) {
t.Helper()
records, err := cl.listRecords(context.Background(), recordType)
if err != nil {
t.Fatal(err)
}
ValidateRecords(t, records, expected, expectedCount, recordType)
}

func TestListRecords(t *testing.T) {
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
Expand Down Expand Up @@ -140,145 +166,81 @@ func TestListRecords(t *testing.T) {
}

// Test retrieve A records unfiltered
arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 3 {
t.Fatal("Expected 3 A records returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected := [][]string{
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
{"test1.example.com", "192.168.1.1"},
{"test2.example.com", "192.168.1.2"},
{"test3.match.com", "192.168.1.3"},
}
for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
}
}, 3)

// Test retrieve AAAA records unfiltered
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 3 {
t.Fatal("Expected 3 AAAA records returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected = [][]string{
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
{"test1.example.com", "fc00::1:192:168:1:1"},
{"test2.example.com", "fc00::1:192:168:1:2"},
{"test3.match.com", "fc00::1:192:168:1:3"},
}
for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
}
}, 3)

// Test retrieve CNAME records unfiltered
cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err != nil {
t.Fatal(err)
}
if len(cnamerecs) != 3 {
t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs))
}
// Ensure records were parsed correctly
expected = [][]string{
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
{"test4.example.com", "cname.example.com"},
{"test5.example.com", "cname.example.com"},
{"test6.match.com", "cname.example.com"},
}
for idx, rec := range cnamerecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
}
}, 3)

// Same tests but with a domain filter

cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"})
cl, err = newPiholeClient(cfg)
if err != nil {
t.Fatal(err)
}

// Test retrieve A records filtered
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 1 {
t.Fatal("Expected 1 A record returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected = [][]string{
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
{"test3.match.com", "192.168.1.3"},
}
for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
}
}, 1)

// Test retrieve AAAA records filtered
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 1 {
t.Fatal("Expected 1 AAAA record returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected = [][]string{
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
{"test3.match.com", "fc00::1:192:168:1:3"},
}
for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
}
}, 1)

// Test retrieve CNAME records filtered
cnamerecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
{"test6.match.com", "cname.example.com"},
}, 1)

}

// Helper function to test error scenarios
func testErrorScenarios(t *testing.T, srvrErr *httptest.Server) {
t.Helper()
cfgExpired := PiholeConfig{
Server: srvrErr.URL,
}
clExpired, err := newPiholeClient(cfgExpired)
if err != nil {
t.Fatal(err)
}
if len(cnamerecs) != 1 {
t.Fatal("Expected 1 CNAME record returned, got:", len(cnamerecs))
//set clExpired.token to a valid token
clExpired.(*piholeClient).token = "expired"
clExpired.(*piholeClient).cfg.Password = "notcorrect"

cnamerecs, err := clExpired.listRecords(context.Background(), "notarealrecordtype")
if err == nil {
t.Fatal("Should return error, type is unknown ! ")
}
// Ensure records were parsed correctly
expected = [][]string{
{"test6.match.com", "cname.example.com"},
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err == nil {
t.Fatal("Should return error on failed auth ! ")
}
for idx, rec := range cnamerecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
}
clExpired.(*piholeClient).token = "correct"
clExpired.(*piholeClient).cfg.Password = "correct"
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if len(cnamerecs) != 0 {
t.Fatal("Should return empty on missing data in response ! ")
}
}

func TestErrorScenarios(t *testing.T) {
// Test errors token
srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
Expand Down Expand Up @@ -318,6 +280,7 @@ func TestListRecords(t *testing.T) {
`))
})
defer srvrErr.Close()

cfgExpired := PiholeConfig{
Server: srvrErr.URL,
}
Expand All @@ -329,21 +292,23 @@ func TestListRecords(t *testing.T) {
clExpired.(*piholeClient).token = "expired"
clExpired.(*piholeClient).cfg.Password = "notcorrect"

cnamerecs, err = clExpired.listRecords(context.Background(), "notarealrecordtype")
_, err = clExpired.listRecords(context.Background(), "notarealrecordtype")
if err == nil {
t.Fatal("Should return error, type is unknown ! ")
}
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
_, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err == nil {
t.Fatal("Should return error on failed auth ! ")
}
clExpired.(*piholeClient).token = "correct"
clExpired.(*piholeClient).cfg.Password = "correct"
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
cnamerecs, err := clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err != nil {
t.Fatal(err)
}
if len(cnamerecs) != 0 {
t.Fatal("Should return empty on missing data in response ! ")
}

}

func TestCreateRecord(t *testing.T) {
Expand Down
Loading