diff --git a/cmd/signal/grpc/server/server.go b/cmd/signal/grpc/server/server.go index 45df5a336..6d4f12bf1 100644 --- a/cmd/signal/grpc/server/server.go +++ b/cmd/signal/grpc/server/server.go @@ -149,10 +149,12 @@ func (s *SFUServer) Signal(stream pb.SFU_SignalServer) error { _, nopub := payload.Join.Config["NoPublish"] _, nosub := payload.Join.Config["NoSubscribe"] + _, noautosub := payload.Join.Config["NoAutoSubscribe"] cfg := sfu.JoinConfig{ - NoPublish: nopub, - NoSubscribe: nosub, + NoPublish: nopub, + NoSubscribe: nosub, + NoAutoSubscribe: noautosub, } err = peer.Join(payload.Join.Sid, payload.Join.Uid, cfg) if err != nil { diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index ca3143535..a50217fcb 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -143,6 +143,13 @@ func (d *DownTrack) Kind() webrtc.RTPCodecType { } } +func (d *DownTrack) Stop() error { + if d.transceiver != nil { + return d.transceiver.Stop() + } + return fmt.Errorf("d.transceiver not exists") +} + func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { d.transceiver = transceiver } @@ -161,6 +168,10 @@ func (d *DownTrack) WriteRTP(p *buffer.ExtPacket) error { return nil } +func (d *DownTrack) Enabled() bool { + return d.enabled.get() +} + // Mute enables or disables media forwarding func (d *DownTrack) Mute(val bool) { if d.enabled.get() != val { diff --git a/pkg/sfu/peer.go b/pkg/sfu/peer.go index 8afaa0d22..7fa2432a1 100644 --- a/pkg/sfu/peer.go +++ b/pkg/sfu/peer.go @@ -39,6 +39,11 @@ type JoinConfig struct { NoPublish bool // If true the peer will not be allowed to subscribe to other peers in SessionLocal. NoSubscribe bool + // If true the peer will not automatically subscribe all tracks, + // and then the peer can use peer.Subscriber().AddDownTrack/RemoveDownTrack + // to customize the subscrbe stream combination as needed. + // this parameter depends on NoSubscribe=false. + NoAutoSubscribe bool } // SessionProvider provides the SessionLocal to the sfu.Peer @@ -105,6 +110,8 @@ func (p *PeerLocal) Join(sid, uid string, config ...JoinConfig) error { return fmt.Errorf("error creating transport: %v", err) } + p.subscriber.noAutoSubscribe = conf.NoAutoSubscribe + p.subscriber.OnNegotiationNeeded(func() { p.Lock() defer p.Unlock() diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index 81fb8bf82..020acecbd 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -20,18 +20,19 @@ type Publisher struct { router Router session Session - tracks []publisherTracks + tracks []PublisherTrack relayPeer []*relay.Peer candidates []webrtc.ICECandidateInit onICEConnectionStateChangeHandler atomic.Value // func(webrtc.ICEConnectionState) + onPublisherTrack atomic.Value // func(PublisherTrack) closeOnce sync.Once } -type publisherTracks struct { - track *webrtc.TrackRemote - receiver Receiver +type PublisherTrack struct { + Track *webrtc.TrackRemote + Receiver Receiver // This will be used in the future for tracks that will be relayed as clients or servers // This is for SVC and Simulcast where you will be able to chose if the relayed peer just // want a single track (for recording/ processing) or get all the tracks (for load balancing) @@ -75,16 +76,20 @@ func NewPublisher(id string, session Session, cfg *WebRTCTransportConfig) (*Publ if pub { p.session.Publish(p.router, r) p.mu.Lock() - p.tracks = append(p.tracks, publisherTracks{track, r, true}) + publisherTrack := PublisherTrack{track, r, true} + p.tracks = append(p.tracks, publisherTrack) for _, rp := range p.relayPeer { if err = p.createRelayTrack(track, r, rp); err != nil { Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) } } p.mu.Unlock() + if handler, ok := p.onPublisherTrack.Load().(func(PublisherTrack)); ok && handler != nil { + handler(publisherTrack) + } } else { p.mu.Lock() - p.tracks = append(p.tracks, publisherTracks{track, r, false}) + p.tracks = append(p.tracks, PublisherTrack{track, r, false}) p.mu.Unlock() } }) @@ -161,6 +166,10 @@ func (p *Publisher) Close() { }) } +func (p *Publisher) OnPublisherTrack(f func(track PublisherTrack)) { + p.onPublisherTrack.Store(f) +} + // OnICECandidate handler func (p *Publisher) OnICECandidate(f func(c *webrtc.ICECandidate)) { p.pc.OnICECandidate(f) @@ -204,12 +213,13 @@ func (p *Publisher) Relay(ice []webrtc.ICEServer) (*relay.Peer, error) { // simulcast will just relay client track for now continue } - if err = p.createRelayTrack(tp.track, tp.receiver, rp); err != nil { + if err = p.createRelayTrack(tp.Track, tp.Receiver, rp); err != nil { Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) } } p.relayPeer = append(p.relayPeer, rp) p.mu.Unlock() + go p.relayReports(rp) }) @@ -220,13 +230,24 @@ func (p *Publisher) Relay(ice []webrtc.ICEServer) (*relay.Peer, error) { return rp, nil } +func (p *Publisher) PublisherTracks() []PublisherTrack { + p.mu.Lock() + defer p.mu.Unlock() + + tracks := make([]PublisherTrack, len(p.tracks)) + for idx, track := range p.tracks { + tracks[idx] = track + } + return tracks +} + func (p *Publisher) Tracks() []*webrtc.TrackRemote { p.mu.Lock() defer p.mu.Unlock() tracks := make([]*webrtc.TrackRemote, len(p.tracks)) for idx, track := range p.tracks { - tracks[idx] = track.track + tracks[idx] = track.Track } return tracks } diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go index 7c1982ea5..e7c353911 100644 --- a/pkg/sfu/router.go +++ b/pkg/sfu/router.go @@ -15,6 +15,7 @@ type Router interface { ID() string AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) (Receiver, bool) AddDownTracks(s *Subscriber, r Receiver) error + AddDownTrack(s *Subscriber, r Receiver) (*DownTrack, error) Stop() } @@ -188,8 +189,13 @@ func (r *router) AddDownTracks(s *Subscriber, recv Receiver) error { r.Lock() defer r.Unlock() + if s.noAutoSubscribe { + Logger.Info("peer turns off automatic subscription, skip tracks add") + return nil + } + if recv != nil { - if err := r.addDownTrack(s, recv); err != nil { + if _, err := r.AddDownTrack(s, recv); err != nil { return err } s.negotiate() @@ -198,7 +204,7 @@ func (r *router) AddDownTracks(s *Subscriber, recv Receiver) error { if len(r.receivers) > 0 { for _, rcv := range r.receivers { - if err := r.addDownTrack(s, rcv); err != nil { + if _, err := r.AddDownTrack(s, rcv); err != nil { return err } } @@ -207,16 +213,16 @@ func (r *router) AddDownTracks(s *Subscriber, recv Receiver) error { return nil } -func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { +func (r *router) AddDownTrack(sub *Subscriber, recv Receiver) (*DownTrack, error) { for _, dt := range sub.GetDownTracks(recv.StreamID()) { if dt.ID() == recv.TrackID() { - return nil + return dt, nil } } codec := recv.Codec() if err := sub.me.RegisterCodec(codec, recv.Kind()); err != nil { - return err + return nil, err } downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ @@ -227,13 +233,13 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { RTCPFeedback: []webrtc.RTCPFeedback{{"goog-remb", ""}, {"nack", ""}, {"nack", "pli"}}, }, recv, r.bufferFactory, sub.id, r.config.MaxPacketTrack) if err != nil { - return err + return nil, err } // Create webrtc sender for the peer we are sending track to if downTrack.transceiver, err = sub.pc.AddTransceiverFromTrack(downTrack, webrtc.RTPTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionSendonly, }); err != nil { - return err + return nil, err } // nolint:scopelint @@ -257,7 +263,7 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { sub.AddDownTrack(recv.StreamID(), downTrack) recv.AddDownTrack(downTrack, r.config.Simulcast.BestQualityFirst) - return nil + return downTrack, nil } func (r *router) deleteReceiver(track string, ssrc uint32) { diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go index 40a6669df..4b1995bf2 100644 --- a/pkg/sfu/subscriber.go +++ b/pkg/sfu/subscriber.go @@ -26,6 +26,8 @@ type Subscriber struct { negotiate func() closeOnce sync.Once + + noAutoSubscribe bool } // NewSubscriber creates a new Subscriber @@ -44,11 +46,12 @@ func NewSubscriber(id string, cfg WebRTCTransportConfig) (*Subscriber, error) { } s := &Subscriber{ - id: id, - me: me, - pc: pc, - tracks: make(map[string][]*DownTrack), - channels: make(map[string]*webrtc.DataChannel), + id: id, + me: me, + pc: pc, + tracks: make(map[string][]*DownTrack), + channels: make(map[string]*webrtc.DataChannel), + noAutoSubscribe: false, } pc.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { @@ -215,6 +218,16 @@ func (s *Subscriber) GetDatachannel(label string) *webrtc.DataChannel { return s.DataChannel(label) } +func (s *Subscriber) DownTracks() []*DownTrack { + s.RLock() + defer s.RUnlock() + var downTracks []*DownTrack + for _, tracks := range s.tracks { + downTracks = append(downTracks, tracks...) + } + return downTracks +} + func (s *Subscriber) GetDownTracks(streamID string) []*DownTrack { s.RLock() defer s.RUnlock()