Skip to content

Commit

Permalink
feat: support node sort by tcp ping latency
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancheng91 authored and ginuerzh committed Jun 13, 2024
1 parent fd57e80 commit 2faecc1
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 8 deletions.
21 changes: 13 additions & 8 deletions cmd/gost/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ import (
)

type peerConfig struct {
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
FailTimeout time.Duration
period time.Duration // the period for live reloading
Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
stopped chan struct{}
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
FastestCount int `json:"fastest_count"` // topN fastest node count
FailTimeout time.Duration
period time.Duration // the period for live reloading

Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
stopped chan struct{}
}

func newPeerConfig() *peerConfig {
Expand Down Expand Up @@ -51,6 +53,7 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
FailTimeout: cfg.FailTimeout,
},
&gost.InvalidFilter{},
gost.NewFastestFilter(0, cfg.FastestCount),
),
gost.WithStrategy(gost.NewStrategy(cfg.Strategy)),
)
Expand Down Expand Up @@ -125,6 +128,8 @@ func (cfg *peerConfig) parse(r io.Reader) error {
cfg.Strategy = ss[1]
case "max_fails":
cfg.MaxFails, _ = strconv.Atoi(ss[1])
case "fastest_count":
cfg.FastestCount, _ = strconv.Atoi(ss[1])
case "fail_timeout":
cfg.FailTimeout, _ = time.ParseDuration(ss[1])
case "reload":
Expand Down
1 change: 1 addition & 0 deletions cmd/gost/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func (r *route) parseChain() (*gost.Chain, error) {
FailTimeout: nodes[0].GetDuration("fail_timeout"),
},
&gost.InvalidFilter{},
gost.NewFastestFilter(0, nodes[0].GetInt("fastest_count")),
),
gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))),
)
Expand Down
89 changes: 89 additions & 0 deletions selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"errors"
"math/rand"
"net"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/go-log/log"
)

var (
Expand Down Expand Up @@ -205,6 +208,92 @@ func (f *FailFilter) String() string {
return "fail"
}

// FastestFilter filter the fastest node
type FastestFilter struct {
mu sync.Mutex

pinger *net.Dialer
pingResult map[int]int
pingResultTTL map[int]int64

topCount int
}

func NewFastestFilter(pingTimeOut int, topCount int) *FastestFilter {
if pingTimeOut == 0 {
pingTimeOut = 3000 // 3s
}
return &FastestFilter{
mu: sync.Mutex{},
pinger: &net.Dialer{Timeout: time.Millisecond * time.Duration(pingTimeOut)},
pingResult: make(map[int]int, 0),
pingResultTTL: make(map[int]int64, 0),
topCount: topCount,
}
}

func (f *FastestFilter) Filter(nodes []Node) []Node {
// disabled
if f.topCount == 0 {
return nodes
}

// get latency with ttl cache
now := time.Now().Unix()
r := rand.New(rand.NewSource(time.Now().UnixNano()))

var getNodeLatency = func(node Node) int {
if f.pingResultTTL[node.ID] < now {
f.mu.Lock()
f.pingResultTTL[node.ID] = now + 5 // tmp
defer f.mu.Unlock()

// get latency
go func(node Node) {
latency := f.doTcpPing(node.Addr)
ttl := 300 - int64(60*r.Float64())

f.mu.Lock()
f.pingResult[node.ID] = latency
f.pingResultTTL[node.ID] = now + ttl
defer f.mu.Unlock()
}(node)
}
return f.pingResult[node.ID]
}

// sort
sort.Slice(nodes, func(i, j int) bool {
return getNodeLatency(nodes[i]) < getNodeLatency(nodes[j])
})

// split
if len(nodes) <= f.topCount {
return nodes
}

return nodes[0:f.topCount]
}

func (f *FastestFilter) String() string {
return "fastest"
}

// doTcpPing
func (f *FastestFilter) doTcpPing(address string) int {
start := time.Now()
conn, err := f.pinger.Dial("tcp", address)
elapsed := time.Since(start)

if err == nil {
_ = conn.Close()
}

latency := int(elapsed.Milliseconds())
log.Logf("pingDoTCP: %s, latency: %d", address, latency)
return latency
}

// InvalidFilter filters the invalid node.
// A node is invalid if its port is invalid (negative or zero value).
type InvalidFilter struct{}
Expand Down
24 changes: 24 additions & 0 deletions selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ func TestFailFilter(t *testing.T) {
}
}

func TestFastestFilter(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}, Addr: "1.0.0.1:80"},
Node{ID: 2, marker: &failMarker{}, Addr: "1.0.0.2:80"},
Node{ID: 3, marker: &failMarker{}, Addr: "1.0.0.3:80"},
}
filter := NewFastestFilter(0, 2)

var print = func(nodes []Node) []string {
var rows []string
for _, node := range nodes {
rows = append(rows, node.Addr)
}
return rows
}

result1 := filter.Filter(nodes)
t.Logf("result 1: %+v", print(result1))

time.Sleep(time.Second)
result2 := filter.Filter(nodes)
t.Logf("result 2: %+v", print(result2))
}

func TestSelector(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}},
Expand Down

0 comments on commit 2faecc1

Please sign in to comment.