From 638673624a53cc8ddf42a92aa6832fa21d45edcb Mon Sep 17 00:00:00 2001 From: Nikola Novakovic Date: Wed, 12 May 2021 11:23:43 -1000 Subject: [PATCH] Simulcast: Allow layer selection for publisher (#477) * Simulcast: Allow layer selection for publisher Changed ion-sfu channel messaging to be peer based and added aditional types Addressing PR comments Rename IonSfuMessage to ChannelAPIMessage. Use interface{} instead of string, fixes double marshalling issue * Rebase on new interface changes --- pkg/middlewares/datachannel/subscriberapi.go | 161 +++++++++++++++---- pkg/sfu/downtrack.go | 35 ++++ pkg/sfu/peer.go | 22 +++ pkg/sfu/session.go | 13 +- 4 files changed, 195 insertions(+), 36 deletions(-) diff --git a/pkg/middlewares/datachannel/subscriberapi.go b/pkg/middlewares/datachannel/subscriberapi.go index 9b83a7c13..2a45a0c24 100644 --- a/pkg/middlewares/datachannel/subscriberapi.go +++ b/pkg/middlewares/datachannel/subscriberapi.go @@ -3,23 +3,92 @@ package datachannel import ( "context" "encoding/json" + "fmt" "github.com/pion/ion-sfu/pkg/sfu" "github.com/pion/webrtc/v3" ) const ( - highValue = "high" - mediumValue = "medium" - lowValue = "low" - mutedValue = "none" + highValue = "high" + mediumValue = "medium" + lowValue = "low" + mutedValue = "none" + ActiveLayerMethod = "activeLayer" ) type setRemoteMedia struct { - StreamID string `json:"streamId"` - Video string `json:"video"` - Framerate string `json:"framerate"` - Audio bool `json:"audio"` + StreamID string `json:"streamId"` + Video string `json:"video"` + Framerate string `json:"framerate"` + Audio bool `json:"audio"` + Layers []string `json:"layers"` +} + +type activeLayerMessage struct { + StreamID string `json:"streamId"` + ActiveLayer string `json:"activeLayer"` + AvailableLayers []string `json:"availableLayers"` +} + +func layerStrToInt(layer string) (int, error) { + switch layer { + case highValue: + return 2, nil + case mediumValue: + return 1, nil + case lowValue: + return 0, nil + default: + // unknown value + return -1, fmt.Errorf("Unknown value") + } +} + +func layerIntToStr(layer int) (string, error) { + switch layer { + case 0: + return lowValue, nil + case 1: + return mediumValue, nil + case 2: + return highValue, nil + default: + return "", fmt.Errorf("Unknown value: %d", layer) + } +} + +func transformLayers(layers []string) ([]uint16, error) { + res := make([]uint16, len(layers)) + for _, layer := range layers { + if l, err := layerStrToInt(layer); err == nil { + res = append(res, uint16(l)) + } else { + return nil, fmt.Errorf("Unknown layer value: %v", layer) + } + } + return res, nil +} + +func sendMessage(streamID string, peer sfu.Peer, layers []string, activeLayer int) { + al, _ := layerIntToStr(activeLayer) + payload := activeLayerMessage{ + StreamID: streamID, + ActiveLayer: al, + AvailableLayers: layers, + } + msg := sfu.ChannelAPIMessage{ + Method: ActiveLayerMethod, + Params: payload, + } + bytes, err := json.Marshal(msg) + if err != nil { + sfu.Logger.Error(err, "unable to marshal active layer message") + } + + if err := peer.SendAPIChannelMessage(&bytes); err != nil { + sfu.Logger.Error(err, "unable to send ActiveLayerMessage to peer", "peer_id", peer.ID()) + } } func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor { @@ -28,35 +97,59 @@ func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor { if err := json.Unmarshal(args.Message.Data, srm); err != nil { return } - downTracks := args.Peer.Subscriber().GetDownTracks(srm.StreamID) - for _, dt := range downTracks { - switch dt.Kind() { - case webrtc.RTPCodecTypeAudio: - dt.Mute(!srm.Audio) - case webrtc.RTPCodecTypeVideo: - switch srm.Video { - case highValue: - dt.Mute(false) - dt.SwitchSpatialLayer(2, true) - case mediumValue: - dt.Mute(false) - dt.SwitchSpatialLayer(1, true) - case lowValue: - dt.Mute(false) - dt.SwitchSpatialLayer(0, true) - case mutedValue: - dt.Mute(true) - } - switch srm.Framerate { - case highValue: - dt.SwitchTemporalLayer(2, true) - case mediumValue: - dt.SwitchTemporalLayer(1, true) - case lowValue: - dt.SwitchTemporalLayer(0, true) + // Publisher changing active layers + if srm.Layers != nil && len(srm.Layers) > 0 { + layers, err := transformLayers(srm.Layers) + if err != nil { + sfu.Logger.Error(err, "error reading layers") + next.Process(ctx, args) + return + } + + session := args.Peer.Session() + peers := session.Peers() + for _, peer := range peers { + if peer.ID() != args.Peer.ID() { + downTracks := peer.Subscriber().GetDownTracks(srm.StreamID) + for _, dt := range downTracks { + if dt.Kind() == webrtc.RTPCodecTypeVideo { + newLayer, _ := dt.UptrackLayersChange(layers) + sendMessage(srm.StreamID, peer, srm.Layers, int(newLayer)) + } + } } } + } else { + downTracks := args.Peer.Subscriber().GetDownTracks(srm.StreamID) + for _, dt := range downTracks { + switch dt.Kind() { + case webrtc.RTPCodecTypeAudio: + dt.Mute(!srm.Audio) + case webrtc.RTPCodecTypeVideo: + switch srm.Video { + case highValue: + dt.Mute(false) + dt.SwitchSpatialLayer(2, true) + case mediumValue: + dt.Mute(false) + dt.SwitchSpatialLayer(1, true) + case lowValue: + dt.Mute(false) + dt.SwitchSpatialLayer(0, true) + case mutedValue: + dt.Mute(true) + } + switch srm.Framerate { + case highValue: + dt.SwitchTemporalLayer(2, true) + case mediumValue: + dt.SwitchTemporalLayer(1, true) + case lowValue: + dt.SwitchTemporalLayer(0, true) + } + } + } } next.Process(ctx, args) }) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index eddb98a20..42e45a2fd 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -1,6 +1,7 @@ package sfu import ( + "fmt" "strings" "sync" "sync/atomic" @@ -211,6 +212,40 @@ func (d *DownTrack) SwitchSpatialLayer(targetLayer int64, setAsMax bool) { } } +func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) (int64, error) { + if d.trackType == SimulcastDownTrack { + currentLayer := uint16(atomic.LoadInt32(&d.spatialLayer)) + maxLayer := uint16(atomic.LoadInt64(&d.maxSpatialLayer)) + + var maxFound uint16 = 0 + layerFound := false + var minFound uint16 = 0 + for _, target := range availableLayers { + if target <= maxLayer { + if target > maxFound { + maxFound = target + layerFound = true + } + } else { + if minFound > target { + minFound = target + } + } + } + var targetLayer uint16 + if layerFound { + targetLayer = maxFound + } else { + targetLayer = minFound + } + if currentLayer != targetLayer { + d.SwitchSpatialLayer(int64(targetLayer), false) + } + return int64(targetLayer), nil + } + return -1, fmt.Errorf("Downtrack %s does not support simulcast", d.id) +} + func (d *DownTrack) SwitchTemporalLayer(targetLayer int64, setAsMax bool) { if d.trackType == SimulcastDownTrack { layer := atomic.LoadInt32(&d.temporalLayer) diff --git a/pkg/sfu/peer.go b/pkg/sfu/peer.go index 16140ed7a..c7a3501d3 100644 --- a/pkg/sfu/peer.go +++ b/pkg/sfu/peer.go @@ -30,6 +30,7 @@ type Peer interface { Publisher() *Publisher Subscriber() *Subscriber Close() error + SendAPIChannelMessage(msg *[]byte) error } // JoinConfig allow adding more control to the peers joining a SessionLocal. @@ -48,6 +49,11 @@ type SessionProvider interface { GetSession(sid string) (Session, WebRTCTransportConfig) } +type ChannelAPIMessage struct { + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + // PeerLocal represents a pair peer connection type PeerLocal struct { sync.Mutex @@ -241,6 +247,22 @@ func (p *PeerLocal) Trickle(candidate webrtc.ICECandidateInit, target int) error return nil } +func (p *PeerLocal) SendAPIChannelMessage(msg *[]byte) error { + if p.subscriber == nil { + return fmt.Errorf("No subscriber for this peer") + } + dc := p.subscriber.DataChannel(APIChannelLabel) + + if dc == nil { + return fmt.Errorf("Data channel %s doesn't exist", APIChannelLabel) + } + + if err := dc.SendText(string(*msg)); err != nil { + return fmt.Errorf("Failed to send message: %v", err) + } + return nil +} + // Close shuts down the peer connection and sends true to the done channel func (p *PeerLocal) Close() error { p.Lock() diff --git a/pkg/sfu/session.go b/pkg/sfu/session.go index 663f37f26..c0033788a 100644 --- a/pkg/sfu/session.go +++ b/pkg/sfu/session.go @@ -22,6 +22,7 @@ type Session interface { GetDCMiddlewares() []*Datachannel GetDataChannelLabels() []string GetDataChannels(origin, label string) (dcs []*webrtc.DataChannel) + Peers() []Peer } type SessionLocal struct { @@ -35,6 +36,10 @@ type SessionLocal struct { onCloseHandler func() } +const ( + AudioLevelsMethod = "audioLevels" +) + // NewSession creates a new SessionLocal func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) Session { s := &SessionLocal{ @@ -45,7 +50,6 @@ func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) Sessio } go s.audioLevelObserver(cfg.Router.AudioLevelInterval) return s - } // ID return SessionLocal id @@ -251,7 +255,12 @@ func (s *SessionLocal) audioLevelObserver(audioLevelInterval int) { continue } - l, err := json.Marshal(&levels) + msg := ChannelAPIMessage{ + Method: AudioLevelsMethod, + Params: levels, + } + + l, err := json.Marshal(&msg) if err != nil { Logger.Error(err, "Marshaling audio levels err") continue