Skip to content

Commit 33e6400

Browse files
authored
Support TsigProvider for Server and Transfer (#1331)
Automatically submitted.
1 parent 51afb90 commit 33e6400

File tree

5 files changed

+126
-49
lines changed

5 files changed

+126
-49
lines changed

client.go

+12-20
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ type Conn struct {
3939
tsigRequestMAC string
4040
}
4141

42+
func (co *Conn) tsigProvider() TsigProvider {
43+
if co.TsigProvider != nil {
44+
return co.TsigProvider
45+
}
46+
// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
47+
return tsigSecretProvider(co.TsigSecret)
48+
}
49+
4250
// A Client defines parameters for a DNS client.
4351
type Client struct {
4452
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
@@ -271,15 +279,8 @@ func (co *Conn) ReadMsg() (*Msg, error) {
271279
return m, err
272280
}
273281
if t := m.IsTsig(); t != nil {
274-
if co.TsigProvider != nil {
275-
err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
276-
} else {
277-
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
278-
return m, ErrSecret
279-
}
280-
// Need to work on the original message p, as that was used to calculate the tsig.
281-
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
282-
}
282+
// Need to work on the original message p, as that was used to calculate the tsig.
283+
err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
283284
}
284285
return m, err
285286
}
@@ -356,17 +357,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
356357
func (co *Conn) WriteMsg(m *Msg) (err error) {
357358
var out []byte
358359
if t := m.IsTsig(); t != nil {
359-
mac := ""
360-
if co.TsigProvider != nil {
361-
out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
362-
} else {
363-
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
364-
return ErrSecret
365-
}
366-
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
367-
}
368-
// Set for the next read, although only used in zone transfers
369-
co.tsigRequestMAC = mac
360+
// Set tsigRequestMAC for the next read, although only used in zone transfers.
361+
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
370362
} else {
371363
out, err = m.Pack()
372364
}

server.go

+25-17
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ type response struct {
7171
tsigTimersOnly bool
7272
tsigStatus error
7373
tsigRequestMAC string
74-
tsigSecret map[string]string // the tsig secrets
75-
udp net.PacketConn // i/o connection if UDP was used
76-
tcp net.Conn // i/o connection if TCP was used
77-
udpSession *SessionUDP // oob data to get egress interface right
78-
pcSession net.Addr // address to use when writing to a generic net.PacketConn
79-
writer Writer // writer to output the raw DNS bits
74+
tsigProvider TsigProvider
75+
udp net.PacketConn // i/o connection if UDP was used
76+
tcp net.Conn // i/o connection if TCP was used
77+
udpSession *SessionUDP // oob data to get egress interface right
78+
pcSession net.Addr // address to use when writing to a generic net.PacketConn
79+
writer Writer // writer to output the raw DNS bits
8080
}
8181

