Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 16 additions & 29 deletions provider/cloudflare/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ type cloudFlareDNS interface {
ZoneIDByName(zoneName string) (string, error)
ListZones(ctx context.Context, params zones.ZoneListParams) autoPager[zones.Zone]
GetZone(ctx context.Context, zoneID string) (*zones.Zone, error)
ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error)
ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse]
CreateDNSRecord(ctx context.Context, params dns.RecordNewParams) (*dns.RecordResponse, error)
DeleteDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, recordID string) error
UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error
Expand Down Expand Up @@ -152,13 +152,8 @@ func (z zoneService) CreateDNSRecord(ctx context.Context, params dns.RecordNewPa
return z.service.DNS.Records.New(ctx, params)
}

func (z zoneService) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) {
records, info, err := z.serviceV0.ListDNSRecords(ctx, rc, rp)
convertedRecords := make([]dns.RecordResponse, 0, len(records))
for _, record := range records {
convertedRecords = append(convertedRecords, dnsRecordResponseFromLegacyDNSRecord(record))
}
return convertedRecords, info, err
func (z zoneService) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] {
return z.service.DNS.Records.ListAutoPaging(ctx, params)
Comment thread
vflaux marked this conversation as resolved.
}

func (z zoneService) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error {
Expand Down Expand Up @@ -428,7 +423,7 @@ func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint,

var endpoints []*endpoint.Endpoint
for _, zone := range zones {
records, err := p.listDNSRecordsWithAutoPagination(ctx, zone.ID)
records, err := p.getDNSRecordsMap(ctx, zone.ID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -643,7 +638,7 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud
continue
}

records, err := p.listDNSRecordsWithAutoPagination(ctx, zoneID)
records, err := p.getDNSRecordsMap(ctx, zoneID)
if err != nil {
return fmt.Errorf("could not fetch records from zone, %w", err)
}
Expand Down Expand Up @@ -860,27 +855,19 @@ func newDNSRecordIndex(r dns.RecordResponse) DNSRecordIndex {
return DNSRecordIndex{Name: r.Name, Type: string(r.Type), Content: r.Content}
}

// listDNSRecordsWithAutoPagination performs automatic pagination of results on requests to cloudflare.ListDNSRecords with custom per_page values
func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Context, zoneID string) (DNSRecordsMap, error) {
// getDNSRecordsMap retrieves all DNS records for a given zone and returns them as a DNSRecordsMap.
func (p *CloudFlareProvider) getDNSRecordsMap(ctx context.Context, zoneID string) (DNSRecordsMap, error) {
// for faster getRecordID lookup
records := make(DNSRecordsMap)
resultInfo := cloudflarev0.ResultInfo{PerPage: p.DNSRecordsConfig.PerPage, Page: 1}
params := cloudflarev0.ListDNSRecordsParams{ResultInfo: resultInfo}
for {
pageRecords, resultInfo, err := p.Client.ListDNSRecords(ctx, cloudflarev0.ZoneIdentifier(zoneID), params)
if err != nil {
return nil, convertCloudflareError(err)
}

for _, r := range pageRecords {
records[newDNSRecordIndex(r)] = r
}
params.ResultInfo = resultInfo.Next()
if params.Done() {
break
}
recordsMap := make(DNSRecordsMap)
params := dns.RecordListParams{ZoneID: cloudflare.F(zoneID)}
iter := p.Client.ListDNSRecords(ctx, params)
for record := range autoPagerIterator(iter) {
recordsMap[newDNSRecordIndex(record)] = record
}
if iter.Err() != nil {
return nil, convertCloudflareError(iter.Err())
}
return records, nil
return recordsMap, nil
}

func newCustomHostnameIndex(ch cloudflarev0.CustomHostname) CustomHostnameIndex {
Expand Down
68 changes: 30 additions & 38 deletions provider/cloudflare/cloudflare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"testing"
"time"

cloudflarev0 "github.com/cloudflare/cloudflare-go"
"github.com/cloudflare/cloudflare-go/v5"
"github.com/cloudflare/cloudflare-go/v5/dns"
"github.com/cloudflare/cloudflare-go/v5/zones"
"github.com/maxatome/go-testdeep/td"
Expand Down Expand Up @@ -171,49 +171,22 @@ func (m *mockCloudFlareClient) CreateDNSRecord(ctx context.Context, params dns.R
return &record, nil
}

func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) {
func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] {
if m.dnsRecordsError != nil {
return nil, &cloudflarev0.ResultInfo{}, m.dnsRecordsError
return &mockAutoPager[dns.RecordResponse]{err: m.dnsRecordsError}
}
result := []dns.RecordResponse{}
if zone, ok := m.Records[rc.Identifier]; ok {
iter := &mockAutoPager[dns.RecordResponse]{}
if zone, ok := m.Records[params.ZoneID.Value]; ok {
for _, record := range zone {
if strings.HasPrefix(record.Name, "newerror-list-") {
m.DeleteDNSRecord(ctx, rc, record.ID)
return nil, &cloudflarev0.ResultInfo{}, errors.New("failed to list erroring DNS record")
m.DeleteDNSRecord(ctx, cloudflarev0.ResourceIdentifier(params.ZoneID.Value), record.ID)
iter.err = errors.New("failed to list erroring DNS record")
return iter
}
result = append(result, record)
iter.items = append(iter.items, record)
}
}

if len(result) == 0 || rp.PerPage == 0 {
return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: 0, Total: 0}, nil
}

// if not pagination options were passed in, return the result as is
if rp.Page == 0 {
return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: len(result), Total: len(result)}, nil
}

// otherwise, split the result into chunks of size rp.PerPage to simulate the pagination from the API
chunks := [][]dns.RecordResponse{}

// to ensure consistency in the multiple calls to this function, sort the result slice
sort.Slice(result, func(i, j int) bool { return strings.Compare(result[i].ID, result[j].ID) > 0 })
for rp.PerPage < len(result) {
result, chunks = result[rp.PerPage:], append(chunks, result[0:rp.PerPage])
}
chunks = append(chunks, result)

// return the requested page
partialResult := chunks[rp.Page-1]
return partialResult, &cloudflarev0.ResultInfo{
PerPage: rp.PerPage,
Page: rp.Page,
TotalPages: len(chunks),
Count: len(partialResult),
Total: len(result),
}, nil
return iter
}

func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error {
Expand Down Expand Up @@ -1500,7 +1473,7 @@ func TestGroupByNameAndTypeWithCustomHostnames_MX(t *testing.T) {
}
ctx := context.Background()
chs := CustomHostnamesMap{}
records, err := provider.listDNSRecordsWithAutoPagination(ctx, "001")
records, err := provider.getDNSRecordsMap(ctx, "001")
assert.NoError(t, err)

endpoints := provider.groupByNameAndTypeWithCustomHostnames(records, chs)
Expand Down Expand Up @@ -3346,3 +3319,22 @@ func TestDnsRecordFromLegacyAPI(t *testing.T) {
})
}
}

func TestZoneService(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(t.Context())
cancel()

client := &zoneService{
service: cloudflare.NewClient(),
}

t.Run("UpdateDNSRecord", func(t *testing.T) {
t.Parallel()
iter := client.ListDNSRecords(ctx, dns.RecordListParams{ZoneID: cloudflare.F("foo")})
require.False(t, iter.Next())
require.Empty(t, iter.Current())
require.ErrorIs(t, iter.Err(), context.Canceled)
})
}
Loading