Skip to content

Commit

Permalink
move code, ipset-v6
Browse files Browse the repository at this point in the history
  • Loading branch information
szolin committed Sep 1, 2020
1 parent 77f4d94 commit bdbd732
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 114 deletions.
2 changes: 1 addition & 1 deletion AGHTechDoc.md
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,7 @@ Prepare: user creates an ipset list and configures AGH for using it.

Syntax:

ipset: "DOMAIN[,DOMAIN].../IPSET_NAME"
ipset: "DOMAIN[,DOMAIN].../IPSET1_NAME[,IPSET2_NAME]..."

Run-time: AGH adds IP addresses of a domain name to a corresponding ipset list.

Expand Down
5 changes: 2 additions & 3 deletions dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ type Server struct {
stats stats.Stats
access *accessCtx

ipsetList map[string]string // domain -> ipset_name
ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipset ipsetCtx

tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP)
tableHostToIPLock sync.Mutex
Expand Down Expand Up @@ -193,7 +192,7 @@ func (s *Server) Prepare(config *ServerConfig) error {

// Initialize IPSET configuration
// --
s.initIPSET()
s.ipset.init(s.conf.IPSETList)

// Prepare DNS servers settings
// --
Expand Down
28 changes: 0 additions & 28 deletions dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1060,31 +1060,3 @@ func TestPTRResponse(t *testing.T) {

s.Close()
}

func TestIPSET(t *testing.T) {
s := Server{}
s.conf.IPSETList = append(s.conf.IPSETList, "HOST.com/name")
s.conf.IPSETList = append(s.conf.IPSETList, "host2.com,host3.com/name23")
s.initIPSET()

assert.Equal(t, "name", s.ipsetList["host.com"])
assert.Equal(t, "name23", s.ipsetList["host2.com"])
assert.Equal(t, "name23", s.ipsetList["host3.com"])

_, ok := s.ipsetList["host4.com"]
assert.False(t, ok)

ctx := &dnsContext{
srv: &s,
}
ctx.proxyCtx = &proxy.DNSContext{}
ctx.proxyCtx.Req = &dns.Msg{
Question: []dns.Question{
{
Name: "host.com.",
Qtype: dns.TypeA,
},
},
}
assert.Equal(t, resultDone, processIPSEC(ctx))
}
83 changes: 1 addition & 82 deletions dnsforward/handle_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
processUpstream,
processDNSSECAfterResponse,
processFilteringAfterResponse,
processIPSEC,
s.ipset.process,
processQueryLogsAndStats,
}
for _, process := range mods {
Expand Down Expand Up @@ -404,84 +404,3 @@ func processFilteringAfterResponse(ctx *dnsContext) int {

return resultDone
}

// Convert configuration settings to an internal map
// DOMAIN[,DOMAIN].../IPSET_NAME
func (s *Server) initIPSET() {
s.ipsetList = make(map[string]string)
s.ipsetCache = make(map[[4]byte]bool)

nSets := 0
for _, it := range s.conf.IPSETList {
it = strings.TrimSpace(it)
hostsAndName := strings.Split(it, "/")
if len(hostsAndName) != 2 {
log.Debug("IPSET: invalid value '%s'", it)
continue
}
ipsetName := strings.TrimSpace(hostsAndName[1])
if ipsetName == "" {
log.Debug("IPSET: invalid value '%s'", it)
continue
}
nSets++
hosts := strings.Split(hostsAndName[0], ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
host = strings.ToLower(host)
if host == "" {
log.Debug("IPSET: invalid value '%s'", it)
continue
}
s.ipsetList[host] = ipsetName
}
}
log.Debug("IPSET: added %d hosts; ipsets:%d", len(s.ipsetList), nSets)
}

func processIPSEC(ctx *dnsContext) int {
s := ctx.srv
req := ctx.proxyCtx.Req
if req.Question[0].Qtype != dns.TypeA ||
!ctx.responseFromUpstream {
return resultDone
}

host := req.Question[0].Name
host = strings.TrimSuffix(host, ".")
host = strings.ToLower(host)
ipsetName, found := s.ipsetList[host]
if !found {
return resultDone
}

log.Debug("IPSET: found ipset %s for host %s", ipsetName, host)

for _, it := range ctx.proxyCtx.Res.Answer {
a, ok := it.(*dns.A)
if !ok {
continue
}

var ip4 [4]byte
copy(ip4[:], a.A.To4())
_, found := s.ipsetCache[ip4]
if found {
continue // this IP was added before
}
s.ipsetCache[ip4] = false

code, out, err := util.RunCommand("ipset", "add", ipsetName, a.A.String())
if err != nil {
log.Info("%s", err)
return resultDone
}
if code != 0 {
log.Info("IPSET: ipset add: code:%d output:'%s'", code, out)
return resultDone
}
log.Debug("IPSET: added %s(%s) -> %s", host, a.A.String(), ipsetName)
}

return resultDone
}
132 changes: 132 additions & 0 deletions dnsforward/ipset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package dnsforward

import (
"net"
"strings"

"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)

type ipsetCtx struct {
ipsetList map[string][]string // domain -> []ipset_name
ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipset6Cache map[[16]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
}

// Convert configuration settings to an internal map
// DOMAIN[,DOMAIN].../IPSET1_NAME[,IPSET2_NAME]...
func (c *ipsetCtx) init(ipsetConfig []string) {
c.ipsetList = make(map[string][]string)
c.ipsetCache = make(map[[4]byte]bool)

for _, it := range ipsetConfig {
it = strings.TrimSpace(it)
hostsAndNames := strings.Split(it, "/")
if len(hostsAndNames) != 2 {
log.Debug("IPSET: invalid value '%s'", it)
continue
}

ipsetNames := strings.Split(hostsAndNames[1], ",")
if len(ipsetNames) == 0 {
log.Debug("IPSET: invalid value '%s'", it)
continue
}
bad := false
for i := range ipsetNames {
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
if len(ipsetNames[i]) == 0 {
bad = true
break
}
}
if bad {
log.Debug("IPSET: invalid value '%s'", it)
continue
}

hosts := strings.Split(hostsAndNames[0], ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
host = strings.ToLower(host)
if len(host) == 0 {
log.Debug("IPSET: invalid value '%s'", it)
continue
}
c.ipsetList[host] = ipsetNames
}
}
log.Debug("IPSET: added %d hosts", len(c.ipsetList))
}

func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
switch a := rr.(type) {
case *dns.A:
var ip4 [4]byte
copy(ip4[:], a.A.To4())
_, found := c.ipsetCache[ip4]
if found {
return nil // this IP was added before
}
c.ipsetCache[ip4] = false
return a.A

case *dns.AAAA:
var ip6 [16]byte
copy(ip6[:], a.AAAA)
_, found := c.ipset6Cache[ip6]
if found {
return nil // this IP was added before
}
c.ipset6Cache[ip6] = false
return a.AAAA

default:
return nil
}
}

// Add IP addresses of the specified in configuration domain names to an ipset list
func (c *ipsetCtx) process(ctx *dnsContext) int {
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA ||
req.Question[0].Qtype == dns.TypeAAAA) ||
!ctx.responseFromUpstream {
return resultDone
}

host := req.Question[0].Name
host = strings.TrimSuffix(host, ".")
host = strings.ToLower(host)
ipsetNames, found := c.ipsetList[host]
if !found {
return resultDone
}

log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host)

for _, it := range ctx.proxyCtx.Res.Answer {
ip := c.getIP(it)
if ip == nil {
continue
}

ipStr := ip.String()
for _, name := range ipsetNames {
code, out, err := util.RunCommand("ipset", "add", name, ipStr)
if err != nil {
log.Info("%s", err)
return resultDone
}
if code != 0 {
log.Info("IPSET: ipset add: code:%d output:'%s'", code, out)
continue
}
log.Debug("IPSET: added %s(%s) -> %s", host, ipStr, name)
}
}

return resultDone
}
41 changes: 41 additions & 0 deletions dnsforward/ipset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package dnsforward

import (
"testing"

"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)

func TestIPSET(t *testing.T) {
s := Server{}
s.conf.IPSETList = append(s.conf.IPSETList, "HOST.com/name")
s.conf.IPSETList = append(s.conf.IPSETList, "host2.com,host3.com/name23")
s.conf.IPSETList = append(s.conf.IPSETList, "host4.com/name4,name41")
c := ipsetCtx{}
c.init(s.conf.IPSETList)

assert.Equal(t, "name", c.ipsetList["host.com"][0])
assert.Equal(t, "name23", c.ipsetList["host2.com"][0])
assert.Equal(t, "name23", c.ipsetList["host3.com"][0])
assert.Equal(t, "name4", c.ipsetList["host4.com"][0])
assert.Equal(t, "name41", c.ipsetList["host4.com"][1])

_, ok := c.ipsetList["host0.com"]
assert.False(t, ok)

ctx := &dnsContext{
srv: &s,
}
ctx.proxyCtx = &proxy.DNSContext{}
ctx.proxyCtx.Req = &dns.Msg{
Question: []dns.Question{
{
Name: "host.com.",
Qtype: dns.TypeA,
},
},
}
assert.Equal(t, resultDone, c.process(ctx))
}

0 comments on commit bdbd732

Please sign in to comment.