diff --git a/common/protocol/address.go b/common/protocol/address.go index bbf923f26ab..11d22f4bec8 100644 --- a/common/protocol/address.go +++ b/common/protocol/address.go @@ -239,7 +239,7 @@ func (p *addressParser) writeAddress(writer io.Writer, address net.Address) erro } case net.AddressFamilyDomain: domain := address.Domain() - if isDomainTooLong(domain) { + if IsDomainTooLong(domain) { return newError("Super long domain is not supported: ", domain) } diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 6b65fd52915..528c9e8b7d6 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -103,6 +103,6 @@ func (sc *SecurityConfig) GetSecurityType() SecurityType { return sc.Type } -func isDomainTooLong(domain string) bool { - return len(domain) > 256 +func IsDomainTooLong(domain string) bool { + return len(domain) > 255 } diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 962f5db73b5..d6ef0f66c9f 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -21,7 +21,6 @@ var ( ) const ( - maxLength = 8192 commandTCP byte = 1 commandUDP byte = 3 ) @@ -110,11 +109,11 @@ type PacketWriter struct { // WriteMultiBuffer implements buf.Writer func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { - b := make([]byte, maxLength) - for !mb.IsEmpty() { - var length int - mb, length = buf.SplitBytes(mb, b) - if _, err := w.writePacket(b[:length], w.Target); err != nil { + for _, b := range mb { + if b.IsEmpty() { + continue + } + if _, err := w.writePacket(b.Bytes(), w.Target); err != nil { buf.ReleaseMulti(mb) return err } @@ -125,11 +124,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { // WriteMultiBufferWithMetadata writes udp packet with destination specified func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error { - b := make([]byte, maxLength) - for !mb.IsEmpty() { - var length int - mb, length = buf.SplitBytes(mb, b) - if _, err := w.writePacket(b[:length], dest); err != nil { + for _, b := range mb { + if b.IsEmpty() { + continue + } + if _, err := w.writePacket(b.Bytes(), dest); err != nil { buf.ReleaseMulti(mb) return err } @@ -145,13 +144,29 @@ func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) { } func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam - buffer := buf.StackNew() - defer buffer.Release() + var addrPortLen int32 + switch dest.Address.Family() { + case net.AddressFamilyDomain: + if protocol.IsDomainTooLong(dest.Address.Domain()) { + return 0, newError("Super long domain is not supported: ", dest.Address.Domain()) + } + addrPortLen = 1 + 1 + int32(len(dest.Address.Domain())) + 2 + case net.AddressFamilyIPv4: + addrPortLen = 1 + 4 + 2 + case net.AddressFamilyIPv6: + addrPortLen = 1 + 16 + 2 + default: + panic("Unknown address type.") + } length := len(payload) lengthBuf := [2]byte{} binary.BigEndian.PutUint16(lengthBuf[:], uint16(length)) - if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil { + + buffer := buf.NewWithSize(addrPortLen + 2 + 2 + int32(length)) + defer buffer.Release() + + if err := addrParser.WriteAddressPort(buffer, dest.Address, dest.Port); err != nil { return 0, err } if _, err := buffer.Write(lengthBuf[:]); err != nil { @@ -264,10 +279,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { return nil, newError("failed to read payload length").Base(err) } - remain := int(binary.BigEndian.Uint16(lengthBuf[:])) - if remain > maxLength { - return nil, newError("oversize payload") - } + length := binary.BigEndian.Uint16(lengthBuf[:]) var crlf [2]byte if _, err := io.ReadFull(r, crlf[:]); err != nil { @@ -275,25 +287,14 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { } dest := net.UDPDestination(addr, port) - var mb buf.MultiBuffer - for remain > 0 { - length := buf.Size - if remain < length { - length = remain - } - b := buf.New() - mb = append(mb, b) - n, err := b.ReadFullFrom(r, int32(length)) - if err != nil { - buf.ReleaseMulti(mb) - return nil, newError("failed to read payload").Base(err) - } - - remain -= int(n) + b := buf.NewWithSize(int32(length)) + _, err = b.ReadFullFrom(r, int32(length)) + if err != nil { + return nil, newError("failed to read payload").Base(err) } - return &PacketPayload{Target: dest, Buffer: mb}, nil + return &PacketPayload{Target: dest, Buffer: buf.MultiBuffer{b}}, nil } type PacketConnectionReader struct { diff --git a/proxy/trojan/protocol_test.go b/proxy/trojan/protocol_test.go index 00bdc13e038..5d6a18c9af6 100644 --- a/proxy/trojan/protocol_test.go +++ b/proxy/trojan/protocol_test.go @@ -1,6 +1,7 @@ package trojan_test import ( + "crypto/rand" "testing" "github.com/google/go-cmp/cmp" @@ -90,3 +91,48 @@ func TestUDPRequest(t *testing.T) { t.Error("data: ", r) } } + +func TestLargeUDPRequest(t *testing.T) { + user := &protocol.MemoryUser{ + Email: "love@v2fly.org", + Account: toAccount(&Account{ + Password: "password", + }), + } + + payload := make([]byte, 4096) + common.Must2(rand.Read(payload)) + data := buf.NewWithSize(int32(len(payload))) + common.Must2(data.Write(payload)) + + buffer := buf.NewWithSize(2*data.Len() + 1) + defer buffer.Release() + + destination := net.Destination{Network: net.Network_UDP, Address: net.LocalHostIP, Port: 1234} + writer := &PacketWriter{Writer: &ConnWriter{Writer: buffer, Target: destination, Account: user.Account.(*MemoryAccount)}, Target: destination} + common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data, data})) + + connReader := &ConnReader{Reader: buffer} + common.Must(connReader.ParseHeader()) + + packetReader := &PacketReader{Reader: connReader} + for i := 0; i < 2; i++ { + p, err := packetReader.ReadMultiBufferWithMetadata() + common.Must(err) + + if p.Buffer.IsEmpty() { + t.Error("no request data") + } + + if r := cmp.Diff(p.Target, destination); r != "" { + t.Error("destination: ", r) + } + + mb, decoded := buf.SplitFirst(p.Buffer) + buf.ReleaseMulti(mb) + + if r := cmp.Diff(decoded.Bytes(), payload); r != "" { + t.Error("data: ", r) + } + } +}