From 557f9026072c9eef6aa37993109c42f72c5489fb Mon Sep 17 00:00:00 2001 From: OrlandoCo Date: Tue, 8 Jun 2021 09:26:18 -0500 Subject: [PATCH] feat(relay): add remote peers requests methods (#538) * feat(relay): add send/request methods * feat(relay): fix locking * feat(relay): fix pending map --- pkg/relay/relay.go | 292 ++++++++++++++++++++++++++----------------- pkg/sfu/publisher.go | 2 +- 2 files changed, 176 insertions(+), 118 deletions(-) diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go index abd08fa01..52ad73408 100644 --- a/pkg/relay/relay.go +++ b/pkg/relay/relay.go @@ -1,6 +1,7 @@ package relay import ( + "context" "encoding/json" "errors" "fmt" @@ -15,8 +16,8 @@ import ( ) const ( - signalerLabel = "ion_sfu_relay_signaler" - signalerEvent = "ion_sfu_relay_event" + signalerLabel = "ion_sfu_relay_signaler" + signalerRequestEvent = "ion_relay_request" ) var ( @@ -34,12 +35,9 @@ type signal struct { TrackMeta *TrackMeta `json:"trackInfo,omitempty"` } -type signalRequest struct { - ID uint32 `json:"id"` - Signal *signal `json:"signal,omitempty"` -} - -type Request struct { +type request struct { + ID uint64 `json:"id"` + IsReply bool `json:"reply"` Event string `json:"event"` Payload []byte `json:"payload"` } @@ -62,26 +60,28 @@ type PeerMeta struct { } type Peer struct { - mu sync.Mutex - me *webrtc.MediaEngine - log logr.Logger - api *webrtc.API - ice *webrtc.ICETransport - meta PeerMeta - sctp *webrtc.SCTPTransport - dtls *webrtc.DTLSTransport - role *webrtc.ICERole - ready bool - senders []*webrtc.RTPSender - receivers []*webrtc.RTPReceiver - pendingSender map[uint32]func() - gatherer *webrtc.ICEGatherer - localTracks []webrtc.TrackLocal - dcIndex uint16 - signalingDC *webrtc.DataChannel + mu sync.Mutex + rmu sync.Mutex + me *webrtc.MediaEngine + log logr.Logger + api *webrtc.API + ice *webrtc.ICETransport + rand *rand.Rand + meta PeerMeta + sctp *webrtc.SCTPTransport + dtls *webrtc.DTLSTransport + role *webrtc.ICERole + ready bool + senders []*webrtc.RTPSender + receivers []*webrtc.RTPReceiver + pendingRequests map[uint64]chan []byte + localTracks []webrtc.TrackLocal + signalingDC *webrtc.DataChannel + gatherer *webrtc.ICEGatherer + dcIndex uint16 onReady func() - onRequest func(r Request) + onRequest func(event string, message Message) onDataChannel func(channel *webrtc.DataChannel) onTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta) } @@ -110,15 +110,16 @@ func NewPeer(meta PeerMeta, conf *PeerConfig) (*Peer, error) { } p := &Peer{ - me: &me, - api: api, - log: conf.Logger, - ice: i, - meta: meta, - sctp: sctp, - dtls: dtls, - gatherer: gatherer, - pendingSender: make(map[uint32]func()), + me: &me, + api: api, + log: conf.Logger, + ice: i, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + meta: meta, + sctp: sctp, + dtls: dtls, + gatherer: gatherer, + pendingRequests: make(map[uint64]chan []byte), } sctp.OnDataChannel(func(channel *webrtc.DataChannel) { @@ -292,29 +293,12 @@ func (p *Peer) OnReady(f func()) { } // OnRequest calls the callback when Peer gets a request message from remote Peer -func (p *Peer) OnRequest(f func(r Request)) { +func (p *Peer) OnRequest(f func(event string, msg Message)) { p.mu.Lock() p.onRequest = f p.mu.Unlock() } -// Request is used to send messages to remote Peer that will end in remote Peer. Other -// data channels if used in ion-sfu may act as middlewares or fan outs. -func (p *Peer) Request(r Request) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.signalingDC == nil { - return ErrRelaySignalDCNotReady - } - - b, err := json.Marshal(r) - if err != nil { - return err - } - return p.signalingDC.Send(b) -} - // OnDataChannel sets an event handler which is invoked when a data // channel message arrives from a remote Peer. func (p *Peer) OnDataChannel(f func(channel *webrtc.DataChannel)) { @@ -428,8 +412,8 @@ func (p *Peer) receive(s *signal) error { return nil } -// Send is used to negotiate a track to the remote peer -func (p *Peer) Send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, +// AddTrack is used to negotiate a track to the remote peer +func (p *Peer) AddTrack(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, localTrack webrtc.TrackLocal) (*webrtc.RTPSender, error) { p.mu.Lock() defer p.mu.Unlock() @@ -443,19 +427,16 @@ func (p *Peer) Send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemot return nil, err } - rr := rand.New(rand.NewSource(time.Now().UnixNano())) - s := signalRequest{ - ID: rr.Uint32(), - Signal: &signal{}, - } - s.Signal.TrackMeta = &TrackMeta{ + s := &signal{} + + s.TrackMeta = &TrackMeta{ StreamID: remoteTrack.StreamID(), TrackID: remoteTrack.ID(), CodecParameters: &codec, } - s.Signal.Encodings = &webrtc.RTPCodingParameters{ - SSRC: webrtc.SSRC(rr.Uint32()), + s.Encodings = &webrtc.RTPCodingParameters{ + SSRC: webrtc.SSRC(p.rand.Uint32()), PayloadType: remoteTrack.PayloadType(), } pld, err := json.Marshal(&s) @@ -463,9 +444,57 @@ func (p *Peer) Send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemot return nil, err } - req := Request{ - Event: signalerEvent, - Payload: pld, + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + if _, err = p.Request(ctx, signalerRequestEvent, pld); err != nil { + return nil, err + } + + params := receiver.GetParameters() + + if err = sdr.Send(webrtc.RTPSendParameters{ + RTPParameters: params, + Encodings: []webrtc.RTPEncodingParameters{ + { + webrtc.RTPCodingParameters{ + SSRC: s.Encodings.SSRC, + PayloadType: s.Encodings.PayloadType, + }, + }, + }, + }); err != nil { + p.log.Error(err, "Send RTPSender failed") + } + + p.localTracks = append(p.localTracks, localTrack) + p.senders = append(p.senders, sdr) + return sdr, nil +} + +// Emit emits the data argument to remote peer. +func (p *Peer) Emit(event string, data []byte) error { + req := request{ + ID: p.rand.Uint64(), + Event: event, + Payload: data, + } + + msg, err := json.Marshal(req) + if err != nil { + return err + } + + if err = p.signalingDC.Send(msg); err != nil { + return err + } + return nil +} + +func (p *Peer) Request(ctx context.Context, event string, data []byte) ([]byte, error) { + req := request{ + ID: p.rand.Uint64(), + Event: event, + Payload: data, } msg, err := json.Marshal(req) @@ -477,83 +506,97 @@ func (p *Peer) Send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemot return nil, err } - params := receiver.GetParameters() + resp := make(chan []byte, 1) - p.pendingSender[s.ID] = func() { - if err = sdr.Send(webrtc.RTPSendParameters{ - RTPParameters: params, - Encodings: []webrtc.RTPEncodingParameters{ - { - webrtc.RTPCodingParameters{ - SSRC: s.Signal.Encodings.SSRC, - PayloadType: s.Signal.Encodings.PayloadType, - }, - }, - }, - }); err != nil { - p.log.Error(err, "Send RTPSender failed") - } + p.rmu.Lock() + p.pendingRequests[req.ID] = resp + p.rmu.Unlock() + + defer func() { + p.rmu.Lock() + delete(p.pendingRequests, req.ID) + p.rmu.Unlock() + }() + + select { + case r := <-resp: + return r, nil + case <-ctx.Done(): + return nil, ctx.Err() } - p.localTracks = append(p.localTracks, localTrack) - p.senders = append(p.senders, sdr) - return sdr, nil } func (p *Peer) handleRequest(msg webrtc.DataChannelMessage) { - mr := &Request{} + mr := &request{} if err := json.Unmarshal(msg.Data, mr); err != nil { p.log.Error(err, "Error marshaling remote message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) return } - if mr.Event != signalerEvent { + if mr.Event == signalerRequestEvent && !mr.IsReply { p.mu.Lock() - if p.onRequest != nil { - p.onRequest(*mr) + defer p.mu.Unlock() + + r := &signal{} + if err := json.Unmarshal(mr.Payload, r); err != nil { + p.log.Error(err, "Error marshaling remote message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) + return + } + if err := p.receive(r); err != nil { + p.log.Error(err, "Error receiving remote track", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) + return + } + if err := p.reply(mr.ID, mr.Event, nil); err != nil { + p.log.Error(err, "Error replying message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) + return } - p.mu.Unlock() - return - } - r := &signalRequest{} - if err := json.Unmarshal(mr.Payload, r); err != nil { - p.log.Error(err, "Error marshaling remote message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) return } - p.mu.Lock() - defer p.mu.Unlock() - - if r.Signal == nil { - if f, ok := p.pendingSender[r.ID]; ok { - f() + if mr.IsReply { + p.rmu.Lock() + if c, ok := p.pendingRequests[mr.ID]; ok { + c <- mr.Payload + delete(p.pendingRequests, mr.ID) } + p.rmu.Unlock() return } - if err := p.receive(r.Signal); err != nil { - return - } - rr := &signalRequest{ - ID: r.ID, - } - d, err := json.Marshal(rr) - if err != nil { - p.log.Error(err, "Error marshaling remote signalRequest", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") + if mr.Event != signalerRequestEvent { + p.mu.Lock() + if p.onRequest != nil { + p.onRequest(mr.Event, Message{ + p: p, + event: mr.Event, + id: mr.ID, + msg: mr.Payload, + }) + } + p.mu.Unlock() return } - req := Request{ - Event: signalerEvent, - Payload: d, + +} + +func (p *Peer) reply(id uint64, event string, payload []byte) error { + req := request{ + ID: id, + Event: event, + Payload: payload, + IsReply: true, } - d, err = json.Marshal(req) + + msg, err := json.Marshal(req) if err != nil { - p.log.Error(err, "Error marshaling response Request", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") - return + return err } - if err = p.signalingDC.Send(d); err != nil { - p.log.Error(err, "Error sending response", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") + + if err = p.signalingDC.Send(msg); err != nil { + return err } + return nil } func joinErrs(errs ...error) error { @@ -580,3 +623,18 @@ func joinErrs(errs ...error) error { } return joinErrsR("", 0, errs...) } + +type Message struct { + p *Peer + event string + id uint64 + msg []byte +} + +func (m *Message) Payload() []byte { + return m.msg +} + +func (m *Message) Reply(msg []byte) error { + return m.p.reply(m.id, m.event, msg) +} diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index 0e86039b7..1170d5a51 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -232,7 +232,7 @@ func (p *Publisher) createRelayTrack(track *webrtc.TrackRemote, receiver Receive return err } - sdr, err := rp.Send(receiver.(*WebRTCReceiver).receiver, track, downTrack) + sdr, err := rp.AddTrack(receiver.(*WebRTCReceiver).receiver, track, downTrack) if err != nil { Logger.V(1).Error(err, "Relaying track.", "peer_id", p.id) return fmt.Errorf("relay: %w", err)