diff --git a/provider/pihole/pihole_test.go b/provider/pihole/pihole_test.go index 5c99c13940..ca6d5d7b54 100644 --- a/provider/pihole/pihole_test.go +++ b/provider/pihole/pihole_test.go @@ -81,12 +81,11 @@ func TestNewPiholeProvider(t *testing.T) { } } -func TestProvider(t *testing.T) { +func TestProvider_InitialState(t *testing.T) { requests := requestTracker{} p := &PiholeProvider{ api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, } - records, err := p.Records(context.Background()) if err != nil { t.Fatal(err) @@ -94,9 +93,14 @@ func TestProvider(t *testing.T) { if len(records) != 0 { t.Fatal("Expected empty list of records, got:", records) } +} - // Populate the provider with records - records = []*endpoint.Endpoint{ +func TestProvider_CreateRecords(t *testing.T) { + requests := requestTracker{} + p := &PiholeProvider{ + api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + } + records := []*endpoint.Endpoint{ { DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, @@ -133,9 +137,6 @@ func TestProvider(t *testing.T) { }); err != nil { t.Fatal(err) } - - // Test records are correct on retrieval - newRecords, err := p.Records(context.Background()) if err != nil { t.Fatal(err) @@ -149,7 +150,6 @@ func TestProvider(t *testing.T) { if len(requests.deleteRequests) != 0 { t.Fatal("Expected no delete requests, got:", requests.deleteRequests) } - for idx, record := range records { if newRecords[idx].DNSName != record.DNSName { t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) @@ -157,17 +157,19 @@ func TestProvider(t *testing.T) { if newRecords[idx].Targets[0] != record.Targets[0] { t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) } - if !reflect.DeepEqual(requests.createRequests[idx], record) { t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName) } } - requests.clear() +} - // Test delete a record - - records = []*endpoint.Endpoint{ +func TestProvider_DeleteRecords(t *testing.T) { + requests := requestTracker{} + p := &PiholeProvider{ + api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + } + records := []*endpoint.Endpoint{ { DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, @@ -189,6 +191,12 @@ func TestProvider(t *testing.T) { RecordType: endpoint.RecordTypeAAAA, }, } + // Create initial records + if err := p.ApplyChanges(context.Background(), &plan.Changes{ + Create: records, + }); err != nil { + t.Fatal(err) + } recordToDeleteA := endpoint.Endpoint{ DNSName: "test3.example.com", Targets: []string{"192.168.1.3"}, @@ -213,22 +221,19 @@ func TestProvider(t *testing.T) { }); err != nil { t.Fatal(err) } - - // Test records are updated - newRecords, err = p.Records(context.Background()) + newRecords, err := p.Records(context.Background()) if err != nil { t.Fatal(err) } if len(newRecords) != 4 { t.Fatal("Expected list of 4 records, got:", records) } - if len(requests.createRequests) != 0 { - t.Fatal("Expected no create requests, got:", requests.createRequests) + if len(requests.createRequests) != 4 { + t.Fatal("Expected 4 create requests, got:", requests.createRequests) } if len(requests.deleteRequests) != 2 { t.Fatal("Expected 2 delete request, got:", requests.deleteRequests) } - for idx, record := range records { if newRecords[idx].DNSName != record.DNSName { t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) @@ -237,19 +242,22 @@ func TestProvider(t *testing.T) { t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) } } - if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDeleteA) { t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDeleteA) } if !reflect.DeepEqual(requests.deleteRequests[1], &recordToDeleteAAAA) { t.Error("Unexpected delete request, got:", requests.deleteRequests[1], "expected:", recordToDeleteAAAA) } - requests.clear() +} - // Test update a record - - records = []*endpoint.Endpoint{ +func TestProvider_UpdateRecords(t *testing.T) { + requests := requestTracker{} + p := &PiholeProvider{ + api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + } + // Create initial records + initialRecords := []*endpoint.Endpoint{ { DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, @@ -257,7 +265,7 @@ func TestProvider(t *testing.T) { }, { DNSName: "test2.example.com", - Targets: []string{"10.0.0.1"}, + Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA, }, { @@ -267,61 +275,68 @@ func TestProvider(t *testing.T) { }, { DNSName: "test2.example.com", - Targets: []string{"fc00::1:10:0:0:1"}, + Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA, }, } if err := p.ApplyChanges(context.Background(), &plan.Changes{ - UpdateOld: []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"192.168.1.2"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:192:168:1:2"}, - RecordType: endpoint.RecordTypeAAAA, - }, + Create: initialRecords, + }); err != nil { + t.Fatal(err) + } + requests.clear() + // Update records + updateOld := []*endpoint.Endpoint{ + { + DNSName: "test1.example.com", + Targets: []string{"192.168.1.1"}, + RecordType: endpoint.RecordTypeA, }, - UpdateNew: []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"10.0.0.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:10:0:0:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, + { + DNSName: "test2.example.com", + Targets: []string{"192.168.1.2"}, + RecordType: endpoint.RecordTypeA, + }, + { + DNSName: "test1.example.com", + Targets: []string{"fc00::1:192:168:1:1"}, + RecordType: endpoint.RecordTypeAAAA, }, + { + DNSName: "test2.example.com", + Targets: []string{"fc00::1:192:168:1:2"}, + RecordType: endpoint.RecordTypeAAAA, + }, + } + updateNew := []*endpoint.Endpoint{ + { + DNSName: "test1.example.com", + Targets: []string{"192.168.1.1"}, + RecordType: endpoint.RecordTypeA, + }, + { + DNSName: "test2.example.com", + Targets: []string{"10.0.0.1"}, + RecordType: endpoint.RecordTypeA, + }, + { + DNSName: "test1.example.com", + Targets: []string{"fc00::1:192:168:1:1"}, + RecordType: endpoint.RecordTypeAAAA, + }, + { + DNSName: "test2.example.com", + Targets: []string{"fc00::1:10:0:0:1"}, + RecordType: endpoint.RecordTypeAAAA, + }, + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{ + UpdateOld: updateOld, + UpdateNew: updateNew, }); err != nil { t.Fatal(err) } - - // Test records are updated - newRecords, err = p.Records(context.Background()) + newRecords, err := p.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -334,8 +349,7 @@ func TestProvider(t *testing.T) { if len(requests.deleteRequests) != 2 { t.Fatal("Expected 2 delete request, got:", requests.deleteRequests) } - - for idx, record := range records { + for idx, record := range updateNew { if newRecords[idx].DNSName != record.DNSName { t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) } @@ -343,7 +357,6 @@ func TestProvider(t *testing.T) { t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) } } - expectedCreateA := endpoint.Endpoint{ DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, @@ -364,7 +377,6 @@ func TestProvider(t *testing.T) { Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA, } - for _, request := range requests.createRequests { switch request.RecordType { case endpoint.RecordTypeA: @@ -375,10 +387,8 @@ func TestProvider(t *testing.T) { if !reflect.DeepEqual(request, &expectedCreateAAAA) { t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateAAAA) } - default: } } - for _, request := range requests.deleteRequests { switch request.RecordType { case endpoint.RecordTypeA: @@ -389,9 +399,7 @@ func TestProvider(t *testing.T) { if !reflect.DeepEqual(request, &expectedDeleteAAAA) { t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteAAAA) } - default: } } - requests.clear() }