diff --git a/cmd/shadowsocks-server/server.go b/cmd/shadowsocks-server/server.go index 1a8c07a1..5713aef7 100644 --- a/cmd/shadowsocks-server/server.go +++ b/cmd/shadowsocks-server/server.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "errors" "flag" + "github.com/cyfdecyf/dnspool" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" "io" "log" @@ -23,7 +24,34 @@ var debug ss.DebugLog var errAddrType = errors.New("addr type not supported") -func getRequest(conn *ss.Conn) (host string, extra []byte, err error) { +const dnsGoroutineNum = 64 + +func dial(hostPort string, isIP bool) (c net.Conn, err error) { + // return net.Dial("tcp", host) + if isIP { + return net.Dial("tcp", hostPort) + } + + var addrs []string + var host, port string + + if host, port, err = net.SplitHostPort(hostPort); err != nil { + log.Println("Internal error: host should always has port specified") + return + } + if addrs, err = dnspool.LookupHost(host); err != nil { + return + } + for _, ip := range addrs { + ipHost := net.JoinHostPort(ip, port) + if c, err = net.Dial("tcp", ipHost); err == nil { + return + } + } + return nil, err +} + +func getRequest(conn *ss.Conn) (host string, extra []byte, isIP bool, err error) { const ( idType = 0 // address type index idIP0 = 1 // ip addres start index @@ -66,16 +94,17 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) { extra = buf[reqLen:n] } - // TODO add ipv6 support + // TODO add ipv6 address support if buf[idType] == typeDm { host = string(buf[idDm0 : idDm0+buf[idDmLen]]) } else if buf[idType] == typeIP { addrIp := net.IPv4(buf[idIP0], buf[idIP0+1], buf[idIP0+2], buf[idIP0+3]) host = addrIp.String() + isIP = true } // parse port port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen]) - host += ":" + strconv.Itoa(int(port)) + host = net.JoinHostPort(host, strconv.Itoa(int(port))) return } @@ -109,13 +138,13 @@ func handleConnection(conn *ss.Conn) { conn.Close() }() - host, extra, err := getRequest(conn) + host, extra, isIP, err := getRequest(conn) if err != nil { log.Println("error getting request:", err) return } debug.Println("connecting", host) - remote, err := net.Dial("tcp", host) + remote, err := dial(host, isIP) if err != nil { if ne, ok := err.(*net.OpError); ok && (ne.Err == syscall.EMFILE || ne.Err == syscall.ENFILE) { // log too many open file error @@ -127,7 +156,7 @@ func handleConnection(conn *ss.Conn) { return } defer remote.Close() - // write extra bytes read from + // write extra bytes read from if extra != nil { // debug.Println("getRequest read extra data, writing to remote, len", len(extra)) if _, err = remote.Write(extra); err != nil { @@ -410,6 +439,8 @@ func main() { os.Exit(1) } + dnspool.SetGoroutineNumber(dnsGoroutineNum) + initTableCache(config) for port, password := range config.PortPassword { go run(port, password)