8282
// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
@@ -211,6 +211,8 @@ type Server struct {
211211
WriteTimeout time.Duration
212212
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
213213
IdleTimeout func() time.Duration
214+
// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
215+
TsigProvider TsigProvider
214216
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
215217
TsigSecret map[string]string
216218
// If NotifyStartedFunc is set it is called once the server has started listening.
@@ -238,6 +240,16 @@ type Server struct {
238240
udpPool sync.Pool
239241
}
240242

243+
func (srv *Server) tsigProvider() TsigProvider {
244+
if srv.TsigProvider != nil {
245+
return srv.TsigProvider
246+
}
247+
if srv.TsigSecret != nil {
248+
return tsigSecretProvider(srv.TsigSecret)
249+
}
250+
return nil
251+
}
252+
241253
func (srv *Server) isStarted() bool {
242254
srv.lock.RLock()
243255
started := srv.started
@@ -526,7 +538,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
526538

527539
// Serve a new TCP connection.
528540
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
529-
w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
541+
w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
530542
if srv.DecorateWriter != nil {
531543
w.writer = srv.DecorateWriter(w)
532544
} else {
@@ -581,7 +593,7 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
581593

582594
// Serve a new UDP request.
583595
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
584-
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
596+
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
585597
if srv.DecorateWriter != nil {
586598
w.writer = srv.DecorateWriter(w)
587599
} else {
@@ -632,15 +644,11 @@ func (srv *Server) serveDNS(m []byte, w *response) {
632644
}
633645

634646
w.tsigStatus = nil
635-
if w.tsigSecret != nil {
647+
if w.tsigProvider != nil {
636648
if t := req.IsTsig(); t != nil {
637-
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
638-
w.tsigStatus = TsigVerify(m, secret, "", false)
639-
} else {
640-
w.tsigStatus = ErrSecret
641-
}
649+
w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false)
642650
w.tsigTimersOnly = false
643-
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
651+
w.tsigRequestMAC = t.MAC
644652
}
645653
}
646654

@@ -718,9 +726,9 @@ func (w *response) WriteMsg(m *Msg) (err error) {
718726
}
719727

720728
var data []byte
721-
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
729+
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
722730
if t := m.IsTsig(); t != nil {
723-
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
731+
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
724732
if err != nil {
725733
return err
726734
}

tsig.go

+18
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,24 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
7474
return nil
7575
}
7676

77+
type tsigSecretProvider map[string]string
78+
79+
func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
80+
key, ok := ts[t.Hdr.Name]
81+
if !ok {
82+
return nil, ErrSecret
83+
}
84+
return tsigHMACProvider(key).Generate(msg, t)
85+
}
86+
87+
func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error {
88+
key, ok := ts[t.Hdr.Name]
89+
if !ok {
90+
return ErrSecret
91+
}
92+
return tsigHMACProvider(key).Verify(msg, t)
93+
}
94+
7795
// TSIG is the RR the holds the transaction signature of a message.
7896
// See RFC 2845 and RFC 4635.
7997
type TSIG struct {

xfr.go

+16-11
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,22 @@ type Transfer struct {
1717
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
1818
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
1919
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
20+
TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
2021
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
2122
tsigTimersOnly bool
2223
}
2324

24-
// Think we need to away to stop the transfer
25+
func (t *Transfer) tsigProvider() TsigProvider {
26+
if t.TsigProvider != nil {
27+
return t.TsigProvider
28+
}
29+
if t.TsigSecret != nil {
30+
return tsigSecretProvider(t.TsigSecret)
31+
}
32+
return nil
33+
}
34+
35+
// TODO: Think we need to away to stop the transfer
2536

2637
// In performs an incoming transfer with the server in a.
2738
// If you would like to set the source IP, or some other attribute
@@ -224,12 +235,9 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
224235
if err := m.Unpack(p); err != nil {
225236
return nil, err
226237
}
227-
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
228-
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
229-
return m, ErrSecret
230-
}
238+
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
231239
// Need to work on the original message p, as that was used to calculate the tsig.
232-
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
240+
err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
233241
t.tsigRequestMAC = ts.MAC
234242
}
235243
return m, err
@@ -238,11 +246,8 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
238246
// WriteMsg writes a message through the transfer connection t.
239247
func (t *Transfer) WriteMsg(m *Msg) (err error) {
240248
var out []byte
241-
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
242-
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
243-
return ErrSecret
244-
}
245-
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
249+
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
250+
out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
246251
} else {
247252
out, err = m.Pack()
248253
}

xfr_test.go

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package dns
22

3-
import "testing"
3+
import (
4+
"testing"
5+
"time"
6+
)
47

58
var (
69
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
@@ -127,3 +130,54 @@ func axfrTestingSuite(t *testing.T, addrstr string) {
127130
}
128131
}
129132
}
133+
134+
func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) {
135+
tr := new(Transfer)
136+
m := new(Msg)
137+
var err error
138+
tr.Conn, err = Dial("tcp", addrstr)
139+
if err != nil {
140+
t.Fatal("failed to dial", err)
141+
}
142+
tr.TsigProvider = provider
143+
m.SetAxfr("miek.nl.")
144+
m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix())
145+
146+
c, err := tr.In(m, addrstr)
147+
if err != nil {
148+
t.Fatal("failed to zone transfer in", err)
149+
}
150+
151+
var records []RR
152+
for msg := range c {
153+
if msg.Error != nil {
154+
t.Fatal(msg.Error)
155+
}
156+
records = append(records, msg.RR...)
157+
}
158+
159+
if len(records) != len(xfrTestData) {
160+
t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData)
161+
}
162+
163+
for i, rr := range records {
164+
if !IsDuplicate(rr, xfrTestData[i]) {
165+
t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData)
166+
}
167+
}
168+
}
169+
170+
func TestCustomTsigProvider(t *testing.T) {
171+
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
172+
defer HandleRemove("miek.nl.")
173+
174+
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
175+
srv.TsigProvider = tsigSecretProvider(tsigSecret)
176+
})
177+
if err != nil {
178+
t.Fatalf("unable to run test server: %s", err)
179+
}
180+
defer s.Shutdown()
181+
182+
axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret))
183+
}

0 commit comments

Comments
 (0)