Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,21 @@ type cachedReader struct {
cache buf.MultiBuffer
}

func (r *cachedReader) Cache(b *buf.Buffer) {
mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
func (r *cachedReader) Cache(b *buf.Buffer, deadline time.Duration) error {
mb, err := r.reader.ReadMultiBufferTimeout(deadline)
if err != nil {
return err
}
r.Lock()
if !mb.IsEmpty() {
r.cache, _ = buf.MergeMulti(r.cache, mb)
}
cacheLen := r.cache.Len()
if cacheLen <= b.Cap() {
b.Clear()
} else {
b.Release()
*b = *buf.NewWithSize(cacheLen)
}
rawBytes := b.Extend(cacheLen)
b.Clear()
rawBytes := b.Extend(b.Cap())
n := r.cache.Copy(rawBytes)
b.Resize(0, int32(n))
r.Unlock()
return nil
}

func (r *cachedReader) readInternal() buf.MultiBuffer {
Expand Down Expand Up @@ -355,7 +353,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
}

func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
payload := buf.New()
payload := buf.NewWithSize(32767)
defer payload.Release()

sniffer := NewSniffer(ctx)
Expand All @@ -367,26 +365,33 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
}

contentResult, contentErr := func() (SniffResult, error) {
cacheDeadline := 200 * time.Millisecond
totalAttempt := 0
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
totalAttempt++
if totalAttempt > 2 {
return nil, errSniffingTimeout
}
cachingStartingTimeStamp := time.Now()
cacheErr := cReader.Cache(payload, cacheDeadline)
cachingTimeElapsed := time.Since(cachingStartingTimeStamp)
cacheDeadline -= cachingTimeElapsed

cReader.Cache(payload)
if !payload.IsEmpty() {
result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
if err != common.ErrNoClue {
switch err {
case common.ErrNoClue: // No Clue: protocol not matches, and sniffer cannot determine whether there will be a match or not
totalAttempt++
case protocol.ErrProtoNeedMoreData: // Protocol Need More Data: protocol matches, but need more data to complete sniffing
if cacheErr != nil { // Cache error (e.g. timeout) counts for failed attempt
totalAttempt++
}
default:
return result, err
}
}
if payload.IsFull() {
return nil, errUnknownContent
if totalAttempt >= 2 || cacheDeadline <= 0 {
return nil, errSniffingTimeout
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions app/dispatcher/sniffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/bittorrent"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/protocol/quic"
Expand Down Expand Up @@ -58,14 +59,17 @@ var errUnknownContent = errors.New("unknown content")
func (s *Sniffer) Sniff(c context.Context, payload []byte, network net.Network) (SniffResult, error) {
var pendingSniffer []protocolSnifferWithMetadata
for _, si := range s.sniffer {
s := si.protocolSniffer
protocolSniffer := si.protocolSniffer
if si.metadataSniffer || si.network != network {
continue
}
result, err := s(c, payload)
result, err := protocolSniffer(c, payload)
if err == common.ErrNoClue {
pendingSniffer = append(pendingSniffer, si)
continue
} else if err == protocol.ErrProtoNeedMoreData { // Sniffer protocol matched, but need more data to complete sniffing
s.sniffer = []protocolSnifferWithMetadata{si}
return nil, err
}

if err == nil && result != nil {
Expand Down
47 changes: 31 additions & 16 deletions common/buf/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,27 @@ const (

var pool = bytespool.GetPool(Size)

// ownership represents the data owner of the buffer.
type ownership uint8

const (
managed ownership = iota
unmanaged
bytespools
)

// Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
// the buffer into an internal buffer pool, in order to recreate a buffer more
// quickly.
type Buffer struct {
v []byte
start int32
end int32
unmanaged bool
ownership ownership
UDP *net.Destination
}

// New creates a Buffer with 0 length and 8K capacity.
// New creates a Buffer with 0 length and 8K capacity, managed.
func New() *Buffer {
buf := pool.Get().([]byte)
if cap(buf) >= Size {
Expand All @@ -40,7 +49,7 @@ func New() *Buffer {
}
}

// NewExisted creates a managed, standard size Buffer with an existed bytearray
// NewExisted creates a standard size Buffer with an existed bytearray, managed.
func NewExisted(b []byte) *Buffer {
if cap(b) < Size {
panic("Invalid buffer")
Expand All @@ -57,16 +66,16 @@ func NewExisted(b []byte) *Buffer {
}
}

// FromBytes creates a Buffer with an existed bytearray
// FromBytes creates a Buffer with an existed bytearray, unmanaged.
func FromBytes(b []byte) *Buffer {
return &Buffer{
v: b,
end: int32(len(b)),
unmanaged: true,
ownership: unmanaged,
}
}

// StackNew creates a new Buffer object on stack.
// StackNew creates a new Buffer object on stack, managed.
// This method is for buffers that is released in the same function.
func StackNew() Buffer {
buf := pool.Get().([]byte)
Expand All @@ -81,18 +90,31 @@ func StackNew() Buffer {
}
}

// NewWithSize creates a Buffer with 0 length and capacity with at least the given size, bytespool's.
func NewWithSize(size int32) *Buffer {
return &Buffer{
v: bytespool.Alloc(size),
ownership: bytespools,
}
}

// Release recycles the buffer into an internal buffer pool.
func (b *Buffer) Release() {
if b == nil || b.v == nil || b.unmanaged {
if b == nil || b.v == nil || b.ownership == unmanaged {
return
}

p := b.v
b.v = nil
b.Clear()

if cap(p) == Size {
pool.Put(p)
switch b.ownership {
case managed:
if cap(p) == Size {
pool.Put(p)
}
case bytespools:
bytespool.Free(p)
}
b.UDP = nil
}
Expand Down Expand Up @@ -215,13 +237,6 @@ func (b *Buffer) Cap() int32 {
return int32(len(b.v))
}

// NewWithSize creates a Buffer with 0 length and capacity with at least the given size.
func NewWithSize(size int32) *Buffer {
return &Buffer{
v: bytespool.Alloc(size),
}
}

// IsEmpty returns true if the buffer is empty.
func (b *Buffer) IsEmpty() bool {
return b.Len() == 0
Expand Down
6 changes: 6 additions & 0 deletions common/protocol/protocol.go
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
package protocol // import "github.com/xtls/xray-core/common/protocol"

import (
"errors"
)

var ErrProtoNeedMoreData = errors.New("protocol matches, but need more data to complete sniffing")
94 changes: 42 additions & 52 deletions common/protocol/quic/sniff.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package quic

import (
"context"
"crypto"
"crypto/aes"
"crypto/tls"
Expand All @@ -13,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/bytespool"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/protocol"
ptls "github.com/xtls/xray-core/common/protocol/tls"
"golang.org/x/crypto/hkdf"
)
Expand Down Expand Up @@ -47,22 +47,17 @@ var (
errNotQuicInitial = errors.New("not initial packet")
)

func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
// In extremely rare cases, this sniffer may cause slice error
// and we set recover() here to prevent crash.
// TODO: Thoroughly fix this panic
defer func() {
if r := recover(); r != nil {
errors.LogError(context.Background(), "Failed to sniff QUIC: ", r)
resultReturn = nil
errorReturn = common.ErrNoClue
}
}()
func SniffQUIC(b []byte) (*SniffHeader, error) {
if len(b) == 0 {
return nil, common.ErrNoClue
}

// Crypto data separated across packets
cryptoLen := 0
cryptoData := bytespool.Alloc(int32(len(b)))
cryptoData := bytespool.Alloc(32767)
defer bytespool.Free(cryptoData)
cache := buf.New()
defer cache.Release()

// Parse QUIC packets
for len(b) > 0 {
Expand Down Expand Up @@ -105,13 +100,15 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, errNotQuic
}

tokenLen, err := quicvarint.Read(buffer)
if err != nil || tokenLen > uint64(len(b)) {
return nil, errNotQuic
}
if isQuicInitial { // Only initial packets have token, see https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.2
tokenLen, err := quicvarint.Read(buffer)
if err != nil || tokenLen > uint64(len(b)) {
return nil, errNotQuic
}

if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
return nil, errNotQuic
if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
return nil, errNotQuic
}
}

packetLen, err := quicvarint.Read(buffer)
Expand All @@ -130,9 +127,6 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
continue
}

origPNBytes := make([]byte, 4)
copy(origPNBytes, b[hdrLen:hdrLen+4])

var salt []byte
if versionNumber == version1 {
salt = quicSalt
Expand All @@ -147,44 +141,34 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, err
}

cache := buf.New()
defer cache.Release()

cache.Clear()
mask := cache.Extend(int32(block.BlockSize()))
block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
b[0] ^= mask[0] & 0xf
for i := range b[hdrLen : hdrLen+4] {
packetNumberLength := int(b[0]&0x3 + 1)
for i := range packetNumberLength {
b[hdrLen+i] ^= mask[i+1]
}
packetNumberLength := b[0]&0x3 + 1
if packetNumberLength != 1 {
return nil, errNotQuicInitial
}
var packetNumber uint32
{
n, err := buffer.ReadByte()
if err != nil {
return nil, err
}
packetNumber = uint32(n)
}

extHdrLen := hdrLen + int(packetNumberLength)
copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
data := b[extHdrLen : int(packetLen)+hdrLen]

key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
cipher := AEADAESGCMTLS13(key, iv)

nonce := cache.Extend(int32(cipher.NonceSize()))
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
_, err = buffer.Read(nonce[len(nonce)-packetNumberLength:])
Copy link
Contributor Author

@j2rong4cn j2rong4cn Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RPRX 因为这里把数据读取到nonce末尾

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原本的代码就有这个潜在bug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 buffer.Read 没问题吧,上面 cache.Extend 可能有问题

if err != nil {
return nil, err
}

extHdrLen := hdrLen + packetNumberLength
data := b[extHdrLen : int(packetLen)+hdrLen]
decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen])
if err != nil {
return nil, err
}
buffer = buf.FromBytes(decrypted)
for i := 0; !buffer.IsEmpty(); i++ {
frameType := byte(0x0) // Default to PADDING frame
for !buffer.IsEmpty() {
frameType, _ := buffer.ReadByte()
for frameType == 0x0 && !buffer.IsEmpty() {
frameType, _ = buffer.ReadByte()
}
Expand Down Expand Up @@ -234,13 +218,12 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, io.ErrUnexpectedEOF
}
if cryptoLen < int(offset+length) {
cryptoLen = int(offset + length)
if len(cryptoData) < cryptoLen {
newCryptoData := bytespool.Alloc(int32(cryptoLen))
copy(newCryptoData, cryptoData)
bytespool.Free(cryptoData)
cryptoData = newCryptoData
newCryptoLen := int(offset + length)
if len(cryptoData) < newCryptoLen {
return nil, io.ErrShortBuffer
}
wipeBytes(cryptoData[cryptoLen:newCryptoLen])
cryptoLen = newCryptoLen
}
if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data
return nil, io.ErrUnexpectedEOF
Expand Down Expand Up @@ -276,7 +259,14 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
}
return &SniffHeader{domain: tlsHdr.Domain()}, nil
}
return nil, common.ErrNoClue
// All payload is parsed as valid QUIC packets, but we need more packets for crypto data to read client hello.
return nil, protocol.ErrProtoNeedMoreData
}

func wipeBytes(b []byte) {
for i := range len(b) {
b[i] = 0x0
}
}

func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
Expand Down
Loading