Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ func (pt *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix {
return pt.coveredNetworks(network)
}

// WalkFunc is the type of the function called for each network visited by Walk methods.
type WalkFunc func(network netip.Prefix, value any) error

// CoveredNetworksWalk walks networks contained within the given network, calling walkFn.
//
// Note: Inserted addresses are normalized to IPv6, so the returned list will be IPv6 only.
func (pt *Trie) CoveredNetworksWalk(network netip.Prefix, walkFn WalkFunc) error {
network = normalizePrefix(network)
return pt.coveredNetworksWalk(network, walkFn)
}

// String returns string representation of trie.
//
// The result will contain implicit nodes which exist as parents for multiple entries, but can be distinguished by the
Expand Down Expand Up @@ -197,6 +208,19 @@ func (pt *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix {
return results
}

func (pt *Trie) coveredNetworksWalk(network netip.Prefix, walkFn WalkFunc) error {
if network.Bits() <= pt.network.Bits() && network.Contains(pt.network.Addr()) {
return pt.walkDepthFunc(walkFn)
} else if pt.network.Bits() < 128 {
bit := pt.discriminatorBitFromIP(network.Addr())
child := pt.children[bit]
if child != nil {
return child.coveredNetworksWalk(network, walkFn)
}
}
return nil
}

// This is an unsafe, but faster version of netip.Prefix.Contains
func netContains(pfx netip.Prefix, ip netip.Addr) bool {
pfxAddr := addr128(pfx.Addr())
Expand Down Expand Up @@ -378,6 +402,24 @@ func (pt *Trie) walkDepth() <-chan netip.Prefix {
return entries
}

// walkDepthFunc walks the trie in depth order, calling walkFn for each network.
func (pt *Trie) walkDepthFunc(walkFn WalkFunc) error {
if pt.value != nil {
if err := walkFn(pt.network, pt.value); err != nil {
return err
}
}
for _, trie := range pt.children {
if trie == nil {
continue
}
if err := trie.walkDepthFunc(walkFn); err != nil {
return err
}
}
return nil
}

// TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the
// last insert in the tree, using it as the starting point to start searching for the location of the next insert. This
// is highly beneficial when the addresses are pre-sorted.
Expand Down Expand Up @@ -452,6 +494,7 @@ func unempty(v any) any {
func addr128(addr netip.Addr) uint128 {
return *(*uint128)(unsafe.Pointer(&addr))
}

func init() {
// Accessing the underlying data of a `netip.Addr` relies upon the data being
// in a known format, which is not guaranteed to be stable. So this init()
Expand Down
74 changes: 74 additions & 0 deletions trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package iptrie

import (
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net/netip"
Expand Down Expand Up @@ -405,6 +406,9 @@ type coveredNetworkTest struct {
inserts []string
search string
networks []string
walk []string
stopWalk string
error bool
name string
}

Expand All @@ -413,36 +417,54 @@ var coveredNetworkTests = []coveredNetworkTest{
[]string{"192.168.0.0/24"},
"192.168.0.0/16",
[]string{"192.168.0.0/24"},
[]string{"192.168.0.0/24"},
"",
false,
"basic covered networks",
},
{
[]string{"192.168.0.0/24"},
"10.1.0.0/16",
nil,
nil,
"",
false,
"nothing",
},
{
[]string{"192.168.0.0/24", "192.168.0.0/25"},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.0.0/25"},
[]string{"192.168.0.0/24", "192.168.0.0/25"},
"192.168.1.0/25",
false,
"multiple networks",
},
{
[]string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
[]string{"192.168.0.0/24", "192.168.0.0/25"},
"192.168.0.0/25",
true,
"multiple networks 2",
},
{
[]string{"192.168.1.1/32"},
"192.168.0.0/16",
[]string{"192.168.1.1/32"},
[]string{"192.168.1.1/32"},
"",
false,
"leaf",
},
{
[]string{"0.0.0.0/0", "192.168.1.1/32"},
"192.168.0.0/16",
[]string{"192.168.1.1/32"},
[]string{"192.168.1.1/32"},
"",
false,
"leaf with root",
},
{
Expand All @@ -452,14 +474,32 @@ var coveredNetworkTests = []coveredNetworkTest{
},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.1.1/32"},
[]string{"192.168.0.0/24", "192.168.1.1/32"},
"10.1.0.0/16",
false,
"path not taken",
},
{
[]string{
"0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32",
"10.1.0.0/16", "10.1.1.0/24", "192.168.2.2/32",
},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.1.1/32", "192.168.2.2/32"},
[]string{"192.168.0.0/24", "192.168.1.1/32"},
"192.168.1.1/32",
true,
"path not taken and stopped",
},
{
[]string{
"192.168.0.0/15",
},
"192.168.0.0/16",
nil,
nil,
"",
false,
"only masks different",
},
}
Expand All @@ -485,6 +525,40 @@ func TestTrieCoveredNetworks(t *testing.T) {
}
}

func TestTrieCoveredNetworksWalk(t *testing.T) {
for _, tc := range coveredNetworkTests {
t.Run(tc.name, func(t *testing.T) {
trie := NewTrie()
for _, insert := range tc.inserts {
network := netip.MustParsePrefix(insert)
v := any(insert)
trie.Insert(network, v)
}
var expectedEntries []netip.Prefix
for _, network := range tc.walk {
expected := normalizePrefix(netip.MustParsePrefix(network))
expectedEntries = append(expectedEntries, expected)
}
snet := netip.MustParsePrefix(tc.search)
var networks []netip.Prefix
walkFn := func(network netip.Prefix, v any) error {
networks = append(networks, network)
if stopWalk := v.(string); stopWalk == tc.stopWalk {
return errors.New(stopWalk)
}
return nil
}
err := trie.CoveredNetworksWalk(snet, walkFn)
assert.Equal(t, expectedEntries, networks)
if tc.error {
assert.EqualError(t, err, tc.stopWalk)
} else {
assert.Nil(t, err)
}
})
}
}

func TestTrieMemUsage(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory test in `-short` mode")
Expand Down