Skip to content
This repository has been archived by the owner on Apr 9, 2020. It is now read-only.

Commit

Permalink
support one time auth in client & server
Browse files Browse the repository at this point in the history
append "-ota" suffix in method name to enable one time auth
  • Loading branch information
ayanamist committed Dec 4, 2015
1 parent dbe5178 commit 8e44bc7
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 87 deletions.
16 changes: 9 additions & 7 deletions cmd/shadowsocks-local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,12 @@ func parseServerConfig(config *ss.Config) {
}

if len(config.ServerPassword) == 0 {
method := config.Method
if config.Auth {
method += "-ota"
}
// only one encryption table
cipher, err := ss.NewCipher(config.Method, config.Password)
cipher, err := ss.NewCipher(method, config.Password)
if err != nil {
log.Fatal("Failed generating ciphers:", err)
}
Expand Down Expand Up @@ -314,14 +318,12 @@ func handleConnection(conn net.Conn) {
return
}
defer func() {
if !closed {
remote.Close()
}
remote.Close()
}()

go ss.PipeThenClose(conn, remote)
ss.PipeThenClose(remote, conn)
closed = true
closedFlag := &ss.ClosedFlag{}
go ss.PipeThenClose(conn, remote, closedFlag)
ss.PipeThenClose(remote, conn, closedFlag)
debug.Println("closed connection to", addr)
}

Expand Down
132 changes: 71 additions & 61 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"encoding/binary"
"errors"
"flag"
Expand All @@ -17,61 +18,61 @@ import (
"syscall"
)

var debug ss.DebugLog
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index
typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
lenHmacSha1 = 10
)

lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port
lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port
lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
)
var debug ss.DebugLog

func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) {
ss.SetReadTimeout(conn)

// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260)
var n int
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
buf := make([]byte, 270)
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
return
}

reqLen := -1
switch buf[idType] {
var reqStart, reqEnd int
addrType := buf[idType]
switch addrType & ss.AddrMask {
case typeIPv4:
reqLen = lenIPv4
reqStart, reqEnd = idIP0, idIP0+lenIPv4
case typeIPv6:
reqLen = lenIPv6
reqStart, reqEnd = idIP0, idIP0+lenIPv6
case typeDm:
reqLen = int(buf[idDmLen]) + lenDmBase
if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
return
}
reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase)
default:
err = fmt.Errorf("addr type %d not supported", buf[idType])
err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask)
return
}

if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
return
}

// Return string for typeIP is not most efficient, but browsers (Chrome,
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
// big problem.
switch buf[idType] {
switch addrType & ss.AddrMask {
case typeIPv4:
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
case typeIPv6:
Expand All @@ -80,8 +81,22 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen])
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
// if specified one time auth enabled, we should verify this
if auth || addrType&ss.OneTimeAuthMask > 0 {
ota = true
if _, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil {
return
}
iv := conn.GetIv()
key := conn.GetKey()
actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd])
if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) {
err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd])
return
}
}
return
}

Expand All @@ -90,7 +105,11 @@ const logCntDelta = 100
var connCnt int
var nextLogConnCnt int = logCntDelta

func handleConnection(conn *ss.Conn) {
type isClosed struct {
isClosed bool
}

func handleConnection(conn *ss.Conn, auth bool) {
var host string

connCnt++ // this maybe not accurate, but should be enough
Expand All @@ -107,18 +126,15 @@ func handleConnection(conn *ss.Conn) {
if debug {
debug.Printf("new client %s->%s\n", conn.RemoteAddr().String(), conn.LocalAddr())
}
closed := false
defer func() {
if debug {
debug.Printf("closed pipe %s<->%s\n", conn.RemoteAddr(), host)
}
connCnt--
if !closed {
conn.Close()
}
conn.Close()
}()

host, extra, err := getRequest(conn)
host, ota, err := getRequest(conn, auth)
if err != nil {
log.Println("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err)
return
Expand All @@ -136,24 +152,18 @@ func handleConnection(conn *ss.Conn) {
return
}
defer func() {
if !closed {
remote.Close()
}
remote.Close()
}()
// write extra bytes read from
if extra != nil {
// debug.Println("getRequest read extra data, writing to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}
if debug {
debug.Printf("piping %s<->%s", conn.RemoteAddr(), host)
debug.Printf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta())
}
closedFlag := &ss.ClosedFlag{}
if ota {
go ss.PipeThenCloseOta(conn, remote, closedFlag)
} else {
go ss.PipeThenClose(conn, remote, closedFlag)
}
go ss.PipeThenClose(conn, remote)
ss.PipeThenClose(remote, conn)
closed = true
ss.PipeThenClose(remote, conn, closedFlag)
return
}

Expand Down Expand Up @@ -195,7 +205,7 @@ func (pm *PasswdManager) del(port string) {
// port. A different approach would be directly change the password used by
// that port, but that requires **sharing** password between the port listener
// and password manager.
func (pm *PasswdManager) updatePortPasswd(port, password string) {
func (pm *PasswdManager) updatePortPasswd(port, password string, auth bool) {
pl, ok := pm.get(port)
if !ok {
log.Printf("new port %s added\n", port)
Expand All @@ -208,7 +218,7 @@ func (pm *PasswdManager) updatePortPasswd(port, password string) {
}
// run will add the new port listener to passwdManager.
// So there maybe concurrent access to passwdManager and we need lock to protect it.
go run(port, password)
go run(port, password, auth)
}

var passwdManager = PasswdManager{portListener: map[string]*PortListener{}}
Expand All @@ -227,7 +237,7 @@ func updatePasswd() {
return
}
for port, passwd := range config.PortPassword {
passwdManager.updatePortPasswd(port, passwd)
passwdManager.updatePortPasswd(port, passwd, config.Auth)
if oldconfig.PortPassword != nil {
delete(oldconfig.PortPassword, port)
}
Expand All @@ -254,7 +264,7 @@ func waitSignal() {
}
}

func run(port, password string) {
func run(port, password string, auth bool) {
ln, err := net.Listen("tcp", ":"+port)
if err != nil {
log.Printf("error listening port %v: %v\n", port, err)
Expand All @@ -280,7 +290,7 @@ func run(port, password string) {
continue
}
}
go handleConnection(ss.NewConn(conn, cipher.Copy()))
go handleConnection(ss.NewConn(conn, cipher.Copy()), auth)
}
}

Expand Down Expand Up @@ -357,7 +367,7 @@ func main() {
runtime.GOMAXPROCS(core)
}
for port, password := range config.PortPassword {
go run(port, password)
go run(port, password, config.Auth)
}

waitSignal()
Expand Down
6 changes: 6 additions & 0 deletions shadowsocks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"reflect"
"time"
"strings"
)

type Config struct {
Expand All @@ -23,6 +24,7 @@ type Config struct {
LocalPort int `json:"local_port"`
Password string `json:"password"`
Method string `json:"method"` // encryption method
Auth bool `json:"auth"` // one time auth

// following options are only used by server
PortPassword map[string]string `json:"port_password"`
Expand Down Expand Up @@ -85,6 +87,10 @@ func ParseConfig(path string) (config *Config, err error) {
return nil, err
}
readTimeout = time.Duration(config.Timeout) * time.Second
if strings.HasSuffix(strings.ToLower(config.Method), "-ota") {
config.Method = config.Method[:len(config.Method) - 4]
config.Auth = true
}
return
}

Expand Down
Loading

0 comments on commit 8e44bc7

Please sign in to comment.