@@ -16,6 +16,7 @@ package conn
16
16
17
17
import (
18
18
"bufio"
19
+ "bytes"
19
20
"encoding/base64"
20
21
"fmt"
21
22
"io"
@@ -25,6 +26,8 @@ import (
25
26
"strings"
26
27
"sync"
27
28
"time"
29
+
30
+ "github.com/fatedier/frp/src/utils/pool"
28
31
)
29
32
30
33
type Listener struct {
@@ -61,11 +64,7 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
61
64
continue
62
65
}
63
66
64
- c := & Conn {
65
- TcpConn : conn ,
66
- closeFlag : false ,
67
- }
68
- c .Reader = bufio .NewReader (c .TcpConn )
67
+ c := NewConn (conn )
69
68
l .accept <- c
70
69
}
71
70
}()
@@ -95,20 +94,23 @@ func (l *Listener) Close() error {
95
94
type Conn struct {
96
95
TcpConn net.Conn
97
96
Reader * bufio.Reader
97
+ buffer * bytes.Buffer
98
98
closeFlag bool
99
- mutex sync.RWMutex
99
+
100
+ mutex sync.RWMutex
100
101
}
101
102
102
103
func NewConn (conn net.Conn ) (c * Conn ) {
103
- c = & Conn {}
104
- c .TcpConn = conn
104
+ c = & Conn {
105
+ TcpConn : conn ,
106
+ buffer : nil ,
107
+ closeFlag : false ,
108
+ }
105
109
c .Reader = bufio .NewReader (c .TcpConn )
106
- c .closeFlag = false
107
- return c
110
+ return
108
111
}
109
112
110
113
func ConnectServer (addr string ) (c * Conn , err error ) {
111
- c = & Conn {}
112
114
servertAddr , err := net .ResolveTCPAddr ("tcp" , addr )
113
115
if err != nil {
114
116
return
@@ -117,9 +119,7 @@ func ConnectServer(addr string) (c *Conn, err error) {
117
119
if err != nil {
118
120
return
119
121
}
120
- c .TcpConn = conn
121
- c .Reader = bufio .NewReader (c .TcpConn )
122
- c .closeFlag = false
122
+ c = NewConn (conn )
123
123
return c , nil
124
124
}
125
125
@@ -185,7 +185,23 @@ func (c *Conn) GetLocalAddr() (addr string) {
185
185
}
186
186
187
187
func (c * Conn ) Read (p []byte ) (n int , err error ) {
188
- n , err = c .Reader .Read (p )
188
+ c .mutex .RLock ()
189
+ if c .buffer == nil {
190
+ c .mutex .RUnlock ()
191
+ return c .Reader .Read (p )
192
+ }
193
+ c .mutex .RUnlock ()
194
+
195
+ n , err = c .buffer .Read (p )
196
+ if err == io .EOF {
197
+ c .mutex .Lock ()
198
+ c .buffer = nil
199
+ c .mutex .Unlock ()
200
+ var n2 int
201
+ n2 , err = c .Reader .Read (p [n :])
202
+
203
+ n += n2
204
+ }
189
205
return
190
206
}
191
207
@@ -212,6 +228,16 @@ func (c *Conn) WriteString(content string) (err error) {
212
228
return err
213
229
}
214
230
231
+ func (c * Conn ) AppendReaderBuffer (content []byte ) {
232
+ c .mutex .Lock ()
233
+ defer c .mutex .Unlock ()
234
+
235
+ if c .buffer == nil {
236
+ c .buffer = bytes .NewBuffer (make ([]byte , 0 , 2048 ))
237
+ }
238
+ c .buffer .Write (content )
239
+ }
240
+
215
241
func (c * Conn ) SetDeadline (t time.Time ) error {
216
242
return c .TcpConn .SetDeadline (t )
217
243
}
@@ -238,22 +264,36 @@ func (c *Conn) IsClosed() (closeFlag bool) {
238
264
}
239
265
240
266
// when you call this function, you should make sure that
241
- // remote client won't send any bytes to this socket
267
+ // no bytes were read before
242
268
func (c * Conn ) CheckClosed () bool {
243
269
c .mutex .RLock ()
244
270
if c .closeFlag {
271
+ c .mutex .RUnlock ()
245
272
return true
246
273
}
247
274
c .mutex .RUnlock ()
248
275
276
+ tmp := pool .GetBuf (2048 )
277
+ defer pool .PutBuf (tmp )
249
278
err := c .TcpConn .SetReadDeadline (time .Now ().Add (time .Millisecond ))
250
279
if err != nil {
251
280
c .Close ()
252
281
return true
253
282
}
254
283
255
- var tmp []byte = make ([]byte , 1 )
256
- _ , err = c .TcpConn .Read (tmp )
284
+ n , err := c .TcpConn .Read (tmp )
285
+ if err == io .EOF {
286
+ return true
287
+ }
288
+
289
+ var tmp2 []byte = make ([]byte , 1 )
290
+ err = c .TcpConn .SetReadDeadline (time .Now ().Add (time .Millisecond ))
291
+ if err != nil {
292
+ c .Close ()
293
+ return true
294
+ }
295
+
296
+ n2 , err := c .TcpConn .Read (tmp2 )
257
297
if err == io .EOF {
258
298
return true
259
299
}
@@ -263,5 +303,12 @@ func (c *Conn) CheckClosed() bool {
263
303
c .Close ()
264
304
return true
265
305
}
306
+
307
+ if n > 0 {
308
+ c .AppendReaderBuffer (tmp [:n ])
309
+ }
310
+ if n2 > 0 {
311
+ c .AppendReaderBuffer (tmp2 [:n2 ])
312
+ }
266
313
return false
267
314
}
0 commit comments