diff --git a/announce_test.go b/announce_test.go index 5fcb122..ea973be 100644 --- a/announce_test.go +++ b/announce_test.go @@ -6,8 +6,48 @@ import ( "sync" "testing" "time" + + "golang.org/x/exp/slices" ) +func checkAlives(t *testing.T, alives []*AliveMessage, typ, usn, loc, srv string) { + t.Helper() + + first := slices.IndexFunc(alives, func(m *AliveMessage) bool { + return m.Type == typ + }) + if first < 0 { + t.Errorf("no AliveMessage which Type is %s", typ) + return + } + + _, port, err := net.SplitHostPort(alives[first].From.String()) + if err != nil { + t.Errorf("failed to split host and port from first message: %s", err) + return + } + port = ":" + port + + for i, m := range alives { + if m.Type != typ { + t.Logf("unexpected alive[%d].Type: want=%q got=%q", i, typ, m.Type) + continue + } + if !strings.HasSuffix(m.From.String(), port) { + t.Errorf("unmatch alive[%d].From (:port): want=%q got=%q", i, port, m.From.String()) + } + if m.USN != usn { + t.Errorf("unexpected alive[%d].USN: want=%q got=%q", i, usn, m.USN) + } + if m.Location != loc { + t.Errorf("unexpected alive[%d].Location: want=%q got=%q", i, loc, m.Location) + } + if m.Server != srv { + t.Errorf("unexpected alive[%d].Server: want=%q got=%q", i, srv, m.Server) + } + } +} + func TestAnnounceAlive(t *testing.T) { var mu sync.Mutex var mm []*AliveMessage @@ -28,32 +68,7 @@ func TestAnnounceAlive(t *testing.T) { } time.Sleep(500 * time.Millisecond) - if len(mm) < 1 { - t.Fatal("no alives detected") - } - //t.Logf("found %d alives", len(mm)) - _, port, err := net.SplitHostPort(mm[0].From.String()) - if err != nil { - t.Fatalf("failed to split host and port: %s", err) - } - port = ":" + port - for i, m := range mm { - if strings.HasSuffix(port, m.From.String()) { - t.Errorf("unmatch port#%d:\nwant=%q\n got=%q", i, port, m.From.String()) - } - if m.Type != "test:announce+alive" { - t.Errorf("unexpected alive#%d type: want=%q got=%q", i, "test:announce+alive", m.Type) - } - if m.USN != "usn:announce+alive" { - t.Errorf("unexpected alive#%d usn: want=%q got=%q", i, "usn:announce+alive", m.USN) - } - if m.Location != "location:announce+alive" { - t.Errorf("unexpected alive#%d location: want=%q got=%q", i, "location:announce+alive", m.Location) - } - if m.Server != "server:announce+alive" { - t.Errorf("unexpected alive#%d server: want=%q got=%q", i, "server:announce+alive", m.Server) - } - } + checkAlives(t, mm, "test:announce+alive", "usn:announce+alive", "location:announce+alive", "server:announce+alive") } func TestAnnounceBye(t *testing.T) { diff --git a/examples/monitor/monitor.go b/examples/monitor/monitor.go index 9a5d60e..c5549a4 100644 --- a/examples/monitor/monitor.go +++ b/examples/monitor/monitor.go @@ -16,6 +16,7 @@ func main() { flag.StringVar(&filterType, "filter_type", "", "print only a specified type (ST or NT). default is print all types.") ttl := flag.Int("ttl", 0, "TTL for outgoing multicast packets") sysIf := flag.Bool("sysif", false, "use system assigned multicast interface") + laddr := flag.String("laddr", "", "local address to listen") flag.Parse() if *h { @@ -33,6 +34,9 @@ func main() { if *sysIf { opts = append(opts, ssdp.OnlySystemInterface()) } + if *laddr != "" { + opts = append(opts, ssdp.LocalAddr(*laddr)) + } m := &ssdp.Monitor{ Alive: onAlive, diff --git a/go.mod b/go.mod index f3baf60..a6da951 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,9 @@ module github.com/koron/go-ssdp go 1.20 -require golang.org/x/net v0.17.0 +require ( + golang.org/x/exp v0.0.0-20231006140011-7918f672742d + golang.org/x/net v0.17.0 +) require golang.org/x/sys v0.13.0 // indirect diff --git a/go.sum b/go.sum index 7b00fe3..e7cea4d 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= diff --git a/internal/multicast/interface.go b/internal/multicast/interface.go index 15b7ee6..796e70f 100644 --- a/internal/multicast/interface.go +++ b/internal/multicast/interface.go @@ -8,6 +8,9 @@ type InterfacesProviderFunc func() []net.Interface // InterfacesProvider specify a function to list all interfaces to multicast. // If no provider are given, all possible interfaces will be used. +// +// Deprecated: this setting item is not good because it affects globaly. +// Use ConnInterfaces() option for each function call, instead of. var InterfacesProvider InterfacesProviderFunc // SystemAssignedInterface indicates use the system assigned multicast interface or not. diff --git a/internal/multicast/multicast.go b/internal/multicast/multicast.go index a26fa78..7e9792b 100644 --- a/internal/multicast/multicast.go +++ b/internal/multicast/multicast.go @@ -2,6 +2,7 @@ package multicast import ( "errors" + "fmt" "io" "net" "strings" @@ -23,10 +24,32 @@ type Conn struct { type connConfig struct { ttl int sysIf bool + ifps []*net.Interface +} + +func (cfg connConfig) interfaces() ([]*net.Interface, error) { + if cfg.sysIf { + if cfg.ifps != nil { + return nil, errors.New("both of ssdp.ConnSystemAssginedInterface and ConnInterfaces are specified") + } + return []*net.Interface{nil}, nil + } + if cfg.ifps != nil { + return cfg.ifps, nil + } + list, err := interfaces() + if err != nil { + return nil, err + } + ifplist := make([]*net.Interface, 0, len(list)) + for i := range list { + ifplist = append(ifplist, &list[i]) + } + return ifplist, nil } // Listen starts to receiving multicast messages. -func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, error) { +func Listen(laddrResolver, gaddrResolver Resolver, opts ...ConnOption) (*Conn, error) { // prepare parameters. laddr, err := laddrResolver.Resolve() if err != nil { @@ -43,7 +66,7 @@ func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, e return nil, err } // configure socket to use with multicast. - pconn, ifplist, err := newIPv4MulticastConn(conn, raddrResolver, cfg.sysIf) + pconn, ifplist, err := newIPv4MulticastConn(conn, gaddrResolver, cfg) if err != nil { conn.Close() return nil, err @@ -65,56 +88,43 @@ func Listen(laddrResolver, raddrResolver Resolver, opts ...ConnOption) (*Conn, e // newIPv4MulticastConn create a new multicast connection. // 2nd return parameter will be nil when sysIf is true. -func newIPv4MulticastConn(conn *net.UDPConn, raddrResolver Resolver, sysIf bool) (*ipv4.PacketConn, []*net.Interface, error) { - // sysIf: use system assigned multicast interface. - // the empty iflist indicate it. - var ifplist []*net.Interface - if !sysIf { - list, err := interfaces() - if err != nil { - return nil, nil, err - } - ifplist = make([]*net.Interface, 0, len(list)) - for i := range list { - ifplist = append(ifplist, &list[i]) - } +func newIPv4MulticastConn(conn *net.UDPConn, gaddrResolver Resolver, cfg connConfig) (*ipv4.PacketConn, []*net.Interface, error) { + ifplist, err := cfg.interfaces() + if err != nil { + return nil, nil, err } - raddr, err := raddrResolver.Resolve() + gaddr, err := gaddrResolver.Resolve() if err != nil { return nil, nil, err } - pconn, err := joinGroupIPv4(conn, ifplist, raddr) + pconn, err := joinGroupIPv4(conn, ifplist, gaddr) if err != nil { return nil, nil, err } return pconn, ifplist, nil } +func interfaceName(ifi *net.Interface) string { + if ifi == nil { + return "system assigned multicast interface (nil)" + } + return fmt.Sprintf("%s (#%d)", ifi.Name, ifi.Index) +} + // joinGroupIPv4 makes the connection join to a group on interfaces. // This trys to use system assigned when iflist is nil or empty. func joinGroupIPv4(conn *net.UDPConn, ifplist []*net.Interface, gaddr net.Addr) (*ipv4.PacketConn, error) { wrap := ipv4.NewPacketConn(conn) wrap.SetMulticastLoopback(true) - - // try to use the system assigned multicast interface when iflist is empty. - if len(ifplist) == 0 { - if err := wrap.JoinGroup(nil, gaddr); err != nil { - ssdplog.Printf("failed to join group %s on system assigned multicast interface: %s", gaddr.String(), err) - return nil, errors.New("no system assigned multicast interfaces had joined to group") - } - ssdplog.Printf("joined group %s on system assigned multicast interface", gaddr.String()) - return wrap, nil - } - // add interfaces to multicast group. joined := 0 for _, ifi := range ifplist { if err := wrap.JoinGroup(ifi, gaddr); err != nil { - ssdplog.Printf("failed to join group %s on %s: %s", gaddr.String(), ifi.Name, err) + ssdplog.Printf("failed to join group %s on %s: %s", gaddr.String(), interfaceName(ifi), err) continue } joined++ - ssdplog.Printf("joined group %s on %s (#%d)", gaddr.String(), ifi.Name, ifi.Index) + ssdplog.Printf("joined group %s on %s", gaddr.String(), interfaceName(ifi)) } if joined == 0 { return nil, errors.New("no interfaces had joined to group") @@ -148,10 +158,12 @@ func (mc *Conn) WriteTo(dataProv DataProvider, to net.Addr) (int, error) { if uaddr, ok := to.(*net.UDPAddr); ok && !uaddr.IP.IsMulticast() { return mc.writeToIfi(dataProv, to, nil) } + if len(mc.ifps) == 0 { + return mc.writeToIfi(dataProv, to, nil) + } // Send a multicast message to all interfaces (iflist). sum := 0 for _, ifi := range mc.ifps { - ssdplog.Printf("WriteTo: ifi=%+v", ifi) n, err := mc.writeToIfi(dataProv, to, ifi) if err != nil { return 0, err @@ -217,8 +229,21 @@ func ConnTTL(ttl int) ConnOption { }) } +// ConnSystemAssginedInterface returns ConnOption that use a system assigned +// interface for multicast. +// This can't be combined with ConnInterfaces. func ConnSystemAssginedInterface() ConnOption { return connOptFunc(func(cfg *connConfig) { cfg.sysIf = true }) } + +// ConnInterfaces returns ConnInterfaces that specify interfaces to join the +// multicast group. +// +// This can't be combined with ConnSystemAssginedInterface. +func ConnInterfaces(ifps []*net.Interface) ConnOption { + return connOptFunc(func(cfg *connConfig) { + cfg.ifps = ifps + }) +} diff --git a/option_test.go b/option_test.go new file mode 100644 index 0000000..aabf6e3 --- /dev/null +++ b/option_test.go @@ -0,0 +1,74 @@ +package ssdp + +import ( + "sync" + "testing" + "time" +) + +// TestTTL tests TTL() doesn't something bad. +func TestTTL(t *testing.T) { + // start alive monitor. + var mu sync.Mutex + var alives []*AliveMessage + m := newTestMonitor("test:ttl", func(m *AliveMessage) { + mu.Lock() + alives = append(alives, m) + mu.Unlock() + }, nil, nil) + err := m.Start() + if err != nil { + t.Fatalf("failed to start Monitor: %s", err) + } + defer m.Close() + + // send test alive with TTL:2 + err = AnnounceAlive("test:ttl", "usn:ttl", "location:ttl", "server:ttl", 600, "", TTL(2)) + if err != nil { + t.Fatalf("failed to announce alive: %s", err) + } + time.Sleep(500 * time.Millisecond) + + checkAlives(t, alives, "test:ttl", "usn:ttl", "location:ttl", "server:ttl") +} + +// TestOnlySystemInterface tests OnlySystemInterface(). +// Monitor with OnlySystemInterface() and send alive message to all interfaces. +// Monitor will receive just an alive message for default interface. +func TestOnlySystemInterface(t *testing.T) { + // start alive monitor with OnlySystemInterface. + var mu sync.Mutex + var alives []*AliveMessage + m := newTestMonitor("test:sysif", func(m *AliveMessage) { + mu.Lock() + alives = append(alives, m) + mu.Unlock() + }, nil, nil) + m.Options = append(m.Options, OnlySystemInterface()) + err := m.Start() + if err != nil { + t.Fatalf("failed to start Monitor: %s", err) + } + defer m.Close() + + // send a test alive. + err = AnnounceAlive("test:sysif", "usn:sysif", "location:sysif", "server:sysif", 600, "") + if err != nil { + t.Fatalf("failed to announce alive: %s", err) + } + time.Sleep(500 * time.Millisecond) + + checkAlives(t, alives, "test:sysif", "usn:sysif", "location:sysif", "server:sysif") + + if len(alives) != 1 { + t.Fatalf("exact an alive should be detected: but got %d", len(alives)) + } +} + +func TestLocalAddr(t *testing.T) { + // TODO: +} + +func TestRemoteAddr(t *testing.T) { + // TODO: +}