diff --git a/provider/cloudflare/cloudflare.go b/provider/cloudflare/cloudflare.go index 574212da4b..ca5597e5cf 100644 --- a/provider/cloudflare/cloudflare.go +++ b/provider/cloudflare/cloudflare.go @@ -780,29 +780,48 @@ func (p *CloudFlareProvider) newCloudFlareChange(action changeAction, ep *endpoi } priority := (*uint16)(nil) + var data map[string]interface{} + if ep.RecordType == "MX" { mxRecord, err := endpoint.NewMXRecord(target) if err != nil { return &cloudFlareChange{}, fmt.Errorf("failed to parse MX record target %q: %w", target, err) - } else { - priority = mxRecord.GetPriority() - target = *mxRecord.GetHost() + } + priority = mxRecord.GetPriority() + target = *mxRecord.GetHost() + } else if ep.RecordType == "SRV" { + parts := strings.Fields(target) + if len(parts) >= 4 { + priorityVal, _ := strconv.Atoi(parts[0]) + weight, _ := strconv.Atoi(parts[1]) + port, _ := strconv.Atoi(parts[2]) + targetHost := strings.Join(parts[3:], " ") + data = map[string]interface{}{ + "priority": priorityVal, + "weight": weight, + "port": port, + "target": targetHost, + } } } + record := cloudflare.DNSRecord{ + Name: ep.DNSName, + TTL: ttl, + Proxied: &proxied, + Type: ep.RecordType, + Content: target, + Comment: comment, + Priority: priority, + } + + if data != nil { + record.Data = data + } + return &cloudFlareChange{ - Action: action, - ResourceRecord: cloudflare.DNSRecord{ - Name: ep.DNSName, - TTL: ttl, - // We have to use pointers to bools now, as the upstream cloudflare-go library requires them - // see: https://github.com/cloudflare/cloudflare-go/pull/595 - Proxied: &proxied, - Type: ep.RecordType, - Content: target, - Comment: comment, - Priority: priority, - }, + Action: action, + ResourceRecord: record, RegionalHostname: p.regionalHostname(ep), CustomHostnamesPrev: prevCustomHostnames, CustomHostnames: newCustomHostnames, diff --git a/provider/cloudflare/cloudflare_regional_test.go b/provider/cloudflare/cloudflare_regional_test.go index fbbca0d907..a3fa6752fc 100644 --- a/provider/cloudflare/cloudflare_regional_test.go +++ b/provider/cloudflare/cloudflare_regional_test.go @@ -831,7 +831,6 @@ func TestRecordsWithListRegionalHostnameFaillure(t *testing.T) { } func TestApplyChangesWithRegionalHostnamesFaillures(t *testing.T) { - t.Parallel() type fields struct { Records map[string]cloudflare.DNSRecord RegionalHostnames []cloudflare.RegionalHostname @@ -1031,7 +1030,6 @@ func TestApplyChangesWithRegionalHostnamesFaillures(t *testing.T) { } func TestApplyChangesWithRegionalHostnamesDryRun(t *testing.T) { - t.Parallel() type fields struct { Records map[string]cloudflare.DNSRecord RegionalHostnames []cloudflare.RegionalHostname diff --git a/provider/cloudflare/cloudflare_test.go b/provider/cloudflare/cloudflare_test.go index b304d2116a..87bbaeaa54 100644 --- a/provider/cloudflare/cloudflare_test.go +++ b/provider/cloudflare/cloudflare_test.go @@ -23,6 +23,7 @@ import ( "os" "slices" "sort" + "strconv" "strings" "testing" @@ -134,6 +135,12 @@ func getDNSRecordFromRecordParams(rp any) cloudflare.DNSRecord { if params.Type == "MX" { record.Priority = params.Priority } + if params.Data != nil { + record.Data = params.Data + } + if params.Type == "SRV" && record.Data != nil { + record.Content = "" + } return record case cloudflare.UpdateDNSRecordParams: record := cloudflare.DNSRecord{ @@ -147,6 +154,12 @@ func getDNSRecordFromRecordParams(rp any) cloudflare.DNSRecord { if params.Type == "MX" { record.Priority = params.Priority } + if params.Data != nil { + record.Data = params.Data + } + if params.Type == "SRV" && record.Data != nil { + record.Content = "" + } return record default: return cloudflare.DNSRecord{} @@ -162,12 +175,37 @@ func (m *mockCloudFlareClient) CreateDNSRecord(ctx context.Context, rc *cloudfla if recordData.ID == "" { recordData.ID = generateDNSRecordID(recordData.Type, recordData.Name, recordData.Content) } - m.Actions = append(m.Actions, MockAction{ + + if recordData.Type == "SRV" { + if rp.Data != nil { + recordData.Data = rp.Data + recordData.Content = "" + } else if recordData.Data == nil && recordData.Content != "" { + parts := strings.Fields(recordData.Content) + if len(parts) >= 4 { + priority, _ := strconv.Atoi(parts[0]) + weight, _ := strconv.Atoi(parts[1]) + port, _ := strconv.Atoi(parts[2]) + target := strings.Join(parts[3:], " ") + recordData.Data = map[string]interface{}{ + "priority": priority, + "weight": weight, + "port": port, + "target": target, + } + recordData.Content = "" + } + } + } + + action := MockAction{ Name: "Create", ZoneId: rc.Identifier, RecordId: recordData.ID, RecordData: recordData, - }) + } + m.Actions = append(m.Actions, action) + if zone, ok := m.Records[rc.Identifier]; ok { zone[recordData.ID] = recordData } @@ -175,7 +213,7 @@ func (m *mockCloudFlareClient) CreateDNSRecord(ctx context.Context, rc *cloudfla if recordData.Name == "newerror.bar.com" { return cloudflare.DNSRecord{}, fmt.Errorf("failed to create record") } - return cloudflare.DNSRecord{}, nil + return recordData, nil } func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.ListDNSRecordsParams) ([]cloudflare.DNSRecord, *cloudflare.ResultInfo, error) { @@ -780,6 +818,9 @@ func TestCloudflareSetProxied(t *testing.T) { targets = endpoint.Targets{"10 mx.example.com"} content = "mx.example.com" priority = cloudflare.Uint16Ptr(10) + } else if testCase.recordType == "SRV" { + targets = endpoint.Targets{"10 5 8080 example.com"} + content = "10 5 8080 example.com" } else { targets = endpoint.Targets{"127.0.0.1"} content = "127.0.0.1" @@ -798,17 +839,30 @@ func TestCloudflareSetProxied(t *testing.T) { }, }, } - expectedID := fmt.Sprintf("%s-%s-%s", testCase.domain, testCase.recordType, content) + // Generate the expected ID based on the record type and content + expectedID := generateDNSRecordID(testCase.recordType, testCase.domain, content) + recordData := cloudflare.DNSRecord{ ID: expectedID, Type: testCase.recordType, Name: testCase.domain, - Content: content, TTL: 1, Proxied: testCase.proxiable, } + if testCase.recordType == "MX" { + recordData.Content = content recordData.Priority = priority + } else if testCase.recordType == "SRV" { + recordData.Data = map[string]interface{}{ + "priority": 10, + "weight": 5, + "port": 8080, + "target": "example.com", + } + recordData.Content = "" + } else { + recordData.Content = content } AssertActions(t, &CloudFlareProvider{}, endpoints, []MockAction{ { @@ -817,7 +871,7 @@ func TestCloudflareSetProxied(t *testing.T) { RecordId: expectedID, RecordData: recordData, }, - }, []string{endpoint.RecordTypeA, endpoint.RecordTypeCNAME, endpoint.RecordTypeNS, endpoint.RecordTypeMX}, testCase.recordType+" record on "+testCase.domain) + }, []string{endpoint.RecordTypeA, endpoint.RecordTypeCNAME, endpoint.RecordTypeNS, endpoint.RecordTypeMX, endpoint.RecordTypeSRV}, testCase.recordType+" record on "+testCase.domain) } } @@ -1226,11 +1280,6 @@ func TestCloudflareGetRecordID(t *testing.T) { } func TestCloudflareGroupByNameAndType(t *testing.T) { - provider := &CloudFlareProvider{ - Client: NewMockCloudFlareClient(), - domainFilter: endpoint.NewDomainFilter([]string{"bar.com"}), - zoneIDFilter: provider.NewZoneIDFilter([]string{""}), - } testCases := []struct { Name string Records []cloudflare.DNSRecord @@ -1465,6 +1514,7 @@ func TestCloudflareGroupByNameAndType(t *testing.T) { for _, r := range tc.Records { records[newDNSRecordIndex(r)] = r } + provider := &CloudFlareProvider{} endpoints := provider.groupByNameAndTypeWithCustomHostnames(records, CustomHostnamesMap{}) // Targets order could be random with underlying map for _, ep := range endpoints {