diff --git a/provider/cloudflare/cloudflare.go b/provider/cloudflare/cloudflare.go index c7ff356585..c08429c617 100644 --- a/provider/cloudflare/cloudflare.go +++ b/provider/cloudflare/cloudflare.go @@ -110,7 +110,7 @@ type cloudFlareDNS interface { ZoneIDByName(zoneName string) (string, error) ListZones(ctx context.Context, params zones.ZoneListParams) autoPager[zones.Zone] GetZone(ctx context.Context, zoneID string) (*zones.Zone, error) - ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) + ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] CreateDNSRecord(ctx context.Context, params dns.RecordNewParams) (*dns.RecordResponse, error) DeleteDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, recordID string) error UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error @@ -152,13 +152,8 @@ func (z zoneService) CreateDNSRecord(ctx context.Context, params dns.RecordNewPa return z.service.DNS.Records.New(ctx, params) } -func (z zoneService) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) { - records, info, err := z.serviceV0.ListDNSRecords(ctx, rc, rp) - convertedRecords := make([]dns.RecordResponse, 0, len(records)) - for _, record := range records { - convertedRecords = append(convertedRecords, dnsRecordResponseFromLegacyDNSRecord(record)) - } - return convertedRecords, info, err +func (z zoneService) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] { + return z.service.DNS.Records.ListAutoPaging(ctx, params) } func (z zoneService) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error { @@ -428,7 +423,7 @@ func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, var endpoints []*endpoint.Endpoint for _, zone := range zones { - records, err := p.listDNSRecordsWithAutoPagination(ctx, zone.ID) + records, err := p.getDNSRecordsMap(ctx, zone.ID) if err != nil { return nil, err } @@ -643,7 +638,7 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud continue } - records, err := p.listDNSRecordsWithAutoPagination(ctx, zoneID) + records, err := p.getDNSRecordsMap(ctx, zoneID) if err != nil { return fmt.Errorf("could not fetch records from zone, %w", err) } @@ -860,27 +855,19 @@ func newDNSRecordIndex(r dns.RecordResponse) DNSRecordIndex { return DNSRecordIndex{Name: r.Name, Type: string(r.Type), Content: r.Content} } -// listDNSRecordsWithAutoPagination performs automatic pagination of results on requests to cloudflare.ListDNSRecords with custom per_page values -func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Context, zoneID string) (DNSRecordsMap, error) { +// getDNSRecordsMap retrieves all DNS records for a given zone and returns them as a DNSRecordsMap. +func (p *CloudFlareProvider) getDNSRecordsMap(ctx context.Context, zoneID string) (DNSRecordsMap, error) { // for faster getRecordID lookup - records := make(DNSRecordsMap) - resultInfo := cloudflarev0.ResultInfo{PerPage: p.DNSRecordsConfig.PerPage, Page: 1} - params := cloudflarev0.ListDNSRecordsParams{ResultInfo: resultInfo} - for { - pageRecords, resultInfo, err := p.Client.ListDNSRecords(ctx, cloudflarev0.ZoneIdentifier(zoneID), params) - if err != nil { - return nil, convertCloudflareError(err) - } - - for _, r := range pageRecords { - records[newDNSRecordIndex(r)] = r - } - params.ResultInfo = resultInfo.Next() - if params.Done() { - break - } + recordsMap := make(DNSRecordsMap) + params := dns.RecordListParams{ZoneID: cloudflare.F(zoneID)} + iter := p.Client.ListDNSRecords(ctx, params) + for record := range autoPagerIterator(iter) { + recordsMap[newDNSRecordIndex(record)] = record + } + if iter.Err() != nil { + return nil, convertCloudflareError(iter.Err()) } - return records, nil + return recordsMap, nil } func newCustomHostnameIndex(ch cloudflarev0.CustomHostname) CustomHostnameIndex { diff --git a/provider/cloudflare/cloudflare_test.go b/provider/cloudflare/cloudflare_test.go index fbe7ede5c4..86bb551b99 100644 --- a/provider/cloudflare/cloudflare_test.go +++ b/provider/cloudflare/cloudflare_test.go @@ -22,12 +22,12 @@ import ( "fmt" "os" "slices" - "sort" "strings" "testing" "time" cloudflarev0 "github.com/cloudflare/cloudflare-go" + "github.com/cloudflare/cloudflare-go/v5" "github.com/cloudflare/cloudflare-go/v5/dns" "github.com/cloudflare/cloudflare-go/v5/zones" "github.com/maxatome/go-testdeep/td" @@ -171,49 +171,22 @@ func (m *mockCloudFlareClient) CreateDNSRecord(ctx context.Context, params dns.R return &record, nil } -func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) { +func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] { if m.dnsRecordsError != nil { - return nil, &cloudflarev0.ResultInfo{}, m.dnsRecordsError + return &mockAutoPager[dns.RecordResponse]{err: m.dnsRecordsError} } - result := []dns.RecordResponse{} - if zone, ok := m.Records[rc.Identifier]; ok { + iter := &mockAutoPager[dns.RecordResponse]{} + if zone, ok := m.Records[params.ZoneID.Value]; ok { for _, record := range zone { if strings.HasPrefix(record.Name, "newerror-list-") { - m.DeleteDNSRecord(ctx, rc, record.ID) - return nil, &cloudflarev0.ResultInfo{}, errors.New("failed to list erroring DNS record") + m.DeleteDNSRecord(ctx, cloudflarev0.ResourceIdentifier(params.ZoneID.Value), record.ID) + iter.err = errors.New("failed to list erroring DNS record") + return iter } - result = append(result, record) + iter.items = append(iter.items, record) } } - - if len(result) == 0 || rp.PerPage == 0 { - return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: 0, Total: 0}, nil - } - - // if not pagination options were passed in, return the result as is - if rp.Page == 0 { - return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: len(result), Total: len(result)}, nil - } - - // otherwise, split the result into chunks of size rp.PerPage to simulate the pagination from the API - chunks := [][]dns.RecordResponse{} - - // to ensure consistency in the multiple calls to this function, sort the result slice - sort.Slice(result, func(i, j int) bool { return strings.Compare(result[i].ID, result[j].ID) > 0 }) - for rp.PerPage < len(result) { - result, chunks = result[rp.PerPage:], append(chunks, result[0:rp.PerPage]) - } - chunks = append(chunks, result) - - // return the requested page - partialResult := chunks[rp.Page-1] - return partialResult, &cloudflarev0.ResultInfo{ - PerPage: rp.PerPage, - Page: rp.Page, - TotalPages: len(chunks), - Count: len(partialResult), - Total: len(result), - }, nil + return iter } func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error { @@ -1500,7 +1473,7 @@ func TestGroupByNameAndTypeWithCustomHostnames_MX(t *testing.T) { } ctx := context.Background() chs := CustomHostnamesMap{} - records, err := provider.listDNSRecordsWithAutoPagination(ctx, "001") + records, err := provider.getDNSRecordsMap(ctx, "001") assert.NoError(t, err) endpoints := provider.groupByNameAndTypeWithCustomHostnames(records, chs) @@ -3346,3 +3319,22 @@ func TestDnsRecordFromLegacyAPI(t *testing.T) { }) } } + +func TestZoneService(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + client := &zoneService{ + service: cloudflare.NewClient(), + } + + t.Run("UpdateDNSRecord", func(t *testing.T) { + t.Parallel() + iter := client.ListDNSRecords(ctx, dns.RecordListParams{ZoneID: cloudflare.F("foo")}) + require.False(t, iter.Next()) + require.Empty(t, iter.Current()) + require.ErrorIs(t, iter.Err(), context.Canceled) + }) +}