Skip to content

Commit

Permalink
refactoring and test
Browse files Browse the repository at this point in the history
  • Loading branch information
koron committed Oct 24, 2023
1 parent 8bc0f26 commit e534284
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 58 deletions.
67 changes: 41 additions & 26 deletions announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,48 @@ import (
"sync"
"testing"
"time"

"golang.org/x/exp/slices"
)

func checkAlives(t *testing.T, alives []*AliveMessage, typ, usn, loc, srv string) {
t.Helper()

first := slices.IndexFunc(alives, func(m *AliveMessage) bool {
return m.Type == typ
})
if first < 0 {
t.Errorf("no AliveMessage which Type is %s", typ)
return
}

_, port, err := net.SplitHostPort(alives[first].From.String())
if err != nil {
t.Errorf("failed to split host and port from first message: %s", err)
return
}
port = ":" + port

for i, m := range alives {
if m.Type != typ {
t.Logf("unexpected alive[%d].Type: want=%q got=%q", i, typ, m.Type)
continue
}
if !strings.HasSuffix(m.From.String(), port) {
t.Errorf("unmatch alive[%d].From (:port): want=%q got=%q", i, port, m.From.String())
}
if m.USN != usn {
t.Errorf("unexpected alive[%d].USN: want=%q got=%q", i, usn, m.USN)
}
if m.Location != loc {
t.Errorf("unexpected alive[%d].Location: want=%q got=%q", i, loc, m.Location)
}
if m.Server != srv {
t.Errorf("unexpected alive[%d].Server: want=%q got=%q", i, srv, m.Server)
}
}
}

func TestAnnounceAlive(t *testing.T) {
var mu sync.Mutex
var mm []*AliveMessage
Expand All @@ -28,32 +68,7 @@ func TestAnnounceAlive(t *testing.T) {
}
time.Sleep(500 * time.Millisecond)

if len(mm) < 1 {
t.Fatal("no alives detected")
}
//t.Logf("found %d alives", len(mm))
_, port, err := net.SplitHostPort(mm[0].From.String())
if err != nil {
t.Fatalf("failed to split host and port: %s", err)
}
port = ":" + port
for i, m := range mm {
if strings.HasSuffix(port, m.From.String()) {
t.Errorf("unmatch port#%d:\nwant=%q\n got=%q", i, port, m.From.String())
}
if m.Type != "test:announce+alive" {
t.Errorf("unexpected alive#%d type: want=%q got=%q", i, "test:announce+alive", m.Type)
}
if m.USN != "usn:announce+alive" {
t.Errorf("unexpected alive#%d usn: want=%q got=%q", i, "usn:announce+alive", m.USN)
}
if m.Location != "location:announce+alive" {
t.Errorf("unexpected alive#%d location: want=%q got=%q", i, "location:announce+alive", m.Location)
}
if m.Server != "server:announce+alive" {
t.Errorf("unexpected alive#%d server: want=%q got=%q", i, "server:announce+alive", m.Server)
}
}
checkAlives(t, mm, "test:announce+alive", "usn:announce+alive", "location:announce+alive", "server:announce+alive")
}

