Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy/trojan: fix writing UDP packet #2446

Merged
merged 2 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/protocol/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (p *addressParser) writeAddress(writer io.Writer, address net.Address) erro
}
case net.AddressFamilyDomain:
domain := address.Domain()
if isDomainTooLong(domain) {
if IsDomainTooLong(domain) {
return newError("Super long domain is not supported: ", domain)
}

Expand Down
4 changes: 2 additions & 2 deletions common/protocol/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,6 @@ func (sc *SecurityConfig) GetSecurityType() SecurityType {
return sc.Type
}

func isDomainTooLong(domain string) bool {
return len(domain) > 256
func IsDomainTooLong(domain string) bool {
return len(domain) > 255
}
69 changes: 35 additions & 34 deletions proxy/trojan/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ var (
)

const (
maxLength = 8192
commandTCP byte = 1
commandUDP byte = 3
)
Expand Down Expand Up @@ -110,11 +109,11 @@ type PacketWriter struct {

// WriteMultiBuffer implements buf.Writer
func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
b := make([]byte, maxLength)
for !mb.IsEmpty() {
var length int
mb, length = buf.SplitBytes(mb, b)
if _, err := w.writePacket(b[:length], w.Target); err != nil {
for _, b := range mb {
if b.IsEmpty() {
continue
}
if _, err := w.writePacket(b.Bytes(), w.Target); err != nil {
buf.ReleaseMulti(mb)
return err
}
Expand All @@ -125,11 +124,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {

// WriteMultiBufferWithMetadata writes udp packet with destination specified
func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
b := make([]byte, maxLength)
for !mb.IsEmpty() {
var length int
mb, length = buf.SplitBytes(mb, b)
if _, err := w.writePacket(b[:length], dest); err != nil {
for _, b := range mb {
if b.IsEmpty() {
continue
}
if _, err := w.writePacket(b.Bytes(), dest); err != nil {
buf.ReleaseMulti(mb)
return err
}
Expand All @@ -145,13 +144,29 @@ func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) {
}

func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
buffer := buf.StackNew()
defer buffer.Release()
var addrPortLen int32
switch dest.Address.Family() {
case net.AddressFamilyDomain:
if protocol.IsDomainTooLong(dest.Address.Domain()) {
return 0, newError("Super long domain is not supported: ", dest.Address.Domain())
}
addrPortLen = 1 + 1 + int32(len(dest.Address.Domain())) + 2
case net.AddressFamilyIPv4:
addrPortLen = 1 + 4 + 2
case net.AddressFamilyIPv6:
addrPortLen = 1 + 16 + 2
default:
panic("Unknown address type.")
}

length := len(payload)
lengthBuf := [2]byte{}
binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {

buffer := buf.NewWithSize(addrPortLen + 2 + 2 + int32(length))
defer buffer.Release()

if err := addrParser.WriteAddressPort(buffer, dest.Address, dest.Port); err != nil {
return 0, err
}
if _, err := buffer.Write(lengthBuf[:]); err != nil {
Expand Down Expand Up @@ -264,36 +279,22 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
return nil, newError("failed to read payload length").Base(err)
}

remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
if remain > maxLength {
return nil, newError("oversize payload")
}
length := binary.BigEndian.Uint16(lengthBuf[:])

var crlf [2]byte
if _, err := io.ReadFull(r, crlf[:]); err != nil {
return nil, newError("failed to read crlf").Base(err)
}

dest := net.UDPDestination(addr, port)
var mb buf.MultiBuffer
for remain > 0 {
length := buf.Size
if remain < length {
length = remain
}

b := buf.New()
mb = append(mb, b)
n, err := b.ReadFullFrom(r, int32(length))
if err != nil {
buf.ReleaseMulti(mb)
return nil, newError("failed to read payload").Base(err)
}

remain -= int(n)
b := buf.NewWithSize(int32(length))
_, err = b.ReadFullFrom(r, int32(length))
if err != nil {
return nil, newError("failed to read payload").Base(err)
}

return &PacketPayload{Target: dest, Buffer: mb}, nil
return &PacketPayload{Target: dest, Buffer: buf.MultiBuffer{b}}, nil
}

type PacketConnectionReader struct {
Expand Down
46 changes: 46 additions & 0 deletions proxy/trojan/protocol_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trojan_test

import (
"crypto/rand"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -90,3 +91,48 @@ func TestUDPRequest(t *testing.T) {
t.Error("data: ", r)
}
}

func TestLargeUDPRequest(t *testing.T) {
user := &protocol.MemoryUser{
Email: "[email protected]",
Account: toAccount(&Account{
Password: "password",
}),
}

payload := make([]byte, 4096)
common.Must2(rand.Read(payload))
data := buf.NewWithSize(int32(len(payload)))
common.Must2(data.Write(payload))

buffer := buf.NewWithSize(2*data.Len() + 1)
defer buffer.Release()

destination := net.Destination{Network: net.Network_UDP, Address: net.LocalHostIP, Port: 1234}
writer := &PacketWriter{Writer: &ConnWriter{Writer: buffer, Target: destination, Account: user.Account.(*MemoryAccount)}, Target: destination}
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data, data}))

connReader := &ConnReader{Reader: buffer}
common.Must(connReader.ParseHeader())

packetReader := &PacketReader{Reader: connReader}
for i := 0; i < 2; i++ {
p, err := packetReader.ReadMultiBufferWithMetadata()
common.Must(err)

if p.Buffer.IsEmpty() {
t.Error("no request data")
}

if r := cmp.Diff(p.Target, destination); r != "" {
t.Error("destination: ", r)
}

mb, decoded := buf.SplitFirst(p.Buffer)
buf.ReleaseMulti(mb)

if r := cmp.Diff(decoded.Bytes(), payload); r != "" {
t.Error("data: ", r)
}
}
}