diff --git a/core/stream.go b/core/stream.go index 5c773cd2..03a19793 100644 --- a/core/stream.go +++ b/core/stream.go @@ -1,6 +1,12 @@ package core -import "net" +import ( + "bytes" + "io" + "net" + + "github.com/shadowsocks/go-shadowsocks2/socks" +) type listener struct { net.Listener @@ -21,3 +27,47 @@ func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) { c, err := net.Dial(network, address) return ciph.StreamConn(c), err } + +// Connect sends the shadowsocks standard header to underlying ciphered connection +// in the next Write/ReadFrom call. +func Connect(c net.Conn, addr socks.Addr) net.Conn { + return &ssconn{Conn: c, addr: addr} +} + +type ssconn struct { + net.Conn + addr socks.Addr +} + +func (c *ssconn) Write(b []byte) (int, error) { + n, err := c.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +func (c *ssconn) ReadFrom(r io.Reader) (int64, error) { + if len(c.addr) > 0 { + r = &readerWithAddr{Reader: r, b: c.addr} + c.addr = nil + } + return io.Copy(c.Conn, r) +} + +func (c *ssconn) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, c.Conn) +} + +type readerWithAddr struct { + io.Reader + b []byte +} + +func (r *readerWithAddr) Read(b []byte) (n int, err error) { + nc := copy(b, r.b) + if nc < len(r.b) { + r.b = r.b[:nc] + return nc, nil + } + r.b = nil + nr, err := r.Reader.Read(b[nc:]) + return nc + nr, err +} diff --git a/tcp.go b/tcp.go index e8892011..b410c83f 100644 --- a/tcp.go +++ b/tcp.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -72,10 +73,11 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func( defer rc.Close() rc.(*net.TCPConn).SetKeepAlive(true) rc = shadow(rc) + // Connect to target + rc = core.Connect(rc, tgt) logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt) - ca := &connWithAddr{Conn: c, addr: tgt} - _, _, err = relay(rc, ca) + _, _, err = relay(rc, c) if err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { return // ignore i/o timeout @@ -159,21 +161,3 @@ func relay(left, right net.Conn) (int64, int64, error) { } return n, rs.N, err } - -type connWithAddr struct { - net.Conn - addr socks.Addr -} - -// Read reads the addr and data from the connection. -// The format of output is aligned with shadowsocks protocol. -func (c *connWithAddr) Read(b []byte) (n int, err error) { - nc := copy(b, c.addr) - if nc < len(c.addr) { - c.addr = c.addr[:nc] - return nc, nil - } - c.addr = nil - nr, err := c.Conn.Read(b[nc:]) - return nc + nr, err -}