Skip to content

Commit

Permalink
Check for nil transceivers on get parameters
Browse files Browse the repository at this point in the history
For ORTC API senders does not have a transceiver causing
panics on getting parameters.
  • Loading branch information
ourwarmhouse committed Aug 8, 2021
1 parent 305fc18 commit 99633ae
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 96 deletions.
64 changes: 64 additions & 0 deletions ortc_datachannel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// +build !js

package webrtc

import (
"io"
"testing"
"time"

"github.com/pion/transport/test"
"github.com/stretchr/testify/assert"
)

func TestDataChannel_ORTCE2E(t *testing.T) {
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

stackA, stackB, err := newORTCPair()
assert.NoError(t, err)

awaitSetup := make(chan struct{})
awaitString := make(chan struct{})
awaitBinary := make(chan struct{})
stackB.sctp.OnDataChannel(func(d *DataChannel) {
close(awaitSetup)

d.OnMessage(func(msg DataChannelMessage) {
if msg.IsString {
close(awaitString)
} else {
close(awaitBinary)
}
})
})

assert.NoError(t, signalORTCPair(stackA, stackB))

var id uint16 = 1
dcParams := &DataChannelParameters{
Label: "Foo",
ID: &id,
}
channelA, err := stackA.api.NewDataChannel(stackA.sctp, dcParams)
assert.NoError(t, err)

<-awaitSetup

assert.NoError(t, channelA.SendText("ABC"))
assert.NoError(t, channelA.Send([]byte("ABC")))

<-awaitString
<-awaitBinary

assert.NoError(t, stackA.close())
assert.NoError(t, stackB.close())

// attempt to send when channel is closed
assert.Error(t, channelA.Send([]byte("ABC")), io.ErrClosedPipe)
assert.Error(t, channelA.SendText("test"), io.ErrClosedPipe)
assert.Error(t, channelA.ensureOpen(), io.ErrClosedPipe)
}
68 changes: 68 additions & 0 deletions ortc_media_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// +build !js

package webrtc

import (
"context"
"testing"
"time"

"github.com/pion/transport/test"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/stretchr/testify/assert"
)

func Test_ORTC_Media(t *testing.T) {
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

stackA, stackB, err := newORTCPair()
assert.NoError(t, err)

assert.NoError(t, stackA.api.mediaEngine.RegisterDefaultCodecs())
assert.NoError(t, stackB.api.mediaEngine.RegisterDefaultCodecs())

assert.NoError(t, signalORTCPair(stackA, stackB))

track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)

rtpSender, err := stackA.api.NewRTPSender(track, stackA.dtls)
assert.NoError(t, err)
assert.NoError(t, rtpSender.Send(rtpSender.GetParameters()))

rtpReceiver, err := stackB.api.NewRTPReceiver(RTPCodecTypeVideo, stackB.dtls)
assert.NoError(t, err)
assert.NoError(t, rtpReceiver.Receive(RTPReceiveParameters{Encodings: []RTPDecodingParameters{
{RTPCodingParameters: rtpSender.GetParameters().Encodings[0].RTPCodingParameters},
}}))

seenPacket, seenPacketCancel := context.WithCancel(context.Background())
go func() {
track := rtpReceiver.Track()
_, _, err := track.ReadRTP()
assert.NoError(t, err)

seenPacketCancel()
}()

func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacket.Done():
return
default:
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
}
}()

assert.NoError(t, rtpSender.Stop())
assert.NoError(t, rtpReceiver.Stop())

assert.NoError(t, stackA.close())
assert.NoError(t, stackB.close())
}
94 changes: 4 additions & 90 deletions datachannel_ortc_test.go → ortc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,95 +3,9 @@
package webrtc

import (
"io"
"testing"
"time"

"github.com/pion/transport/test"
"github.com/pion/webrtc/v3/internal/util"
"github.com/stretchr/testify/assert"
)

