diff --git a/platforms/dji/tello/driver_test.go b/platforms/dji/tello/driver_test.go index c8d25b2a8..18a584d6c 100644 --- a/platforms/dji/tello/driver_test.go +++ b/platforms/dji/tello/driver_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "io" + "sync" "testing" "time" @@ -14,6 +15,15 @@ import ( var _ gobot.Driver = (*Driver)(nil) +type WriteCloserDoNothing struct{} + +func (w *WriteCloserDoNothing) Write(p []byte) (n int, err error) { + return 0, nil +} +func (w *WriteCloserDoNothing) Close() error { + return nil +} + func TestTelloDriver(t *testing.T) { d := NewDriver("8888") @@ -134,3 +144,38 @@ func TestHandleResponse(t *testing.T) { }) } } + +func TestHaltShouldTerminateAllTheRelatedGoroutines(t *testing.T) { + d := NewDriver("8888") + d.cmdConn = &WriteCloserDoNothing{} + + var wg sync.WaitGroup + wg.Add(3) + go func() { + <-d.doneCh + wg.Done() + fmt.Println("Done routine 1.") + }() + go func() { + <-d.doneCh + wg.Done() + fmt.Println("Done routine 2.") + }() + go func() { + <-d.doneCh + wg.Done() + fmt.Println("Done routine 3.") + }() + + d.Halt() + wg.Wait() +} + +func TestHaltNotWaitForeverWhenCalledMultipleTimes(t *testing.T) { + d := NewDriver("8888") + d.cmdConn = &WriteCloserDoNothing{} + + d.Halt() + d.Halt() + d.Halt() +}