func TestAnnounceBye(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions examples/monitor/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func main() {
flag.StringVar(&filterType, "filter_type", "", "print only a specified type (ST or NT). default is print all types.")
ttl := flag.Int("ttl", 0, "TTL for outgoing multicast packets")
sysIf := flag.Bool("sysif", false, "use system assigned multicast interface")
laddr := flag.String("laddr", "", "local address to listen")
flag.Parse()

if *h {
Expand All @@ -33,6 +34,9 @@ func main() {
if *sysIf {
opts = append(opts, ssdp.OnlySystemInterface())
}
if *laddr != "" {
opts = append(opts, ssdp.LocalAddr(*laddr))
}

m := &ssdp.Monitor{
Alive: onAlive,
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module github.com/koron/go-ssdp

go 1.20

require golang.org/x/net v0.17.0
require (
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/net v0.17.0
)

require golang.org/x/sys v0.13.0 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
Expand Down
3 changes: 3 additions & 0 deletions internal/multicast/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ type InterfacesProviderFunc func() []net.Interface

// InterfacesProvider specify a function to list all interfaces to multicast.
// If no provider are given, all possible interfaces will be used.
//
// Deprecated: this setting item is not good because it affects globaly.
// Use ConnInterfaces() option for each function call, instead of.
var InterfacesProvider InterfacesProviderFunc

// SystemAssignedInterface indicates use the system assigned multicast interface or not.
Expand Down
87 changes: 56 additions & 31 deletions internal/multicast/multicast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package multicast

import (
"errors"
"fmt"
"io"
"net"
"strings"
Expand All @@ -23,10 +24,32 @@ type Conn struct {
type connConfig struct {
ttl int
sysIf bool
ifps []*net.Interface
}

func (cfg connConfig) interfaces() ([]*net.Interface, error) {
if cfg.sysIf {
if cfg.ifps != nil {
return nil, errors.New("both of ssdp.ConnSystemAssginedInterface and ConnInterfaces are specified")
}
return []*net.Interface{nil}, nil
}
if cfg.ifps != nil {
return cfg.ifps, nil
}
list, err := interfaces()
if err != nil {
return nil, err
}
ifplist := make([]*net.Interface, 0, len(list))
for i := range list {
ifplist = append(ifplist, &list[i])
}
return ifplist, nil
}

// Listen starts to receiving multicast messages.
func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, error) {
func Listen(laddrResolver, gaddrResolver Resolver, opts ...ConnOption) (*Conn, error) {
// prepare parameters.
laddr, err := laddrResolver.Resolve()
if err != nil {
Expand All @@ -43,7 +66,7 @@ func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, e
return nil, err
}
// configure socket to use with multicast.
pconn, ifplist, err := newIPv4MulticastConn(conn, raddrResolver, cfg.sysIf)
pconn, ifplist, err := newIPv4MulticastConn(conn, gaddrResolver, cfg)
if err != nil {
conn.Close()
return nil, err
Expand All @@ -65,56 +88,43 @@ func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, e

// newIPv4MulticastConn create a new multicast connection.
// 2nd return parameter will be nil when sysIf is true.
func newIPv4MulticastConn(conn *net.UDPConn, raddrResolver Resolver, sysIf bool) (*ipv4.PacketConn, []*net.Interface, error) {
// sysIf: use system assigned multicast interface.
// the empty iflist indicate it.
var ifplist []*net.Interface
if !sysIf {
list, err := interfaces()
if err != nil {
return nil, nil, err
}
ifplist = make([]*net.Interface, 0, len(list))
for i := range list {
ifplist = append(ifplist, &list[i])
}
func newIPv4MulticastConn(conn *net.UDPConn, gaddrResolver Resolver, cfg connConfig) (*ipv4.PacketConn, []*net.Interface, error) {
ifplist, err := cfg.interfaces()
if err != nil {
return nil, nil, err
}
raddr, err := raddrResolver.Resolve()
gaddr, err := gaddrResolver.Resolve()
if err != nil {
return nil, nil, err
}
pconn, err := joinGroupIPv4(conn, ifplist, raddr)
pconn, err := joinGroupIPv4(conn, ifplist, gaddr)
if err != nil {
return nil, nil, err
}
return pconn, ifplist, nil
}

func interfaceName(ifi *net.Interface) string {
if ifi == nil {
return "system assigned multicast interface (nil)"
}
return fmt.Sprintf("%s (#%d)", ifi.Name, ifi.Index)
}

// joinGroupIPv4 makes the connection join to a group on interfaces.
// This trys to use system assigned when iflist is nil or empty.
func joinGroupIPv4(conn *net.UDPConn, ifplist []*net.Interface, gaddr net.Addr) (*ipv4.PacketConn, error) {
wrap := ipv4.NewPacketConn(conn)
wrap.SetMulticastLoopback(true)

// try to use the system assigned multicast interface when iflist is empty.
if len(ifplist) == 0 {
if err := wrap.JoinGroup(nil, gaddr); err != nil {
ssdplog.Printf("failed to join group %s on system assigned multicast interface: %s", gaddr.String(), err)
return nil, errors.New("no system assigned multicast interfaces had joined to group")
}
ssdplog.Printf("joined group %s on system assigned multicast interface", gaddr.String())
return wrap, nil
}

// add interfaces to multicast group.
joined := 0
for _, ifi := range ifplist {
if err := wrap.JoinGroup(ifi, gaddr); err != nil {
ssdplog.Printf("failed to join group %s on %s: %s", gaddr.String(), ifi.Name, err)
ssdplog.Printf("failed to join group %s on %s: %s", gaddr.String(), interfaceName(ifi), err)
continue
}
joined++
ssdplog.Printf("joined group %s on %s (#%d)", gaddr.String(), ifi.Name, ifi.Index)
ssdplog.Printf("joined group %s on %s", gaddr.String(), interfaceName(ifi))
}
if joined == 0 {
return nil, errors.New("no interfaces had joined to group")
Expand Down Expand Up @@ -148,10 +158,12 @@ func (mc *Conn) WriteTo(dataProv DataProvider, to net.Addr) (int, error) {
if uaddr, ok := to.(*net.UDPAddr); ok && !uaddr.IP.IsMulticast() {
return mc.writeToIfi(dataProv, to, nil)
}
if len(mc.ifps) == 0 {
return mc.writeToIfi(dataProv, to, nil)
}
// Send a multicast message to all interfaces (iflist).
sum := 0
for _, ifi := range mc.ifps {
ssdplog.Printf("WriteTo: ifi=%+v", ifi)
n, err := mc.writeToIfi(dataProv, to, ifi)
if err != nil {
return 0, err
Expand Down Expand Up @@ -217,8 +229,21 @@ func ConnTTL(ttl int) ConnOption {
})
}

// ConnSystemAssginedInterface returns ConnOption that use a system assigned
// interface for multicast.
// This can't be combined with ConnInterfaces.
func ConnSystemAssginedInterface() ConnOption {
return connOptFunc(func(cfg *connConfig) {
cfg.sysIf = true
})
}

// ConnInterfaces returns ConnInterfaces that specify interfaces to join the
// multicast group.
//
// This can't be combined with ConnSystemAssginedInterface.
func ConnInterfaces(ifps []*net.Interface) ConnOption {
return connOptFunc(func(cfg *connConfig) {
cfg.ifps = ifps
})
}
74 changes: 74 additions & 0 deletions option_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package ssdp

import (
"sync"
"testing"
"time"
)

// TestTTL tests TTL() doesn't something bad.
func TestTTL(t *testing.T) {
// start alive monitor.
var mu sync.Mutex
var alives []*AliveMessage
m := newTestMonitor("test:ttl", func(m *AliveMessage) {
mu.Lock()
alives = append(alives, m)
mu.Unlock()
}, nil, nil)
err := m.Start()
if err != nil {
t.Fatalf("failed to start Monitor: %s", err)
}
defer m.Close()

// send test alive with TTL:2
err = AnnounceAlive("test:ttl", "usn:ttl", "location:ttl", "server:ttl", 600, "", TTL(2))
if err != nil {
t.Fatalf("failed to announce alive: %s", err)
}
time.Sleep(500 * time.Millisecond)

checkAlives(t, alives, "test:ttl", "usn:ttl", "location:ttl", "server:ttl")
}

// TestOnlySystemInterface tests OnlySystemInterface().
// Monitor with OnlySystemInterface() and send alive message to all interfaces.
// Monitor will receive just an alive message for default interface.
func TestOnlySystemInterface(t *testing.T) {
// start alive monitor with OnlySystemInterface.
var mu sync.Mutex
var alives []*AliveMessage
m := newTestMonitor("test:sysif", func(m *AliveMessage) {
mu.Lock()
alives = append(alives, m)
mu.Unlock()
}, nil, nil)
m.Options = append(m.Options, OnlySystemInterface())
err := m.Start()
if err != nil {
t.Fatalf("failed to start Monitor: %s", err)
}
defer m.Close()

// send a test alive.
err = AnnounceAlive("test:sysif", "usn:sysif", "location:sysif", "server:sysif", 600, "")
if err != nil {
t.Fatalf("failed to announce alive: %s", err)
}
time.Sleep(500 * time.Millisecond)

checkAlives(t, alives, "test:sysif", "usn:sysif", "location:sysif", "server:sysif")

if len(alives) != 1 {
t.Fatalf("exact an alive should be detected: but got %d", len(alives))
}
}

func TestLocalAddr(t *testing.T) {
// TODO:
}

func TestRemoteAddr(t *testing.T) {
// TODO:
}

0 comments on commit e534284

Please sign in to comment.