func TestDataChannel_ORTCE2E(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

stackA, stackB, err := newORTCPair()
if err != nil {
t.Fatal(err)
}

awaitSetup := make(chan struct{})
awaitString := make(chan struct{})
awaitBinary := make(chan struct{})
stackB.sctp.OnDataChannel(func(d *DataChannel) {
close(awaitSetup)

d.OnMessage(func(msg DataChannelMessage) {
if msg.IsString {
close(awaitString)
} else {
close(awaitBinary)
}
})
})

err = signalORTCPair(stackA, stackB)
if err != nil {
t.Fatal(err)
}

var id uint16 = 1
dcParams := &DataChannelParameters{
Label: "Foo",
ID: &id,
}
channelA, err := stackA.api.NewDataChannel(stackA.sctp, dcParams)
if err != nil {
t.Fatal(err)
}

<-awaitSetup

err = channelA.SendText("ABC")
if err != nil {
t.Fatal(err)
}
err = channelA.Send([]byte("ABC"))
if err != nil {
t.Fatal(err)
}
<-awaitString
<-awaitBinary

err = stackA.close()
if err != nil {
t.Fatal(err)
}

err = stackB.close()
if err != nil {
t.Fatal(err)
}

// attempt to send when channel is closed
err = channelA.Send([]byte("ABC"))
assert.Error(t, err)
assert.Equal(t, io.ErrClosedPipe, err)

err = channelA.SendText("test")
assert.Error(t, err)
assert.Equal(t, io.ErrClosedPipe, err)

err = channelA.ensureOpen()
assert.Error(t, err)
assert.Equal(t, io.ErrClosedPipe, err)
}

type testORTCStack struct {
api *API
gatherer *ICEGatherer
Expand Down Expand Up @@ -185,10 +99,10 @@ func (s *testORTCStack) close() error {
}

type testORTCSignal struct {
ICECandidates []ICECandidate `json:"iceCandidates"`
ICEParameters ICEParameters `json:"iceParameters"`
DTLSParameters DTLSParameters `json:"dtlsParameters"`
SCTPCapabilities SCTPCapabilities `json:"sctpCapabilities"`
ICECandidates []ICECandidate
ICEParameters ICEParameters
DTLSParameters DTLSParameters
SCTPCapabilities SCTPCapabilities
}

func newORTCPair() (stackA *testORTCStack, stackB *testORTCStack, err error) {
Expand Down
14 changes: 9 additions & 5 deletions rtpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type RTPSender struct {
api *API
id string

tr *RTPTransceiver
rtpTransceiver *RTPTransceiver

mu sync.RWMutex
sendCalled, stopCalled chan struct{}
Expand Down Expand Up @@ -90,10 +90,10 @@ func (r *RTPSender) setNegotiated() {
r.negotiated = true
}

func (r *RTPSender) setRTPTransceiver(tr *RTPTransceiver) {
func (r *RTPSender) setRTPTransceiver(rtpTransceiver *RTPTransceiver) {
r.mu.Lock()
defer r.mu.Unlock()
r.tr = tr
r.rtpTransceiver = rtpTransceiver
}

// Transport returns the currently-configured *DTLSTransport or nil
Expand All @@ -119,7 +119,11 @@ func (r *RTPSender) getParameters() RTPSendParameters {
},
},
}
sendParameters.Codecs = r.tr.getCodecs()
if r.rtpTransceiver != nil {
sendParameters.Codecs = r.rtpTransceiver.getCodecs()
} else {
sendParameters.Codecs = r.api.mediaEngine.getCodecsByKind(r.track.Kind())
}
return sendParameters
}

Expand All @@ -145,7 +149,7 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error {
r.mu.Lock()
defer r.mu.Unlock()

if track != nil && r.tr.kind != track.Kind() {
if track != nil && r.rtpTransceiver.kind != track.Kind() {
return ErrRTPSenderNewTrackHasIncorrectKind
}

Expand Down
2 changes: 1 addition & 1 deletion rtpsender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func Test_RTPSender_ReplaceTrack_InvalidCodecChange(t *testing.T) {
rtpSender, err := sender.AddTrack(trackA)
assert.NoError(t, err)

err = rtpSender.tr.SetCodecPreferences([]RTPCodecParameters{{
err = rtpSender.rtpTransceiver.SetCodecPreferences([]RTPCodecParameters{{
RTPCodecCapability: RTPCodecCapability{MimeType: MimeTypeVP8},
PayloadType: 96,
}})
Expand Down

0 comments on commit 99633ae

Please sign in to comment.