Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a custom inserter function for Tree.Insert. #24

Merged
merged 1 commit into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions inserter/inserter.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
// Package inserter provides some common inserter functions for
// mmdbwriter.InsertFunc.
// mmdbwriter.Tree.
package inserter

import (
"github.com/maxmind/mmdbwriter/mmdbtype"
"github.com/pkg/errors"
)

// InserterFunc is a function that returns the data type to be inserted into an
// mmdbwriter.Tree using some conflict resolution strategy.
type InserterFunc func(mmdbtype.DataType) (mmdbtype.DataType, error)

// InserterFuncGenerator is a function that generates an InserterFunc given a
// value.
type InserterFuncGenerator func(value mmdbtype.DataType) InserterFunc

// Remove any records for the network being inserted.
func Remove(value mmdbtype.DataType) (mmdbtype.DataType, error) {
return nil, nil
}

// ReplaceWith generates an inserter function that replaces the existing
// value with the new value.
func ReplaceWith(value mmdbtype.DataType) func(mmdbtype.DataType) (mmdbtype.DataType, error) {
func ReplaceWith(value mmdbtype.DataType) InserterFunc {
return func(_ mmdbtype.DataType) (mmdbtype.DataType, error) {
return value, nil
}
Expand All @@ -26,7 +34,7 @@ func ReplaceWith(value mmdbtype.DataType) func(mmdbtype.DataType) (mmdbtype.Data
//
// Both the new and existing value must be a Map. An error will be returned
// otherwise.
func TopLevelMergeWith(newValue mmdbtype.DataType) func(mmdbtype.DataType) (mmdbtype.DataType, error) {
func TopLevelMergeWith(newValue mmdbtype.DataType) InserterFunc {
return func(existingValue mmdbtype.DataType) (mmdbtype.DataType, error) {
newMap, ok := newValue.(mmdbtype.Map)
if !ok {
Expand Down Expand Up @@ -63,7 +71,7 @@ func TopLevelMergeWith(newValue mmdbtype.DataType) func(mmdbtype.DataType) (mmdb
// DeepMergeWith creates an inserter that will recursively update an existing
// value. Map and Slice values will be merged recursively. Other values will
// be replaced by the new value.
func DeepMergeWith(newValue mmdbtype.DataType) func(mmdbtype.DataType) (mmdbtype.DataType, error) {
func DeepMergeWith(newValue mmdbtype.DataType) InserterFunc {
return func(existingValue mmdbtype.DataType) (mmdbtype.DataType, error) {
return deepMerge(existingValue, newValue)
}
Expand Down
28 changes: 20 additions & 8 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ type Options struct {
// implementations that do not correctly handle metadata pointers. Its
// use should primarily be limited to existing database types.
DisableMetadataPointers bool

// Inserter is the insert function used when calling `Insert`. It defaults
// to `inserter.ReplaceWith`, which replaces any conflicting old value
// entirely with the new.
Inserter inserter.InserterFuncGenerator
}

// Tree represents an MaxMind DB search tree.
Expand All @@ -87,7 +92,8 @@ type Tree struct {
root *node
treeDepth int
// This is set when the tree is finalized
nodeCount int
nodeCount int
inserterFuncGen inserter.InserterFuncGenerator
}

// New creates a new Tree.
Expand All @@ -101,6 +107,7 @@ func New(opts Options) (*Tree, error) {
ipVersion: 6,
recordSize: 28,
root: &node{},
inserterFuncGen: inserter.ReplaceWith,
}

if opts.BuildEpoch != 0 {
Expand All @@ -123,6 +130,10 @@ func New(opts Options) (*Tree, error) {
tree.recordSize = opts.RecordSize
}

if opts.Inserter != nil {
tree.inserterFuncGen = opts.Inserter
}

switch tree.ipVersion {
case 6:
tree.treeDepth = 128
Expand Down Expand Up @@ -210,17 +221,18 @@ func Load(path string, opts Options) (*Tree, error) {
return tree, nil
}

// Insert a data value into the tree.
// Insert a data value into the tree using the Tree's inserter function
// (defaults to inserter.ReplaceWith).
//
// This is not safe to call from multiple threads.
func (t *Tree) Insert(network *net.IPNet, value mmdbtype.DataType) error {
return t.InsertFunc(network, inserter.ReplaceWith(value))
return t.InsertFunc(network, t.inserterFuncGen(value))
}

// InsertFunc will insert the output of the function passed to it. The argument
// passed to the function is the existing value in the record. The function
// should return the mmdbtype.DataType to be inserted. In both cases, a nil value means
// an empty record.
// passed to the function is the existing value in the record. The inserter
// function should return the mmdbtype.DataType to be inserted. In both cases,
// a nil value means an empty record.
//
// You must never modify the argument passed to the function as the value may
// be shared with other records. If you want a copy of the mmdbtype.DataType to modify,
Expand All @@ -234,15 +246,15 @@ func (t *Tree) Insert(network *net.IPNet, value mmdbtype.DataType) error {
// This is not safe to call from multiple threads.
func (t *Tree) InsertFunc(
network *net.IPNet,
inserter func(value mmdbtype.DataType) (mmdbtype.DataType, error),
inserter inserter.InserterFunc,
) error {
return t.insert(network, recordTypeData, inserter, nil)
}

func (t *Tree) insert(
network *net.IPNet,
recordType recordType,
inserter func(value mmdbtype.DataType) (mmdbtype.DataType, error),
inserter inserter.InserterFunc,
node *node,
) error {
// We set this to 0 so that the tree must be finalized again.
Expand Down