diff --git a/.travis.yml b/.travis.yml index dea91125..30d8e2b7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,11 @@ language: go go: - - 1.4.3 + - 1.7.4 install: - go get golang.org/x/crypto/blowfish - go get golang.org/x/crypto/cast5 - go get golang.org/x/crypto/salsa20 - - go get github.com/codahale/chacha20 + - go get github.com/Yawning/chacha20 - go install ./cmd/shadowsocks-local - go install ./cmd/shadowsocks-server script: diff --git a/CHANGELOG b/CHANGELOG index b6886994..4d5ecc4a 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,13 @@ +1.2.0 (2017-01-20) + * Support UDP reley on server side, and OTA + * Support "aes-[128/192/256]-ctr" encryption method (Thanks for @slurin) + * Support "chacha20-ietf" encryption method + * Improve performance of "chacha20" encryption method + * Corrently close connection if handshake failed + +1.1.5 (2016-05-04) + * Support OTA (Thanks for @ayanamist for implementing this feature) + 1.1.4 (2015-05-10) * Support "chacha20" encryption method, thanks to @defia * Support "salsa20" encryption method, thanks to @genzj diff --git a/README.md b/README.md index 21128964..185531bc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # shadowsocks-go -Current version: 1.1.5 [![Build Status](https://travis-ci.org/shadowsocks/shadowsocks-go.png?branch=master)](https://travis-ci.org/shadowsocks/shadowsocks-go) +Current version: 1.2.0 [![Build Status](https://travis-ci.org/shadowsocks/shadowsocks-go.png?branch=master)](https://travis-ci.org/shadowsocks/shadowsocks-go) shadowsocks-go is a lightweight tunnel proxy which can help you get through firewalls. It is a port of [shadowsocks](https://github.com/clowwindy/shadowsocks). @@ -62,6 +62,12 @@ Append `-auth` to the encryption method to enable [One Time Auth (OTA)](https:// - For server: this will **force client use OTA**, non-OTA connection will be dropped. Otherwise, both OTA and non-OTA clients can connect - For client: the `-A` command line option can also enable OTA +### UDP relay + +Use `-u` command line options when starting server to enable UDP relay. + +Currently only tested with Shadowsocks-Android, if you have encountered any problem, please report. + ## Command line options Command line options can override settings from configuration files. Use `-h` option to see all available options. diff --git a/cmd/shadowsocks-server/server.go b/cmd/shadowsocks-server/server.go index 3f74fe68..0477b67b 100644 --- a/cmd/shadowsocks-server/server.go +++ b/cmd/shadowsocks-server/server.go @@ -37,14 +37,15 @@ const ( ) var debug ss.DebugLog +var udp bool func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) { ss.SetReadTimeout(conn) // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) - // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1) - buf := make([]byte, 270) + // 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1) + buf := make([]byte, 269) // read till we get possible domain length field if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { return @@ -61,7 +62,7 @@ func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) { if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { return } - reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) + reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase default: err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) return @@ -80,7 +81,7 @@ func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) { case typeIPv6: host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() case typeDm: - host = string(buf[idDm0 : idDm0+buf[idDmLen]]) + host = string(buf[idDm0 : idDm0+int(buf[idDmLen])]) } // parse port port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) @@ -138,6 +139,13 @@ func handleConnection(conn *ss.Conn, auth bool) { host, ota, err := getRequest(conn, auth) if err != nil { log.Println("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err) + closed = true + return + } + // ensure the host does not contain some illegal characters, NUL may panic on Win32 + if strings.ContainsRune(host, 0x00) { + log.Println("invalid domain name.") + closed = true return } debug.Println("connecting", host) @@ -175,9 +183,15 @@ type PortListener struct { listener net.Listener } +type UDPListener struct { + password string + listener *net.UDPConn +} + type PasswdManager struct { sync.Mutex portListener map[string]*PortListener + udpListener map[string]*UDPListener } func (pm *PasswdManager) add(port, password string, listener net.Listener) { @@ -186,6 +200,12 @@ func (pm *PasswdManager) add(port, password string, listener net.Listener) { pm.Unlock() } +func (pm *PasswdManager) addUDP(port, password string, listener *net.UDPConn) { + pm.Lock() + pm.udpListener[port] = &UDPListener{password, listener} + pm.Unlock() +} + func (pm *PasswdManager) get(port string) (pl *PortListener, ok bool) { pm.Lock() pl, ok = pm.portListener[port] @@ -193,14 +213,31 @@ func (pm *PasswdManager) get(port string) (pl *PortListener, ok bool) { return } +func (pm *PasswdManager) getUDP(port string) (pl *UDPListener, ok bool) { + pm.Lock() + pl, ok = pm.udpListener[port] + pm.Unlock() + return +} + func (pm *PasswdManager) del(port string) { pl, ok := pm.get(port) if !ok { return } + if udp { + upl, ok := pm.getUDP(port) + if !ok { + return + } + upl.listener.Close() + } pl.listener.Close() pm.Lock() delete(pm.portListener, port) + if udp { + delete(pm.udpListener, port) + } pm.Unlock() } @@ -222,9 +259,14 @@ func (pm *PasswdManager) updatePortPasswd(port, password string, auth bool) { // run will add the new port listener to passwdManager. // So there maybe concurrent access to passwdManager and we need lock to protect it. go run(port, password, auth) + if udp { + pl, _ := pm.getUDP(port) + pl.listener.Close() + go runUDP(port, password, auth) + } } -var passwdManager = PasswdManager{portListener: map[string]*PortListener{}} +var passwdManager = PasswdManager{portListener: map[string]*PortListener{}, udpListener: map[string]*UDPListener{}} func updatePasswd() { log.Println("updating password") @@ -297,6 +339,33 @@ func run(port, password string, auth bool) { } } +func runUDP(port, password string, auth bool) { + var cipher *ss.Cipher + port_i, _ := strconv.Atoi(port) + log.Printf("listening udp port %v\n", port) + conn, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv6zero, + Port: port_i, + }) + passwdManager.addUDP(port, password, conn) + if err != nil { + log.Printf("error listening udp port %v: %v\n", port, err) + return + } + defer conn.Close() + cipher, err = ss.NewCipher(config.Method, password) + if err != nil { + log.Printf("Error generating cipher for udp port: %s %v\n", port, err) + conn.Close() + } + SecurePacketConn := ss.NewSecurePacketConn(conn, cipher.Copy(), auth) + for { + if err := ss.ReadAndHandleUDPReq(SecurePacketConn); err != nil { + debug.Println(err) + } + } +} + func enoughOptions(config *ss.Config) bool { return config.ServerPort != 0 && config.Password != "" } @@ -335,7 +404,7 @@ func main() { flag.StringVar(&cmdConfig.Method, "m", "", "encryption method, default: aes-256-cfb") flag.IntVar(&core, "core", 0, "maximum number of CPU cores to use, default is determinied by Go runtime") flag.BoolVar((*bool)(&debug), "d", false, "print debug message") - + flag.BoolVar(&udp, "u", false, "UDP Relay") flag.Parse() if printVer { @@ -358,6 +427,7 @@ func main() { os.Exit(1) } config = &cmdConfig + ss.UpdateConfig(config, config) } else { ss.UpdateConfig(config, &cmdConfig) } @@ -376,6 +446,9 @@ func main() { } for port, password := range config.PortPassword { go run(port, password, config.Auth) + if udp { + go runUDP(port, password, config.Auth) + } } waitSignal() diff --git a/shadowsocks/encrypt.go b/shadowsocks/encrypt.go index 146947d3..44d96a74 100644 --- a/shadowsocks/encrypt.go +++ b/shadowsocks/encrypt.go @@ -12,7 +12,7 @@ import ( "io" "strings" - "github.com/codahale/chacha20" + "github.com/Yawning/chacha20" "golang.org/x/crypto/blowfish" "golang.org/x/crypto/cast5" "golang.org/x/crypto/salsa20/salsa" @@ -65,11 +65,19 @@ func newStream(block cipher.Block, err error, key, iv []byte, } } -func newAESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { +func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { block, err := aes.NewCipher(key) return newStream(block, err, key, iv, doe) } +func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(block, iv), nil +} + func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { block, err := des.NewCipher(key) return newStream(block, err, key, iv, doe) @@ -95,7 +103,11 @@ func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { } func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { - return chacha20.New(key, iv) + return chacha20.NewCipher(key, iv) +} + +func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + return chacha20.NewCipher(key, iv) } type salsaStreamCipher struct { @@ -145,15 +157,19 @@ type cipherInfo struct { } var cipherMethod = map[string]*cipherInfo{ - "aes-128-cfb": {16, 16, newAESStream}, - "aes-192-cfb": {24, 16, newAESStream}, - "aes-256-cfb": {32, 16, newAESStream}, - "des-cfb": {8, 8, newDESStream}, - "bf-cfb": {16, 8, newBlowFishStream}, - "cast5-cfb": {16, 8, newCast5Stream}, - "rc4-md5": {16, 16, newRC4MD5Stream}, - "chacha20": {32, 8, newChaCha20Stream}, - "salsa20": {32, 8, newSalsa20Stream}, + "aes-128-cfb": {16, 16, newAESCFBStream}, + "aes-192-cfb": {24, 16, newAESCFBStream}, + "aes-256-cfb": {32, 16, newAESCFBStream}, + "aes-128-ctr": {16, 16, newAESCTRStream}, + "aes-192-ctr": {24, 16, newAESCTRStream}, + "aes-256-ctr": {32, 16, newAESCTRStream}, + "des-cfb": {8, 8, newDESStream}, + "bf-cfb": {16, 8, newBlowFishStream}, + "cast5-cfb": {16, 8, newCast5Stream}, + "rc4-md5": {16, 16, newRC4MD5Stream}, + "chacha20": {32, 8, newChaCha20Stream}, + "chacha20-ietf": {32, 12, newChaCha20IETFStream}, + "salsa20": {32, 8, newSalsa20Stream}, } func CheckCipherMethod(method string) error { diff --git a/shadowsocks/encrypt_test.go b/shadowsocks/encrypt_test.go index 9ff24286..9a2e6ce6 100644 --- a/shadowsocks/encrypt_test.go +++ b/shadowsocks/encrypt_test.go @@ -63,18 +63,30 @@ func testBlockCipher(t *testing.T, method string) { testCiphter(t, cipherCopy, method+" copy") } -func TestAES128(t *testing.T) { +func TestAES128CFB(t *testing.T) { testBlockCipher(t, "aes-128-cfb") } -func TestAES192(t *testing.T) { +func TestAES192CFB(t *testing.T) { testBlockCipher(t, "aes-192-cfb") } -func TestAES256(t *testing.T) { +func TestAES256CFB(t *testing.T) { testBlockCipher(t, "aes-256-cfb") } +func TestAES128CTR(t *testing.T) { + testBlockCipher(t, "aes-128-ctr") +} + +func TestAES192CTR(t *testing.T) { + testBlockCipher(t, "aes-192-ctr") +} + +func TestAES256CTR(t *testing.T) { + testBlockCipher(t, "aes-256-ctr") +} + func TestDES(t *testing.T) { testBlockCipher(t, "des-cfb") } @@ -87,6 +99,10 @@ func TestChaCha20(t *testing.T) { testBlockCipher(t, "chacha20") } +func TestChaCha20IETF(t *testing.T) { + testBlockCipher(t, "chacha20-ietf") +} + var cipherKey = make([]byte, 64) var cipherIv = make([]byte, 64) @@ -108,18 +124,30 @@ func benchmarkCipherInit(b *testing.B, method string) { } } -func BenchmarkAES128Init(b *testing.B) { +func BenchmarkAES128CFBInit(b *testing.B) { benchmarkCipherInit(b, "aes-128-cfb") } -func BenchmarkAES192Init(b *testing.B) { +func BenchmarkAES19CFB2Init(b *testing.B) { benchmarkCipherInit(b, "aes-192-cfb") } -func BenchmarkAES256Init(b *testing.B) { +func BenchmarkAES256CFBInit(b *testing.B) { benchmarkCipherInit(b, "aes-256-cfb") } +func BenchmarkAES128CTRInit(b *testing.B) { + benchmarkCipherInit(b, "aes-128-ctr") +} + +func BenchmarkAES192CTRInit(b *testing.B) { + benchmarkCipherInit(b, "aes-192-ctr") +} + +func BenchmarkAES256CTRInit(b *testing.B) { + benchmarkCipherInit(b, "aes-256-ctr") +} + func BenchmarkBlowFishInit(b *testing.B) { benchmarkCipherInit(b, "bf-cfb") } @@ -140,6 +168,10 @@ func BenchmarkChaCha20Init(b *testing.B) { benchmarkCipherInit(b, "chacha20") } +func BenchmarkChaCha20IETFInit(b *testing.B) { + benchmarkCipherInit(b, "chacha20-ietf") +} + func BenchmarkSalsa20Init(b *testing.B) { benchmarkCipherInit(b, "salsa20") } @@ -160,18 +192,30 @@ func benchmarkCipherEncrypt(b *testing.B, method string) { } } -func BenchmarkAES128Encrypt(b *testing.B) { +func BenchmarkAES128CFBEncrypt(b *testing.B) { benchmarkCipherEncrypt(b, "aes-128-cfb") } -func BenchmarkAES192Encrypt(b *testing.B) { +func BenchmarkAES192CFBEncrypt(b *testing.B) { benchmarkCipherEncrypt(b, "aes-192-cfb") } -func BenchmarkAES256Encrypt(b *testing.B) { +func BenchmarkAES256CFBEncrypt(b *testing.B) { benchmarkCipherEncrypt(b, "aes-256-cfb") } +func BenchmarkAES128CTREncrypt(b *testing.B) { + benchmarkCipherEncrypt(b, "aes-128-ctr") +} + +func BenchmarkAES192CTREncrypt(b *testing.B) { + benchmarkCipherEncrypt(b, "aes-192-ctr") +} + +func BenchmarkAES256CTREncrypt(b *testing.B) { + benchmarkCipherEncrypt(b, "aes-256-ctr") +} + func BenchmarkBlowFishEncrypt(b *testing.B) { benchmarkCipherEncrypt(b, "bf-cfb") } @@ -192,6 +236,10 @@ func BenchmarkChacha20Encrypt(b *testing.B) { benchmarkCipherEncrypt(b, "chacha20") } +func BenchmarkChacha20IETFEncrypt(b *testing.B) { + benchmarkCipherEncrypt(b, "chacha20-ietf") +} + func BenchmarkSalsa20Encrypt(b *testing.B) { benchmarkCipherEncrypt(b, "salsa20") } @@ -217,18 +265,30 @@ func benchmarkCipherDecrypt(b *testing.B, method string) { } } -func BenchmarkAES128Decrypt(b *testing.B) { +func BenchmarkAES128CFBDecrypt(b *testing.B) { benchmarkCipherDecrypt(b, "aes-128-cfb") } -func BenchmarkAES192Decrypt(b *testing.B) { +func BenchmarkAES192CFBDecrypt(b *testing.B) { benchmarkCipherDecrypt(b, "aes-192-cfb") } -func BenchmarkAES256Decrypt(b *testing.B) { +func BenchmarkAES256CFBDecrypt(b *testing.B) { benchmarkCipherDecrypt(b, "aes-256-cfb") } +func BenchmarkAES128CTRDecrypt(b *testing.B) { + benchmarkCipherDecrypt(b, "aes-128-ctr") +} + +func BenchmarkAES192CTRDecrypt(b *testing.B) { + benchmarkCipherDecrypt(b, "aes-192-ctr") +} + +func BenchmarkAES256CTRDecrypt(b *testing.B) { + benchmarkCipherDecrypt(b, "aes-256-ctr") +} + func BenchmarkBlowFishDecrypt(b *testing.B) { benchmarkCipherDecrypt(b, "bf-cfb") } @@ -249,6 +309,10 @@ func BenchmarkChaCha20Decrypt(b *testing.B) { benchmarkCipherDecrypt(b, "chacha20") } +func BenchmarkChaCha20IETFDecrypt(b *testing.B) { + benchmarkCipherDecrypt(b, "chacha20-ietf") +} + func BenchmarkSalsa20Decrypt(b *testing.B) { benchmarkCipherDecrypt(b, "salsa20") } diff --git a/shadowsocks/udp.go b/shadowsocks/udp.go new file mode 100644 index 00000000..058e93e1 --- /dev/null +++ b/shadowsocks/udp.go @@ -0,0 +1,152 @@ +package shadowsocks + +import ( + "bytes" + "fmt" + "net" + "time" +) + +const ( + maxPacketSize = 4096 // increase it if error occurs +) + +var ( + errPacketTooSmall = fmt.Errorf("[udp]read error: cannot decrypt, received packet is smaller than ivLen") + errPacketTooLarge = fmt.Errorf("[udp]read error: received packet is latger than maxPacketSize(%d)", maxPacketSize) + errBufferTooSmall = fmt.Errorf("[udp]read error: given buffer is too small to hold data") + errPacketOtaFailed = fmt.Errorf("[udp]read error: received packet has invalid ota") +) + +type SecurePacketConn struct { + net.PacketConn + *Cipher + ota bool +} + +func NewSecurePacketConn(c net.PacketConn, cipher *Cipher, ota bool) *SecurePacketConn { + return &SecurePacketConn{ + PacketConn: c, + Cipher: cipher, + ota: ota, + } +} + +func (c *SecurePacketConn) Close() error { + return c.PacketConn.Close() +} + +func (c *SecurePacketConn) ReadFrom(b []byte) (n int, src net.Addr, err error) { + ota := false + cipher := c.Copy() + buf := make([]byte, 4096) + n, src, err = c.PacketConn.ReadFrom(buf) + if err != nil { + return + } + + if n < c.info.ivLen { + return 0, nil, errPacketTooSmall + } + + if len(b) < n-c.info.ivLen { + err = errBufferTooSmall // just a warning + } + + iv := make([]byte, c.info.ivLen) + copy(iv, buf[:c.info.ivLen]) + + if err = cipher.initDecrypt(iv); err != nil { + return + } + + cipher.decrypt(b[0:], buf[c.info.ivLen:n]) + n -= c.info.ivLen + if b[idType]&OneTimeAuthMask > 0 { + ota = true + } + + if c.ota && !ota { + return 0, src, errPacketOtaFailed + } + + if ota { + key := cipher.key + actualHmacSha1Buf := HmacSha1(append(iv, key...), b[:n-lenHmacSha1]) + if !bytes.Equal(b[n-lenHmacSha1:n], actualHmacSha1Buf) { + Debug.Printf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, b) + return 0, src, errPacketOtaFailed + } + n -= lenHmacSha1 + } + + return +} + +func (c *SecurePacketConn) WriteTo(b []byte, dst net.Addr) (n int, err error) { + cipher := c.Copy() + iv, err := cipher.initEncrypt() + if err != nil { + return + } + packetLen := len(b) + len(iv) + + if c.ota { + packetLen += lenHmacSha1 + key := cipher.key + actualHmacSha1Buf := HmacSha1(append(iv, key...), b) + b = append(b, actualHmacSha1Buf...) + } + + cipherData := make([]byte, packetLen) + copy(cipherData, iv) + + cipher.encrypt(cipherData[len(iv):], b) + n, err = c.PacketConn.WriteTo(cipherData, dst) + if c.ota { + n -= lenHmacSha1 + } + return +} + +func (c *SecurePacketConn) LocalAddr() net.Addr { + return c.PacketConn.LocalAddr() +} + +func (c *SecurePacketConn) SetDeadline(t time.Time) error { + return c.PacketConn.SetDeadline(t) +} + +func (c *SecurePacketConn) SetReadDeadline(t time.Time) error { + return c.PacketConn.SetReadDeadline(t) +} + +func (c *SecurePacketConn) SetWriteDeadline(t time.Time) error { + return c.PacketConn.SetWriteDeadline(t) +} + +func (c *SecurePacketConn) IsOta() bool { + return c.ota +} + +func (c *SecurePacketConn) ForceOTAWriteTo(b []byte, dst net.Addr) (n int, err error) { + cipher := c.Copy() + iv, err := cipher.initEncrypt() + if err != nil { + return + } + packetLen := len(b) + len(iv) + + packetLen += lenHmacSha1 + key := cipher.key + actualHmacSha1Buf := HmacSha1(append(iv, key...), b) + b = append(b, actualHmacSha1Buf...) + + cipherData := make([]byte, packetLen) + copy(cipherData, iv) + + cipher.encrypt(cipherData[len(iv):], b) + n, err = c.PacketConn.WriteTo(cipherData, dst) + n -= lenHmacSha1 + return +} diff --git a/shadowsocks/udprelay.go b/shadowsocks/udprelay.go new file mode 100644 index 00000000..c38dbe0a --- /dev/null +++ b/shadowsocks/udprelay.go @@ -0,0 +1,265 @@ +package shadowsocks + +import ( + "encoding/binary" + "fmt" + "net" + "strconv" + "strings" + "sync" + "syscall" + "time" +) + +const ( + idType = 0 // address type index + idIP0 = 1 // ip addres start index + idDmLen = 1 // domain address length index + idDm0 = 2 // domain address start index + + typeIPv4 = 1 // type is ipv4 address + typeDm = 3 // type is domain address + typeIPv6 = 4 // type is ipv6 address + + lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port + lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port + lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen + lenHmacSha1 = 10 +) + +var ( + reqList = newReqList() + natlist = newNatTable() + udpTimeout = 30 * time.Second + reqListRefreshTime = 5 * time.Minute +) + +type natTable struct { + sync.Mutex + conns map[string]net.PacketConn +} + +func newNatTable() *natTable { + return &natTable{conns: map[string]net.PacketConn{}} +} + +func (table *natTable) DeleteAndClose(index string) { + table.Lock() + defer table.Unlock() + c, ok := table.conns[index] + if ok { + c.Close() + delete(table.conns, index) + } +} + +func (table *natTable) Get(index string) (c net.PacketConn, ok bool, err error) { + table.Lock() + defer table.Unlock() + c, ok = table.conns[index] + if !ok { + c, err = net.ListenPacket("udp", "") + if err != nil { + return nil, false, err + } + table.conns[index] = c + } + return +} + +type ReqList struct { + sync.Mutex + List map[string]([]byte) +} + +func newReqList() *ReqList { + ret := &ReqList{List: map[string]([]byte){}} + go func() { + for { + time.Sleep(reqListRefreshTime) + ret.Refresh() + } + }() + return ret +} + +func (r *ReqList) Refresh() { + r.Lock() + defer r.Unlock() + for k, _ := range r.List { + delete(r.List, k) + } +} + +func (r *ReqList) Get(dstaddr string) (req []byte, ok bool) { + r.Lock() + defer r.Unlock() + req, ok = r.List[dstaddr] + return +} + +func (r *ReqList) Put(dstaddr string, req []byte) { + r.Lock() + defer r.Unlock() + r.List[dstaddr] = req + return +} + +func ParseHeader(addr net.Addr) ([]byte, int) { + // if the request address type is domain, it cannot be reverselookuped + ip, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil, 0 + } + buf := make([]byte, 20) + IP := net.ParseIP(ip) + b1 := IP.To4() + iplen := 0 + if b1 == nil { //ipv6 + b1 = IP.To16() + buf[0] = typeIPv6 + iplen = net.IPv6len + } else { //ipv4 + buf[0] = typeIPv4 + iplen = net.IPv4len + } + copy(buf[1:], b1) + port_i, _ := strconv.Atoi(port) + binary.BigEndian.PutUint16(buf[1+iplen:], uint16(port_i)) + return buf[:1+iplen+2], 1 + iplen + 2 +} + +func Pipeloop(ss *SecurePacketConn, addr net.Addr, in net.PacketConn, compatiblemode bool) { + buf := leakyBuf.Get() + defer leakyBuf.Put(buf) + defer in.Close() + for { + in.SetDeadline(time.Now().Add(udpTimeout)) + n, raddr, err := in.ReadFrom(buf) + if err != nil { + if ne, ok := err.(*net.OpError); ok { + if ne.Err == syscall.EMFILE || ne.Err == syscall.ENFILE { + // log too many open file error + // EMFILE is process reaches open file limits, ENFILE is system limit + Debug.Println("[udp]read error:", err) + } + } + Debug.Printf("[udp]closed pipe %s<-%s\n", addr, in.LocalAddr()) + return + } + // need improvement here + if req, ok := reqList.Get(raddr.String()); ok { + if compatiblemode { + ss.ForceOTAWriteTo(append(req, buf[:n]...), addr) + } else { + ss.WriteTo(append(req, buf[:n]...), addr) + } + } else { + header, hlen := ParseHeader(raddr) + if compatiblemode { + ss.ForceOTAWriteTo(append(header[:hlen], buf[:n]...), addr) + } else { + ss.WriteTo(append(header[:hlen], buf[:n]...), addr) + } + + } + } +} + +func handleUDPConnection(handle *SecurePacketConn, n int, src net.Addr, receive []byte) { + var dstIP net.IP + var reqLen int + var ota bool + addrType := receive[idType] + defer leakyBuf.Put(receive) + + if addrType&OneTimeAuthMask > 0 { + ota = true + } + compatiblemode := !handle.IsOta() && ota + + switch addrType & AddrMask { + case typeIPv4: + reqLen = lenIPv4 + if len(receive) < reqLen { + Debug.Println("[udp]invalid received message.") + } + dstIP = net.IP(receive[idIP0 : idIP0+net.IPv4len]) + case typeIPv6: + reqLen = lenIPv6 + if len(receive) < reqLen { + Debug.Println("[udp]invalid received message.") + } + dstIP = net.IP(receive[idIP0 : idIP0+net.IPv6len]) + case typeDm: + reqLen = int(receive[idDmLen]) + lenDmBase + if len(receive) < reqLen { + Debug.Println("[udp]invalid received message.") + } + name := string(receive[idDm0 : idDm0+int(receive[idDmLen])]) + // avoid panic: syscall: string with NUL passed to StringToUTF16 on windows. + if strings.ContainsRune(name, 0x00) { + fmt.Println("[udp]invalid domain name.") + return + } + dIP, err := net.ResolveIPAddr("ip", name) // carefully with const type + if err != nil { + Debug.Printf("[udp]failed to resolve domain name: %s\n", string(receive[idDm0:idDm0+receive[idDmLen]])) + return + } + dstIP = dIP.IP + default: + Debug.Printf("[udp]addrType %d not supported", addrType) + return + } + dst := &net.UDPAddr{ + IP: dstIP, + Port: int(binary.BigEndian.Uint16(receive[reqLen-2 : reqLen])), + } + if _, ok := reqList.Get(dst.String()); !ok { + req := make([]byte, reqLen) + copy(req, receive) + reqList.Put(dst.String(), req) + } + + remote, exist, err := natlist.Get(src.String()) + if err != nil { + return + } + if !exist { + Debug.Printf("[udp]new client %s->%s via %s ota=%v\n", src, dst, remote.LocalAddr(), ota) + go func() { + Pipeloop(handle, src, remote, compatiblemode) + natlist.DeleteAndClose(src.String()) + }() + } else { + Debug.Printf("[udp]using cached client %s->%s via %s ota=%v\n", src, dst, remote.LocalAddr(), ota) + } + if remote == nil { + fmt.Println("WTF") + } + remote.SetDeadline(time.Now().Add(udpTimeout)) + _, err = remote.WriteTo(receive[reqLen:n], dst) + if err != nil { + if ne, ok := err.(*net.OpError); ok && (ne.Err == syscall.EMFILE || ne.Err == syscall.ENFILE) { + // log too many open file error + // EMFILE is process reaches open file limits, ENFILE is system limit + Debug.Println("[udp]write error:", err) + } else { + Debug.Println("[udp]error connecting to:", dst, err) + } + natlist.DeleteAndClose(src.String()) + } + // Pipeloop + return +} + +func ReadAndHandleUDPReq(c *SecurePacketConn) error { + buf := leakyBuf.Get() + n, src, err := c.ReadFrom(buf[0:]) + if err != nil { + return err + } + go handleUDPConnection(c, n, src, buf) + return nil +} diff --git a/shadowsocks/util.go b/shadowsocks/util.go index 21c01c6a..671fe6cb 100644 --- a/shadowsocks/util.go +++ b/shadowsocks/util.go @@ -1,16 +1,16 @@ package shadowsocks import ( - "errors" - "fmt" - "os" "crypto/hmac" "crypto/sha1" "encoding/binary" + "errors" + "fmt" + "os" ) func PrintVersion() { - const version = "1.1.5" + const version = "1.2.0" fmt.Println("shadowsocks-go version", version) } @@ -57,4 +57,4 @@ func (flag *ClosedFlag) SetClosed() { func (flag *ClosedFlag) IsClosed() bool { return flag.flag -} \ No newline at end of file +}