Skip to content

Commit 53f6d18

Browse files
authored
Merge pull request #16 from DiscoRiver/stream-bug
Fix Stream bug
2 parents 5ce9623 + d646cba commit 53f6d18

File tree

5 files changed

+81
-36
lines changed

5 files changed

+81
-36
lines changed

.github/workflows/go-test.yml

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
name: Go
22

3-
on:
4-
push:
5-
branches: [ master ]
6-
pull_request:
7-
branches: [ master ]
3+
on: [push, pull_request]
84

95
jobs:
106
build-and-test:

_examples/example_jobstack_streaming/main.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ func main() {
5757
fmt.Printf("%s: %s\n", result.Host, result.Error)
5858
wg.Done()
5959
} else {
60-
err := readStream(result, &wg)
61-
if err != nil {
62-
panic(err)
63-
}
60+
readStream(result, &wg)
6461
}
6562
}()
6663
default:

_examples/example_streaming/main.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ func main() {
5454
fmt.Printf("%s: %s\n", result.Host, result.Error)
5555
wg.Done()
5656
} else {
57-
err := readStream(result, &wg)
58-
if err != nil {
59-
fmt.Println(err)
60-
}
57+
readStream(result, &wg)
6158
}
6259
}()
6360
default:

session.go

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package massh
22

33
import (
4+
"bufio"
45
"bytes"
56
"fmt"
67
"io"
8+
"sync"
79
)
810

911
var (
@@ -133,8 +135,12 @@ func sshCommandStream(host string, config *Config, resultChannel chan Result) {
133135
// Reading from our pipes as they're populated, and redirecting bytes to our stdout and stderr channels in Result.
134136
//
135137
// We're doing this before we start the ssh task so we can start churning through output as soon as it starts.
136-
go readToBytesChannel(StdOutPipe, r.StdOutStream, r)
137-
go readToBytesChannel(StdErrPipe, r.StdErrStream, r)
138+
var wg sync.WaitGroup
139+
wg.Add(2)
140+
go func() {
141+
readToBytesChannel(StdOutPipe, r.StdOutStream, r, &wg)
142+
readToBytesChannel(StdErrPipe, r.StdErrStream, r, &wg)
143+
}()
138144

139145
resultChannel <- r
140146

@@ -145,21 +151,29 @@ func sshCommandStream(host string, config *Config, resultChannel chan Result) {
145151
}
146152

147153
// Wait for the command to exit only after we've initiated all the output channels
154+
wg.Wait()
148155
session.Wait()
149156

150157
NumberOfStreamingHostsCompleted++
151158
}
152159

153160
// readToBytesChannel reads from io.Reader and directs the data to a byte slice channel for streaming.
154-
func readToBytesChannel(reader io.Reader, stream chan []byte, r Result) {
155-
var data = make([]byte, 1024)
161+
func readToBytesChannel(reader io.Reader, stream chan []byte, r Result, wg *sync.WaitGroup) {
162+
defer func(){ wg.Done() }()
163+
164+
rdr := bufio.NewReader(reader)
165+
156166
for {
157-
n, err := reader.Read(data)
167+
line, err := rdr.ReadBytes('\n')
158168
if err != nil {
159-
r.Error = fmt.Errorf("couldn't read content to stream channel: %s", err)
160-
return
169+
if err == io.EOF {
170+
return
171+
} else {
172+
r.Error = fmt.Errorf("couldn't read content to stream channel: %s", err)
173+
return
174+
}
161175
}
162-
stream <- data[:n]
176+
stream <- line
163177
}
164178
}
165179

session_test.go

+56-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package massh
22

33
import (
4+
"fmt"
45
"golang.org/x/crypto/ssh"
56
"strings"
67
"sync"
@@ -71,11 +72,7 @@ func TestSshCommandStream(t *testing.T) {
7172

7273
wg.Done()
7374
} else {
74-
err := readStream(result, &wg, t)
75-
if err != nil {
76-
t.Log(err)
77-
t.FailNow()
78-
}
75+
readStream(result, &wg, t)
7976
}
8077
}()
8178
default:
@@ -90,14 +87,62 @@ func TestSshCommandStream(t *testing.T) {
9087
}
9188
}
9289

93-
func readStream(res Result, wg *sync.WaitGroup, t *testing.T) error {
90+
// Test for bugs in lots of lines.
91+
func TestSshCommandStreamBigData(t *testing.T) {
92+
defer func() {testConfig.Job = testJob}()
93+
NumberOfStreamingHostsCompleted = 0
94+
95+
if err := testConfig.SetPrivateKeyAuth("~/.ssh/id_rsa", ""); err != nil {
96+
t.Log(err)
97+
t.FailNow()
98+
}
99+
100+
testConfig.Job = &Job{
101+
Command: "cat /var/log/auth.log",
102+
}
103+
104+
resChan := make(chan Result)
105+
106+
// This should be the last responsibility from the massh package. Handling the Result channel is up to the user.
107+
err := testConfig.Stream(resChan)
108+
if err != nil {
109+
t.Log(err)
110+
t.FailNow()
111+
}
112+
113+
var wg sync.WaitGroup
114+
// This can probably be cleaner. We're hindered somewhat, I think, by reading a channel from a channel.
94115
for {
95116
select {
96-
case d := <-res.StdOutStream:
97-
if !strings.Contains(string(d), "Hello, World") {
98-
t.Logf("Expected output from stream test not received from host %s: %s", res.Host, d)
99-
t.Fail()
117+
case result := <-resChan:
118+
wg.Add(1)
119+
go func() {
120+
if result.Error != nil {
121+
t.Logf("Unexpected error in stream test for host %s: %s", result.Host, result.Error)
122+
t.Fail()
123+
124+
wg.Done()
125+
} else {
126+
readStream(result, &wg, t)
127+
}
128+
}()
129+
default:
130+
if NumberOfStreamingHostsCompleted == len(testConfig.Hosts) {
131+
// We want to wait for all goroutines to complete before we declare that the work is finished, as
132+
// it's possible for us to execute this code before the gofunc above has completed if left unchecked.
133+
wg.Wait()
134+
135+
return
100136
}
137+
}
138+
}
139+
}
140+
141+
func readStream(res Result, wg *sync.WaitGroup, t *testing.T) {
142+
for {
143+
select {
144+
case d := <-res.StdOutStream:
145+
fmt.Print(string(d))
101146
case <-res.DoneChannel:
102147
wg.Done()
103148
}
@@ -232,11 +277,7 @@ func TestSshCommandStreamWithJobStack(t *testing.T) {
232277

233278
wg.Done()
234279
} else {
235-
err := readStream(result, &wg, t)
236-
if err != nil {
237-
t.Log(err)
238-
t.FailNow()
239-
}
280+
readStream(result, &wg, t)
240281
}
241282
}()
242283
default:

0 commit comments

Comments
 (0)