diff --git a/go.mod b/go.mod index 7ab2ca309c4..08a567e8c08 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,17 @@ require ( github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op github.com/google/go-tpm v0.9.6 github.com/klauspost/compress v1.18.1 - github.com/minio/highwayhash v1.0.3 github.com/nats-io/jwt/v2 v2.8.0 github.com/nats-io/nats.go v1.47.0 github.com/nats-io/nkeys v0.4.11 github.com/nats-io/nuid v1.0.1 go.uber.org/automaxprocs v1.6.0 golang.org/x/crypto v0.43.0 - golang.org/x/sys v0.37.0 + golang.org/x/sys v0.38.0 golang.org/x/time v0.14.0 + + // We don't usually pin non-tagged commits but so far no release has + // been made that includes https://github.com/minio/highwayhash/pull/29. + // This will be updated if a new tag covers this in the future. + github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 ) diff --git a/go.sum b/go.sum index b351d963d39..66d5d817667 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/google/go-tpm v0.9.6 h1:Ku42PT4LmjDu1H5C5ISWLlpI1mj+Zq7sPGKoRw2XROA= github.com/google/go-tpm v0.9.6/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= -github.com/minio/highwayhash v1.0.3 h1:kbnuUMoHYyVl7szWjSxJnxw11k2U709jqFPPmIUyD6Q= -github.com/minio/highwayhash v1.0.3/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= github.com/nats-io/jwt/v2 v2.8.0 h1:K7uzyz50+yGZDO5o772eRE7atlcSEENpL7P+b74JV1g= github.com/nats-io/jwt/v2 v2.8.0/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= github.com/nats-io/nats.go v1.47.0 h1:YQdADw6J/UfGUd2Oy6tn4Hq6YHxCaJrVKayxxFqYrgM= @@ -27,8 +27,8 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/server/client.go b/server/client.go index 6ade5020cb6..721768d9277 100644 --- a/server/client.go +++ b/server/client.go @@ -2587,9 +2587,11 @@ func (c *client) sendPing() { // Generates the INFO to be sent to the client with the client ID included. // info arg will be copied since passed by value. // Assume lock is held. -func (c *client) generateClientInfoJSON(info Info) []byte { +func (c *client) generateClientInfoJSON(info Info, includeClientIP bool) []byte { info.CID = c.cid - info.ClientIP = c.host + if includeClientIP { + info.ClientIP = c.host + } info.MaxPayload = c.mpay if c.isWebsocket() { info.ClientConnectURLs = info.WSConnectURLs @@ -2670,7 +2672,7 @@ func (c *client) processPing() { info.RemoteAccount = c.acc.Name info.IsSystemAccount = c.acc == srv.SystemAccount() info.ConnectInfo = true - c.enqueueProto(c.generateClientInfoJSON(info)) + c.enqueueProto(c.generateClientInfoJSON(info, true)) c.mu.Unlock() srv.mu.Unlock() } diff --git a/server/client_proxyproto.go b/server/client_proxyproto.go new file mode 100644 index 00000000000..cdfbda7609b --- /dev/null +++ b/server/client_proxyproto.go @@ -0,0 +1,398 @@ +// Copyright 2025 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" +) + +// PROXY protocol v2 constants +const ( + // Protocol signature (12 bytes) + proxyProtoV2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" + + // Version and command byte format: version(4 bits) | command(4 bits) + proxyProtoV2VerMask = 0xF0 + proxyProtoV2Ver = 0x20 // Version 2 + + // Commands + proxyProtoCmdMask = 0x0F + proxyProtoCmdLocal = 0x00 // LOCAL command (health check, use original connection) + proxyProtoCmdProxy = 0x01 // PROXY command (proxied connection) + + // Address family and protocol byte format: family(4 bits) | protocol(4 bits) + proxyProtoFamilyMask = 0xF0 + proxyProtoFamilyUnspec = 0x00 // Unspecified + proxyProtoFamilyInet = 0x10 // IPv4 + proxyProtoFamilyInet6 = 0x20 // IPv6 + proxyProtoFamilyUnix = 0x30 // Unix socket + proxyProtoProtoMask = 0x0F + proxyProtoProtoUnspec = 0x00 // Unspecified + proxyProtoProtoStream = 0x01 // TCP/STREAM + proxyProtoProtoDatagram = 0x02 // UDP/DGRAM + + // Address sizes + proxyProtoAddrSizeIPv4 = 12 // 4 (src IP) + 4 (dst IP) + 2 (src port) + 2 (dst port) + proxyProtoAddrSizeIPv6 = 36 // 16 (src IP) + 16 (dst IP) + 2 (src port) + 2 (dst port) + + // Header sizes + proxyProtoV2HeaderSize = 16 // Fixed header: 12 (sig) + 1 (ver/cmd) + 1 (fam/proto) + 2 (addr len) + + // Timeout for reading PROXY protocol header + proxyProtoReadTimeout = 5 * time.Second +) + +// PROXY protocol v1 constants +const ( + proxyProtoV1Prefix = "PROXY " + proxyProtoV1MaxLineLen = 107 // Maximum line length including CRLF + proxyProtoV1TCP4 = "TCP4" + proxyProtoV1TCP6 = "TCP6" + proxyProtoV1Unknown = "UNKNOWN" +) + +var ( + // Errors + errProxyProtoInvalid = errors.New("invalid PROXY protocol header") + errProxyProtoUnsupported = errors.New("unsupported PROXY protocol feature") + errProxyProtoTimeout = errors.New("timeout reading PROXY protocol header") + errProxyProtoUnrecognized = errors.New("unrecognized PROXY protocol format") +) + +// proxyProtoAddr contains the address information extracted from PROXY protocol header +type proxyProtoAddr struct { + srcIP net.IP + srcPort uint16 + dstIP net.IP + dstPort uint16 +} + +// String implements net.Addr interface +func (p *proxyProtoAddr) String() string { + return net.JoinHostPort(p.srcIP.String(), fmt.Sprintf("%d", p.srcPort)) +} + +// Network implements net.Addr interface +func (p *proxyProtoAddr) Network() string { + if p.srcIP.To4() != nil { + return "tcp4" + } + return "tcp6" +} + +// proxyConn wraps a net.Conn to override RemoteAddr() with the address +// extracted from the PROXY protocol header +type proxyConn struct { + net.Conn + remoteAddr net.Addr +} + +// RemoteAddr returns the original client address extracted from PROXY protocol +func (pc *proxyConn) RemoteAddr() net.Addr { + return pc.remoteAddr +} + +// detectProxyProtoVersion reads the first bytes and determines protocol version. +// Returns 1 for v1, 2 for v2, or error. +// The first 6 bytes read are returned so they can be used by the parser. +func detectProxyProtoVersion(conn net.Conn) (version int, header []byte, err error) { + // Read first 6 bytes to check for "PROXY " or v2 signature + header = make([]byte, 6) + if _, err = io.ReadFull(conn, header); err != nil { + return 0, nil, fmt.Errorf("failed to read protocol version: %w", err) + } + switch bytesToString(header) { + case proxyProtoV1Prefix: + return 1, header, nil + case proxyProtoV2Sig[:6]: + return 2, header, nil + default: + return 0, nil, errProxyProtoUnrecognized + } +} + +// readProxyProtoV1Header parses PROXY protocol v1 text format. +// Expects the "PROXY " prefix (6 bytes) to have already been consumed. +func readProxyProtoV1Header(conn net.Conn) (*proxyProtoAddr, error) { + // Read rest of line (max 107 bytes total, already read 6) + maxRemaining := proxyProtoV1MaxLineLen - 6 + + // Read up to maxRemaining bytes at once (more efficient than byte-by-byte) + buf := make([]byte, maxRemaining) + var line []byte + + for len(line) < maxRemaining { + // Read available data + n, err := conn.Read(buf[len(line):]) + if err != nil { + return nil, fmt.Errorf("failed to read v1 line: %w", err) + } + + line = buf[:len(line)+n] + + // Look for CRLF in what we've read so far + for i := 0; i < len(line)-1; i++ { + if line[i] == '\r' && line[i+1] == '\n' { + // Found CRLF - extract just the line portion + line = line[:i] + goto foundCRLF + } + } + } + + // Exceeded max length without finding CRLF + return nil, fmt.Errorf("%w: v1 line too long", errProxyProtoInvalid) + +foundCRLF: + // Get parts from the protocol + parts := strings.Fields(string(line)) + + // Validate format + if len(parts) < 1 { + return nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid) + } + + // Handle UNKNOWN (health check, like v2 LOCAL) + if parts[0] == proxyProtoV1Unknown { + return nil, nil + } + + // Must have exactly 5 parts: protocol, src-ip, dst-ip, src-port, dst-port + if len(parts) != 5 { + return nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid) + } + + protocol := parts[0] + srcIP := net.ParseIP(parts[1]) + dstIP := net.ParseIP(parts[2]) + + if srcIP == nil || dstIP == nil { + return nil, fmt.Errorf("%w: invalid address", errProxyProtoInvalid) + } + + // Parse ports + srcPort, err := strconv.ParseUint(parts[3], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid source port: %w", err) + } + + dstPort, err := strconv.ParseUint(parts[4], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid dest port: %w", err) + } + + // Validate protocol matches IP version + if protocol == proxyProtoV1TCP4 && srcIP.To4() == nil { + return nil, fmt.Errorf("%w: TCP4 with IPv6 address", errProxyProtoInvalid) + } + if protocol == proxyProtoV1TCP6 && srcIP.To4() != nil { + return nil, fmt.Errorf("%w: TCP6 with IPv4 address", errProxyProtoInvalid) + } + if protocol != proxyProtoV1TCP4 && protocol != proxyProtoV1TCP6 { + return nil, fmt.Errorf("%w: invalid protocol %s", errProxyProtoInvalid, protocol) + } + + return &proxyProtoAddr{ + srcIP: srcIP, + srcPort: uint16(srcPort), + dstIP: dstIP, + dstPort: uint16(dstPort), + }, nil +} + +// readProxyProtoHeader reads and parses PROXY protocol (v1 or v2) from the connection. +// Automatically detects version and routes to appropriate parser. +// If the command is LOCAL/UNKNOWN (health check), it returns nil for addr and no error. +// If the command is PROXY, it returns the parsed address information. +// The connection must be fresh (no data read yet). +func readProxyProtoHeader(conn net.Conn) (*proxyProtoAddr, error) { + // Set read deadline to prevent hanging on slow/malicious clients + if err := conn.SetReadDeadline(time.Now().Add(proxyProtoReadTimeout)); err != nil { + return nil, err + } + defer conn.SetReadDeadline(time.Time{}) + + // Detect version + version, firstBytes, err := detectProxyProtoVersion(conn) + if err != nil { + return nil, err + } + + switch version { + case 1: + // v1 parser expects "PROXY " prefix already consumed + return readProxyProtoV1Header(conn) + case 2: + // Read rest of v2 signature (bytes 6-11, total 6 more bytes) + remaining := make([]byte, 6) + if _, err := io.ReadFull(conn, remaining); err != nil { + return nil, fmt.Errorf("failed to read v2 signature: %w", err) + } + + // Verify full signature + fullSig := string(firstBytes) + string(remaining) + if fullSig != proxyProtoV2Sig { + return nil, fmt.Errorf("%w: invalid signature", errProxyProtoInvalid) + } + + // Read rest of header: ver/cmd, fam/proto, addr-len (4 bytes) + header := make([]byte, 4) + if _, err := io.ReadFull(conn, header); err != nil { + return nil, fmt.Errorf("failed to read v2 header: %w", err) + } + + // Continue with parsing + return parseProxyProtoV2Header(conn, header) + default: + return nil, fmt.Errorf("unsupported PROXY protocol version: %d", version) + } +} + +// readProxyProtoV2Header is kept for backward compatibility and direct testing. +// It reads and parses a PROXY protocol v2 header from the connection. +// If the command is LOCAL (health check), it returns nil for addr and no error. +// If the command is PROXY, it returns the parsed address information. +// The connection must be fresh (no data read yet). +func readProxyProtoV2Header(conn net.Conn) (*proxyProtoAddr, error) { + // Set read deadline to prevent hanging on slow/malicious clients + if err := conn.SetReadDeadline(time.Now().Add(proxyProtoReadTimeout)); err != nil { + return nil, err + } + defer conn.SetReadDeadline(time.Time{}) + + // Read fixed header (16 bytes) + header := make([]byte, proxyProtoV2HeaderSize) + if _, err := io.ReadFull(conn, header); err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return nil, errProxyProtoTimeout + } + return nil, fmt.Errorf("failed to read PROXY protocol header: %w", err) + } + + // Validate signature (first 12 bytes) + if string(header[:12]) != proxyProtoV2Sig { + return nil, fmt.Errorf("%w: invalid signature", errProxyProtoInvalid) + } + + // Continue with parsing after signature + return parseProxyProtoV2Header(conn, header[12:16]) +} + +// parseProxyProtoV2Header parses v2 protocol after signature has been validated. +// header contains the 4 bytes: ver/cmd, fam/proto, addr-len (2 bytes). +func parseProxyProtoV2Header(conn net.Conn, header []byte) (*proxyProtoAddr, error) { + // Parse version and command + verCmd := header[0] + version := verCmd & proxyProtoV2VerMask + command := verCmd & proxyProtoCmdMask + + if version != proxyProtoV2Ver { + return nil, fmt.Errorf("%w: invalid version 0x%02x", errProxyProtoInvalid, version) + } + + // Parse address family and protocol + famProto := header[1] + family := famProto & proxyProtoFamilyMask + protocol := famProto & proxyProtoProtoMask + + // Parse address length (big-endian uint16) + addrLen := binary.BigEndian.Uint16(header[2:4]) + + // Handle LOCAL command (health check) + if command == proxyProtoCmdLocal { + // For LOCAL, we should skip the address data if any + if addrLen > 0 { + // Discard the address data + if _, err := io.CopyN(io.Discard, conn, int64(addrLen)); err != nil { + return nil, fmt.Errorf("failed to discard LOCAL command address data: %w", err) + } + } + return nil, nil // nil addr indicates LOCAL command + } + + // Handle PROXY command + if command != proxyProtoCmdProxy { + return nil, fmt.Errorf("unknown PROXY protocol command: 0x%02x", command) + } + + // Validate protocol (we only support STREAM/TCP) + if protocol != proxyProtoProtoStream { + return nil, fmt.Errorf("%w: only STREAM protocol supported", errProxyProtoUnsupported) + } + + // Parse address data based on family + var addr *proxyProtoAddr + var err error + switch family { + case proxyProtoFamilyInet: + addr, err = parseIPv4Addr(conn, addrLen) + case proxyProtoFamilyInet6: + addr, err = parseIPv6Addr(conn, addrLen) + case proxyProtoFamilyUnspec: + // UNSPEC family with PROXY command is valid but rare + // Just skip the address data + if addrLen > 0 { + if _, err := io.CopyN(io.Discard, conn, int64(addrLen)); err != nil { + return nil, fmt.Errorf("failed to discard UNSPEC address address data: %w", err) + } + } + return nil, nil + default: + return nil, fmt.Errorf("%w: unsupported address family 0x%02x", errProxyProtoUnsupported, family) + } + return addr, err +} + +// parseIPv4Addr parses IPv4 address data from PROXY protocol header +func parseIPv4Addr(conn net.Conn, addrLen uint16) (*proxyProtoAddr, error) { + // IPv4: 4 (src IP) + 4 (dst IP) + 2 (src port) + 2 (dst port) = 12 bytes minimum + if addrLen < proxyProtoAddrSizeIPv4 { + return nil, fmt.Errorf("IPv4 address data too short: %d bytes", addrLen) + } + addrData := make([]byte, addrLen) + if _, err := io.ReadFull(conn, addrData); err != nil { + return nil, fmt.Errorf("failed to read IPv4 address data: %w", err) + } + return &proxyProtoAddr{ + srcIP: net.IP(addrData[0:4]), + dstIP: net.IP(addrData[4:8]), + srcPort: binary.BigEndian.Uint16(addrData[8:10]), + dstPort: binary.BigEndian.Uint16(addrData[10:12]), + }, nil +} + +// parseIPv6Addr parses IPv6 address data from PROXY protocol header +func parseIPv6Addr(conn net.Conn, addrLen uint16) (*proxyProtoAddr, error) { + // IPv6: 16 (src IP) + 16 (dst IP) + 2 (src port) + 2 (dst port) = 36 bytes minimum + if addrLen < proxyProtoAddrSizeIPv6 { + return nil, fmt.Errorf("IPv6 address data too short: %d bytes", addrLen) + } + addrData := make([]byte, addrLen) + if _, err := io.ReadFull(conn, addrData); err != nil { + return nil, fmt.Errorf("failed to read IPv6 address data: %w", err) + } + return &proxyProtoAddr{ + srcIP: net.IP(addrData[0:16]), + dstIP: net.IP(addrData[16:32]), + srcPort: binary.BigEndian.Uint16(addrData[32:34]), + dstPort: binary.BigEndian.Uint16(addrData[34:36]), + }, nil +} diff --git a/server/client_proxyproto_test.go b/server/client_proxyproto_test.go new file mode 100644 index 00000000000..512e6bf987a --- /dev/null +++ b/server/client_proxyproto_test.go @@ -0,0 +1,594 @@ +// Copyright 2025 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strings" + "testing" + "time" +) + +// mockConn is a mock net.Conn for testing +type mockConn struct { + net.Conn + readBuf *bytes.Buffer + writeBuf *bytes.Buffer + closed bool + deadline time.Time +} + +func newMockConn(data []byte) *mockConn { + return &mockConn{ + readBuf: bytes.NewBuffer(data), + writeBuf: &bytes.Buffer{}, + } +} + +func (m *mockConn) Read(b []byte) (int, error) { + if m.closed { + return 0, fmt.Errorf("connection closed") + } + if !m.deadline.IsZero() && time.Now().After(m.deadline) { + return 0, &net.OpError{Op: "read", Err: fmt.Errorf("timeout")} + } + return m.readBuf.Read(b) +} + +func (m *mockConn) Write(b []byte) (int, error) { + if m.closed { + return 0, fmt.Errorf("connection closed") + } + return m.writeBuf.Write(b) +} + +func (m *mockConn) Close() error { + m.closed = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4222} +} + +func (m *mockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 54321} +} + +func (m *mockConn) SetDeadline(t time.Time) error { + m.deadline = t + return nil +} + +func (m *mockConn) SetReadDeadline(t time.Time) error { + m.deadline = t + return nil +} + +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// buildProxyV2Header builds a valid PROXY protocol v2 header +func buildProxyV2Header(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, family byte) []byte { + t.Helper() + + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) // Write signature + buf.WriteByte(proxyProtoV2Ver | proxyProtoCmdProxy) // Write version and command (version 2, PROXY command) + buf.WriteByte(family | proxyProtoProtoStream) // Write family and protocol + + src := net.ParseIP(srcIP) + dst := net.ParseIP(dstIP) + var addrData []byte + switch family { + case proxyProtoFamilyInet: + // IPv4: 12 bytes + addrData = make([]byte, proxyProtoAddrSizeIPv4) + copy(addrData[0:4], src.To4()) + copy(addrData[4:8], dst.To4()) + binary.BigEndian.PutUint16(addrData[8:10], srcPort) + binary.BigEndian.PutUint16(addrData[10:12], dstPort) + case proxyProtoFamilyInet6: + // IPv6: 36 bytes + addrData = make([]byte, proxyProtoAddrSizeIPv6) + copy(addrData[0:16], src.To16()) + copy(addrData[16:32], dst.To16()) + binary.BigEndian.PutUint16(addrData[32:34], srcPort) + binary.BigEndian.PutUint16(addrData[34:36], dstPort) + default: + t.Fatalf("unsupported address family: %d", family) + } + + addrLen := make([]byte, 2) + binary.BigEndian.PutUint16(addrLen, uint16(len(addrData))) + buf.Write(addrLen) + buf.Write(addrData) + + return buf.Bytes() +} + +// buildProxyV2LocalHeader builds a PROXY protocol v2 LOCAL command header +func buildProxyV2LocalHeader() []byte { + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) // // Write signature + buf.WriteByte(proxyProtoV2Ver | proxyProtoCmdLocal) // // Write version and command (version 2, LOCAL command) + buf.WriteByte(proxyProtoFamilyUnspec | proxyProtoProtoUnspec) // // Write family and protocol (UNSPEC) + buf.WriteByte(0) + buf.WriteByte(0) + return buf.Bytes() +} + +// buildProxyV1Header builds a valid PROXY protocol v1 header (text format) +func buildProxyV1Header(t *testing.T, protocol, srcIP, dstIP string, srcPort, dstPort uint16) []byte { + t.Helper() + + if protocol != "TCP4" && protocol != "TCP6" && protocol != "UNKNOWN" { + t.Fatalf("invalid protocol: %s", protocol) + } + + var line string + if protocol == "UNKNOWN" { + line = "PROXY UNKNOWN\r\n" + } else { + line = fmt.Sprintf("PROXY %s %s %s %d %d\r\n", protocol, srcIP, dstIP, srcPort, dstPort) + } + + return []byte(line) +} + +func TestClientProxyProtoV2ParseIPv4(t *testing.T) { + header := buildProxyV2Header(t, "192.168.1.50", "10.0.0.1", 12345, 4222, proxyProtoFamilyInet) + conn := newMockConn(header) + + addr, err := readProxyProtoV2Header(conn) + require_NoError(t, err) + require_NotNil(t, addr) + + require_Equal(t, addr.srcIP.String(), "192.168.1.50") + require_Equal(t, addr.srcPort, 12345) + + require_Equal(t, addr.dstIP.String(), "10.0.0.1") + require_Equal(t, addr.dstPort, 4222) + + // Test String() and Network() methods + require_Equal(t, addr.String(), "192.168.1.50:12345") + require_Equal(t, addr.Network(), "tcp4") +} + +func TestClientProxyProtoV2ParseIPv6(t *testing.T) { + header := buildProxyV2Header(t, "2001:db8::1", "2001:db8::2", 54321, 4222, proxyProtoFamilyInet6) + conn := newMockConn(header) + + addr, err := readProxyProtoV2Header(conn) + require_NoError(t, err) + require_NotNil(t, addr) + + require_Equal(t, addr.srcIP.String(), "2001:db8::1") + require_Equal(t, addr.srcPort, 54321) + + require_Equal(t, addr.dstIP.String(), "2001:db8::2") + require_Equal(t, addr.dstPort, 4222) + + // Test Network() method for IPv6 + require_Equal(t, addr.String(), "[2001:db8::1]:54321") + require_Equal(t, addr.Network(), "tcp6") +} + +func TestClientProxyProtoV2ParseLocalCommand(t *testing.T) { + header := buildProxyV2LocalHeader() + conn := newMockConn(header) + + addr, err := readProxyProtoV2Header(conn) + require_NoError(t, err) + require_True(t, addr == nil) +} + +func TestClientProxyProtoV2InvalidSignature(t *testing.T) { + // Create invalid signature + header := []byte("INVALID_SIG_") + header = append(header, []byte{0x20, 0x11, 0x00, 0x0C}...) + conn := newMockConn(header) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV2InvalidVersion(t *testing.T) { + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) + buf.WriteByte(0x10 | proxyProtoCmdProxy) // Version 1 instead of 2 + buf.WriteByte(proxyProtoFamilyInet | proxyProtoProtoStream) + buf.WriteByte(0) + buf.WriteByte(0) + + conn := newMockConn(buf.Bytes()) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV2UnsupportedFamily(t *testing.T) { + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) + buf.WriteByte(proxyProtoV2Ver | proxyProtoCmdProxy) + buf.WriteByte(proxyProtoFamilyUnix | proxyProtoProtoStream) // Unix socket family + buf.WriteByte(0) + buf.WriteByte(0) + + conn := newMockConn(buf.Bytes()) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, errProxyProtoUnsupported) +} + +func TestClientProxyProtoV2UnsupportedProtocol(t *testing.T) { + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) + buf.WriteByte(proxyProtoV2Ver | proxyProtoCmdProxy) + buf.WriteByte(proxyProtoFamilyInet | proxyProtoProtoDatagram) // UDP instead of TCP + buf.WriteByte(0) + buf.WriteByte(12) + + conn := newMockConn(buf.Bytes()) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, errProxyProtoUnsupported) +} + +func TestClientProxyProtoV2TruncatedHeader(t *testing.T) { + header := buildProxyV2Header(t, "192.168.1.50", "10.0.0.1", 12345, 4222, proxyProtoFamilyInet) + // Only send first 10 bytes (incomplete header) + conn := newMockConn(header[:10]) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, io.ErrUnexpectedEOF) +} + +func TestClientProxyProtoV2ShortAddressData(t *testing.T) { + buf := &bytes.Buffer{} + buf.WriteString(proxyProtoV2Sig) + buf.WriteByte(proxyProtoV2Ver | proxyProtoCmdProxy) + buf.WriteByte(proxyProtoFamilyInet | proxyProtoProtoStream) + // Set address length to 12 but don't provide data + buf.WriteByte(0) + buf.WriteByte(12) + // Only provide 5 bytes instead of 12 + buf.Write([]byte{1, 2, 3, 4, 5}) + + conn := newMockConn(buf.Bytes()) + + _, err := readProxyProtoV2Header(conn) + require_Error(t, err, io.ErrUnexpectedEOF) +} + +func TestProxyConnRemoteAddr(t *testing.T) { + // Create a real TCP connection for testing + originalAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 54321} + + // Create proxy address + proxyAddr := &proxyProtoAddr{ + srcIP: net.ParseIP("10.0.0.50"), + srcPort: 12345, + dstIP: net.ParseIP("10.0.0.1"), + dstPort: 4222, + } + + mockConn := newMockConn(nil) + wrapped := &proxyConn{ + Conn: mockConn, + remoteAddr: proxyAddr, + } + + // Verify RemoteAddr returns the proxied address + addr := wrapped.RemoteAddr() + require_Equal(t, addr.String(), "10.0.0.50:12345") + require_Equal(t, mockConn.RemoteAddr().String(), originalAddr.String()) +} + +func TestClientProxyProtoV2EndToEnd(t *testing.T) { + // Start a test server with PROXY protocol enabled + opts := DefaultOptions() + opts.Port = -1 // Random port + opts.ProxyProtocol = true + + s := RunServer(opts) + defer s.Shutdown() + + // Get the server's listening port + addr := s.Addr().String() + + // Connect to the server + conn, err := net.Dial("tcp", addr) + require_NoError(t, err) + defer conn.Close() + + // Send PROXY protocol header + clientIP := "203.0.113.50" + clientPort := uint16(54321) + header := buildProxyV2Header(t, clientIP, "127.0.0.1", clientPort, 4222, proxyProtoFamilyInet) + + _, err = conn.Write(header) + require_NoError(t, err) + + // Send CONNECT message + connectMsg := "CONNECT {\"verbose\":false,\"pedantic\":false,\"protocol\":1}\r\n" + _, err = conn.Write([]byte(connectMsg)) + require_NoError(t, err) + + // Read INFO and +OK + buf := make([]byte, 4096) + n, err := conn.Read(buf) + require_NoError(t, err) + + response := string(buf[:n]) + require_True(t, strings.Contains(response, "INFO")) + + // Give server time to process + time.Sleep(100 * time.Millisecond) + + // Check server's client list to verify the IP was extracted correctly + s.mu.Lock() + clients := s.clients + s.mu.Unlock() + require_True(t, len(clients) != 0) + + // Find our client + var foundClient *client + for _, c := range clients { + c.mu.Lock() + if c.host == clientIP && c.port == clientPort { + foundClient = c + } + c.mu.Unlock() + if foundClient != nil { + break + } + } + require_NotNil(t, foundClient) +} + +func TestClientProxyProtoV2LocalCommandEndToEnd(t *testing.T) { + // Start a test server with PROXY protocol enabled + opts := DefaultOptions() + opts.Port = -1 // Random port + opts.ProxyProtocol = true + + s := RunServer(opts) + defer s.Shutdown() + + // Get the server's listening port + addr := s.Addr().String() + + // Connect to the server + conn, err := net.Dial("tcp", addr) + require_NoError(t, err) + defer conn.Close() + + // Send PROXY protocol LOCAL header (health check) + header := buildProxyV2LocalHeader() + + _, err = conn.Write(header) + require_NoError(t, err) + + // Send CONNECT message + connectMsg := "CONNECT {\"verbose\":false,\"pedantic\":false,\"protocol\":1}\r\n" + _, err = conn.Write([]byte(connectMsg)) + require_NoError(t, err) + + // Read INFO and +OK + buf := make([]byte, 4096) + n, err := conn.Read(buf) + require_NoError(t, err) + + response := string(buf[:n]) + require_True(t, strings.Contains(response, "INFO")) + + // Connection should work normally with LOCAL command + time.Sleep(100 * time.Millisecond) + + // Verify at least one client is connected + s.mu.Lock() + numClients := len(s.clients) + s.mu.Unlock() + require_NotEqual(t, numClients, 0) +} + +// ============================================================================ +// PROXY Protocol v1 Tests +// ============================================================================ + +func TestClientProxyProtoV1ParseTCP4(t *testing.T) { + header := buildProxyV1Header(t, "TCP4", "192.168.1.50", "10.0.0.1", 12345, 4222) + conn := newMockConn(header) + + addr, err := readProxyProtoHeader(conn) + require_NoError(t, err) + require_NotNil(t, addr) + + require_Equal(t, addr.srcIP.String(), "192.168.1.50") + require_Equal(t, addr.srcPort, 12345) + + require_Equal(t, addr.dstIP.String(), "10.0.0.1") + require_Equal(t, addr.dstPort, 4222) +} + +func TestClientProxyProtoV1ParseTCP6(t *testing.T) { + header := buildProxyV1Header(t, "TCP6", "2001:db8::1", "2001:db8::2", 54321, 4222) + conn := newMockConn(header) + + addr, err := readProxyProtoHeader(conn) + require_NoError(t, err) + require_NotNil(t, addr) + + require_Equal(t, addr.srcIP.String(), "2001:db8::1") + require_Equal(t, addr.srcPort, 54321) + + require_Equal(t, addr.dstIP.String(), "2001:db8::2") + require_Equal(t, addr.dstPort, 4222) +} + +func TestClientProxyProtoV1ParseUnknown(t *testing.T) { + header := buildProxyV1Header(t, "UNKNOWN", "", "", 0, 0) + conn := newMockConn(header) + + addr, err := readProxyProtoHeader(conn) + require_NoError(t, err) + require_True(t, addr == nil) +} + +func TestClientProxyProtoV1InvalidFormat(t *testing.T) { + // Missing fields + header := []byte("PROXY TCP4 192.168.1.1\r\n") + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV1LineTooLong(t *testing.T) { + // Create a line longer than 107 bytes + longIP := strings.Repeat("1234567890", 12) // 120 chars + header := fmt.Appendf(nil, "PROXY TCP4 %s 10.0.0.1 12345 443\r\n", longIP) + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV1InvalidIP(t *testing.T) { + header := []byte("PROXY TCP4 not.an.ip.addr 10.0.0.1 12345 443\r\n") + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV1MismatchedProtocol(t *testing.T) { + // TCP4 with IPv6 address + header := buildProxyV1Header(t, "TCP4", "2001:db8::1", "2001:db8::2", 12345, 443) + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_Error(t, err, errProxyProtoInvalid) + + // TCP6 with IPv4 address + header2 := buildProxyV1Header(t, "TCP6", "192.168.1.1", "10.0.0.1", 12345, 443) + conn2 := newMockConn(header2) + + _, err = readProxyProtoHeader(conn2) + require_Error(t, err, errProxyProtoInvalid) +} + +func TestClientProxyProtoV1InvalidPort(t *testing.T) { + header := []byte("PROXY TCP4 192.168.1.1 10.0.0.1 99999 443\r\n") + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_True(t, err != nil) +} + +func TestClientProxyProtoV1EndToEnd(t *testing.T) { + // Start a test server with PROXY protocol enabled + opts := DefaultOptions() + opts.Port = -1 // Random port + opts.ProxyProtocol = true + + s := RunServer(opts) + defer s.Shutdown() + + // Get the server's listening port + addr := s.Addr().String() + + // Connect to the server + conn, err := net.Dial("tcp", addr) + require_NoError(t, err) + defer conn.Close() + + // Send PROXY protocol v1 header + clientIP, clientPort := "203.0.113.50", uint16(54321) + header := buildProxyV1Header(t, "TCP4", clientIP, "127.0.0.1", clientPort, 4222) + + _, err = conn.Write(header) + require_NoError(t, err) + + // Send CONNECT message + _, err = conn.Write([]byte("CONNECT {\"verbose\":false,\"pedantic\":false,\"protocol\":1}\r\n")) + require_NoError(t, err) + + // Read INFO and +OK + buf := make([]byte, 4096) + n, err := conn.Read(buf) + require_NoError(t, err) + + response := string(buf[:n]) + require_True(t, strings.Contains(response, "INFO")) + + // Give server time to process + time.Sleep(100 * time.Millisecond) + + // Check server's client list to verify the IP was extracted correctly + s.mu.Lock() + clients := s.clients + var foundClient *client + for _, c := range clients { + c.mu.Lock() + if c.host == clientIP && c.port == clientPort { + foundClient = c + } + c.mu.Unlock() + if foundClient != nil { + break + } + } + s.mu.Unlock() + require_NotNil(t, foundClient) +} + +// ============================================================================ +// Mixed Protocol Version Tests +// ============================================================================ + +func TestClientProxyProtoVersionDetection(t *testing.T) { + // Test v1 detection + v1Header := buildProxyV1Header(t, "TCP4", "192.168.1.1", "10.0.0.1", 12345, 443) + conn1 := newMockConn(v1Header) + + addr1, err := readProxyProtoHeader(conn1) + require_NoError(t, err) + require_NotNil(t, addr1) + require_Equal(t, addr1.srcIP.String(), "192.168.1.1") + + // Test v2 detection + v2Header := buildProxyV2Header(t, "192.168.1.2", "10.0.0.1", 54321, 443, proxyProtoFamilyInet) + conn2 := newMockConn(v2Header) + + addr2, err := readProxyProtoHeader(conn2) + require_NoError(t, err) + require_NotNil(t, addr2) + require_Equal(t, addr2.srcIP.String(), "192.168.1.2") +} + +func TestClientProxyProtoUnrecognizedVersion(t *testing.T) { + // Invalid header that doesn't match v1 or v2 + header := []byte("HELLO WORLD\r\n") + conn := newMockConn(header) + + _, err := readProxyProtoHeader(conn) + require_Error(t, err, errProxyProtoUnrecognized) +} diff --git a/server/filestore.go b/server/filestore.go index 0630083f8a0..c6e594e557f 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -25,7 +25,6 @@ import ( "encoding/json" "errors" "fmt" - "hash" "io" "io/fs" "math" @@ -194,7 +193,7 @@ type fileStore struct { psim *stree.SubjectTree[psi] tsl int adml int - hh hash.Hash64 + hh *highwayhash.Digest64 qch chan struct{} fsld chan struct{} cmu sync.RWMutex @@ -239,7 +238,7 @@ type msgBlock struct { lrts int64 lsts int64 llseq uint64 - hh hash.Hash64 + hh *highwayhash.Digest64 ecache elastic.Pointer[cache] cache *cache cloads uint64 @@ -468,7 +467,7 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim // Create highway hash for message blocks. Use sha256 of directory as key. key := sha256.Sum256([]byte(cfg.Name)) - fs.hh, err = highwayhash.New64(key[:]) + fs.hh, err = highwayhash.NewDigest64(key[:]) if err != nil { return nil, fmt.Errorf("could not create hash: %v", err) } @@ -939,7 +938,8 @@ func (fs *fileStore) writeStreamMeta() error { } fs.hh.Reset() fs.hh.Write(b) - checksum := hex.EncodeToString(fs.hh.Sum(nil)) + var hb [highwayhash.Size64]byte + checksum := hex.EncodeToString(fs.hh.Sum(hb[:0])) sum := filepath.Join(fs.fcfg.StoreDir, JetStreamMetaFileSum) err = fs.writeFileWithOptionalSync(sum, []byte(checksum), defaultFilePerms) if err != nil { @@ -1040,7 +1040,7 @@ func (fs *fileStore) initMsgBlock(index uint32) *msgBlock { if mb.hh == nil { key := sha256.Sum256(fs.hashKeyForBlock(index)) - mb.hh, _ = highwayhash.New64(key[:]) + mb.hh, _ = highwayhash.NewDigest64(key[:]) } return mb } @@ -4251,7 +4251,7 @@ func (fs *fileStore) newMsgBlockForWrite() (*msgBlock, error) { // Now do local hash. key := sha256.Sum256(fs.hashKeyForBlock(index)) - hh, err := highwayhash.New64(key[:]) + hh, err := highwayhash.NewDigest64(key[:]) if err != nil { return nil, fmt.Errorf("could not create hash: %v", err) } @@ -6394,10 +6394,17 @@ func (mb *msgBlock) writeMsgRecordLocked(rl, seq uint64, subj string, mhdr, msg // Only update index and do accounting if not a delete tombstone. if seq&tbit == 0 { + last := atomic.LoadUint64(&mb.last.seq) // Accounting, do this before stripping ebit, it is ebit aware. mb.updateAccounting(seq, ts, rl) // Strip ebit if set. seq = seq &^ ebit + // If we have a hole due to skipping many messages, fill it. + if len(mb.cache.idx) > 0 && last+1 < seq { + for dseq := last + 1; dseq < seq; dseq++ { + mb.cache.idx = append(mb.cache.idx, dbit) + } + } // Write index if mb.cache.idx = append(mb.cache.idx, uint32(index)|cbit); len(mb.cache.idx) == 1 { mb.cache.fseq = seq @@ -7658,7 +7665,7 @@ func (mb *msgBlock) cacheLookupEx(seq uint64, sm *StoreMsg, doCopy bool) (*Store buf := mb.cache.buf[li:] // We use the high bit to denote we have already checked the checksum. - var hh hash.Hash64 + var hh *highwayhash.Digest64 if !hashChecked { hh = mb.hh // This will force the hash check in msgFromBuf. } @@ -7757,7 +7764,7 @@ func (fs *fileStore) msgForSeqLocked(seq uint64, sm *StoreMsg, needFSLock bool) // Internal function to return msg parts from a raw buffer. // Raw buffer will be copied into sm. // Lock should be held. -func (mb *msgBlock) msgFromBuf(buf []byte, sm *StoreMsg, hh hash.Hash64) (*StoreMsg, error) { +func (mb *msgBlock) msgFromBuf(buf []byte, sm *StoreMsg, hh *highwayhash.Digest64) (*StoreMsg, error) { return mb.msgFromBufEx(buf, sm, hh, true) } @@ -7765,14 +7772,14 @@ func (mb *msgBlock) msgFromBuf(buf []byte, sm *StoreMsg, hh hash.Hash64) (*Store // Raw buffer will NOT be copied into sm. // Only use for internal use, any message that is passed to upper layers should use mb.msgFromBuf. // Lock should be held. -func (mb *msgBlock) msgFromBufNoCopy(buf []byte, sm *StoreMsg, hh hash.Hash64) (*StoreMsg, error) { +func (mb *msgBlock) msgFromBufNoCopy(buf []byte, sm *StoreMsg, hh *highwayhash.Digest64) (*StoreMsg, error) { return mb.msgFromBufEx(buf, sm, hh, false) } // Internal function to return msg parts from a raw buffer. // copy boolean will determine if we make a copy or not. // Lock should be held. -func (mb *msgBlock) msgFromBufEx(buf []byte, sm *StoreMsg, hh hash.Hash64, doCopy bool) (*StoreMsg, error) { +func (mb *msgBlock) msgFromBufEx(buf []byte, sm *StoreMsg, hh *highwayhash.Digest64, doCopy bool) (*StoreMsg, error) { if len(buf) < emptyRecordLen { return nil, errBadMsg{mb.mfn, "record too short"} } @@ -10318,7 +10325,8 @@ func (fs *fileStore) streamSnapshot(w io.WriteCloser, includeConsumers bool, err hh := fs.hh hh.Reset() hh.Write(meta) - sum := []byte(hex.EncodeToString(fs.hh.Sum(nil))) + var hb [highwayhash.Size64]byte + sum := []byte(hex.EncodeToString(fs.hh.Sum(hb[:0]))) fs.mu.Unlock() // Meta first. @@ -10421,7 +10429,8 @@ func (fs *fileStore) streamSnapshot(w io.WriteCloser, includeConsumers bool, err } o.hh.Reset() o.hh.Write(meta) - sum := []byte(hex.EncodeToString(o.hh.Sum(nil))) + var hb [highwayhash.Size64]byte + sum := []byte(hex.EncodeToString(o.hh.Sum(hb[:0]))) // We can have the running state directly encoded now. state, err := o.encodeState() @@ -10638,7 +10647,7 @@ type consumerFileStore struct { name string odir string ifn string - hh hash.Hash64 + hh *highwayhash.Digest64 state ConsumerState fch chan struct{} qch chan struct{} @@ -10681,7 +10690,7 @@ func (fs *fileStore) ConsumerStore(name string, cfg *ConsumerConfig) (ConsumerSt ifn: filepath.Join(odir, consumerState), } key := sha256.Sum256([]byte(fs.cfg.Name + "/" + name)) - hh, err := highwayhash.New64(key[:]) + hh, err := highwayhash.NewDigest64(key[:]) if err != nil { return nil, fmt.Errorf("could not create hash: %v", err) } @@ -11316,7 +11325,8 @@ func (cfs *consumerFileStore) writeConsumerMeta() error { } cfs.hh.Reset() cfs.hh.Write(b) - checksum := hex.EncodeToString(cfs.hh.Sum(nil)) + var hb [highwayhash.Size64]byte + checksum := hex.EncodeToString(cfs.hh.Sum(hb[:0])) sum := filepath.Join(cfs.odir, JetStreamMetaFileSum) err = cfs.fs.writeFileWithOptionalSync(sum, []byte(checksum), defaultFilePerms) @@ -11700,14 +11710,14 @@ func (fs *fileStore) RemoveConsumer(o ConsumerStore) error { // Deprecated: stream templates are deprecated and will be removed in a future version. type templateFileStore struct { dir string - hh hash.Hash64 + hh *highwayhash.Digest64 } // Deprecated: stream templates are deprecated and will be removed in a future version. func newTemplateFileStore(storeDir string) *templateFileStore { tdir := filepath.Join(storeDir, tmplsDir) key := sha256.Sum256([]byte("templates")) - hh, err := highwayhash.New64(key[:]) + hh, err := highwayhash.NewDigest64(key[:]) if err != nil { return nil } @@ -11736,7 +11746,8 @@ func (ts *templateFileStore) Store(t *streamTemplate) error { // FIXME(dlc) - Do checksum ts.hh.Reset() ts.hh.Write(b) - checksum := hex.EncodeToString(ts.hh.Sum(nil)) + var hb [highwayhash.Size64]byte + checksum := hex.EncodeToString(ts.hh.Sum(hb[:0])) sum := filepath.Join(dir, JetStreamMetaFileSum) if err := os.WriteFile(sum, []byte(checksum), defaultFilePerms); err != nil { return err diff --git a/server/filestore_test.go b/server/filestore_test.go index d5e5906b0c3..22f3f2c5323 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -11087,3 +11087,49 @@ func TestFileStoreMissingDeletesAfterCompact(t *testing.T) { require_Equal(t, atomic.LoadUint64(&fmb.last.seq), 2) }) } + +func TestFileStoreIdxAccountingForSkipMsgs(t *testing.T) { + test := func(t *testing.T, skipMany bool) { + testFileStoreAllPermutations(t, func(t *testing.T, fcfg FileStoreConfig) { + cfg := StreamConfig{Name: "zzz", Subjects: []string{"foo"}, Storage: FileStorage} + created := time.Now() + fs, err := newFileStoreWithCreated(fcfg, cfg, created, prf(&fcfg), nil) + require_NoError(t, err) + defer fs.Stop() + + _, _, err = fs.StoreMsg("foo", nil, nil, 0) + require_NoError(t, err) + if skipMany { + require_NoError(t, fs.SkipMsgs(2, 10)) + } else { + for i := range 10 { + _, err = fs.SkipMsg(uint64(i + 2)) + require_NoError(t, err) + } + } + _, _, err = fs.StoreMsg("foo", nil, nil, 0) + require_NoError(t, err) + + fmb := fs.getFirstBlock() + fmb.mu.Lock() + defer fmb.mu.Unlock() + + for i := range 12 { + seq := uint64(i + 1) + _, err = fmb.cacheLookupNoCopy(seq, nil) + if seq >= 2 && seq <= 11 { + require_Error(t, err, errDeletedMsg) + } else { + require_NoError(t, err) + } + } + + cache := fmb.cache + require_NotNil(t, cache) + require_Len(t, len(cache.idx), 12) + }) + } + + t.Run("SkipMsg", func(t *testing.T) { test(t, false) }) + t.Run("SkipMsgs", func(t *testing.T) { test(t, true) }) +} diff --git a/server/jetstream.go b/server/jetstream.go index 65c0ee743c5..78763f3ece2 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -1023,11 +1023,11 @@ func (s *Server) shutdownJetStream() { js.accounts = nil var qch chan struct{} - + var stopped chan struct{} if cc := js.cluster; cc != nil { if cc.qch != nil { - qch = cc.qch - cc.qch = nil + qch, stopped = cc.qch, cc.stopped + cc.qch, cc.stopped = nil, nil } js.stopUpdatesSub() if cc.c != nil { @@ -1044,14 +1044,11 @@ func (s *Server) shutdownJetStream() { // We will wait for a bit for it to close. // Do this without the lock. if qch != nil { + close(qch) // Must be close() to signal *all* listeners select { - case qch <- struct{}{}: - select { - case <-qch: - case <-time.After(2 * time.Second): - s.Warnf("Did not receive signal for successful shutdown of cluster routine") - } - default: + case <-stopped: + case <-time.After(10 * time.Second): + s.Warnf("Did not receive signal for successful shutdown of cluster routine") } } } @@ -1221,7 +1218,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c tdir := filepath.Join(jsa.storeDir, tmplsDir) if stat, err := os.Stat(tdir); err == nil && stat.IsDir() { key := sha256.Sum256([]byte("templates")) - hh, err := highwayhash.New64(key[:]) + hh, err := highwayhash.NewDigest64(key[:]) if err != nil { return err } @@ -1245,7 +1242,8 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c } hh.Reset() hh.Write(buf) - checksum := hex.EncodeToString(hh.Sum(nil)) + var hb [highwayhash.Size64]byte + checksum := hex.EncodeToString(hh.Sum(hb[:0])) if checksum != string(sum) { s.Warnf(" StreamTemplate checksums do not match %q vs %q", sum, checksum) continue @@ -1398,7 +1396,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c return nil } key := sha256.Sum256([]byte(fi.Name())) - hh, err := highwayhash.New64(key[:]) + hh, err := highwayhash.NewDigest64(key[:]) if err != nil { return err } @@ -1423,7 +1421,8 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c return nil } hh.Write(buf) - checksum := hex.EncodeToString(hh.Sum(nil)) + var hb [highwayhash.Size64]byte + checksum := hex.EncodeToString(hh.Sum(hb[:0])) if checksum != string(sum) { s.Warnf(" Stream metafile %q: checksums do not match %q vs %q", metafile, sum, checksum) return nil diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index 9e7de392609..5f53a6e3930 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -70,6 +70,8 @@ type jetStreamCluster struct { peerStreamCancelMove *subscription // To pop out the monitorCluster before the raft layer. qch chan struct{} + // To notify others that monitorCluster has actually stopped. + stopped chan struct{} // Track last meta snapshot time and duration for monitoring. lastMetaSnapTime int64 // Unix nanoseconds lastMetaSnapDuration int64 // Duration in nanoseconds @@ -641,12 +643,12 @@ func (js *jetStream) isStreamHealthy(acc *Account, sa *streamAssignment) error { case !mset.isMonitorRunning(): return errors.New("monitor goroutine not running") - case !node.Healthy(): - return errors.New("group node unhealthy") - case mset.isCatchingUp(): return errors.New("stream catching up") + case !node.Healthy(): + return errors.New("group node unhealthy") + default: return nil } @@ -954,6 +956,7 @@ func (js *jetStream) setupMetaGroup() error { s: s, c: c, qch: make(chan struct{}), + stopped: make(chan struct{}), } atomic.StoreInt32(&js.clustered, 1) c.registerWithAccount(sysAcc) @@ -1190,6 +1193,16 @@ func (js *jetStream) clusterQuitC() chan struct{} { return nil } +// Return the cluster stopped chan. +func (js *jetStream) clusterStoppedC() chan struct{} { + js.mu.RLock() + defer js.mu.RUnlock() + if js.cluster != nil { + return js.cluster.stopped + } + return nil +} + // Mark that the meta layer is recovering. func (js *jetStream) setMetaRecovering() { js.mu.Lock() @@ -1346,9 +1359,10 @@ func (js *jetStream) checkForOrphans() { func (js *jetStream) monitorCluster() { s, n := js.server(), js.getMetaGroup() - qch, rqch, lch, aq := js.clusterQuitC(), n.QuitC(), n.LeadChangeC(), n.ApplyQ() + qch, stopped, rqch, lch, aq := js.clusterQuitC(), js.clusterStoppedC(), n.QuitC(), n.LeadChangeC(), n.ApplyQ() defer s.grWG.Done() + defer close(stopped) s.Debugf("Starting metadata monitor") defer s.Debugf("Exiting metadata monitor") @@ -1445,8 +1459,6 @@ func (js *jetStream) monitorCluster() { case <-qch: // Clean signal from shutdown routine so do best effort attempt to snapshot meta layer. doSnapshot(false) - // Return the signal back since shutdown will be waiting. - close(qch) return case <-aq.ch: ces := aq.pop() @@ -4058,7 +4070,7 @@ func (js *jetStream) processStreamAssignment(sa *streamAssignment) { js.mu.Unlock() // Need to stop the stream, we can't keep running with an old config. - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { return } @@ -4072,7 +4084,7 @@ func (js *jetStream) processStreamAssignment(sa *streamAssignment) { } js.mu.Unlock() - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { ll := fmt.Sprintf("Account [%s] lookup for stream create failed: %v", accName, err) if isMember { @@ -4187,7 +4199,7 @@ func (js *jetStream) processUpdateStreamAssignment(sa *streamAssignment) { js.mu.Unlock() // Need to stop the stream, we can't keep running with an old config. - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { return } @@ -4201,9 +4213,14 @@ func (js *jetStream) processUpdateStreamAssignment(sa *streamAssignment) { } js.mu.Unlock() - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { - s.Warnf("Update Stream Account %s, error on lookup: %v", accName, err) + ll := fmt.Sprintf("Update Stream Account %s, error on lookup: %v", accName, err) + if isMember { + s.Warnf(ll) + } else { + s.Debugf(ll) + } return } @@ -4876,7 +4893,7 @@ func (js *jetStream) processConsumerAssignment(ca *consumerAssignment) { // Be conservative by protecting the whole stream, even if just one consumer is unsupported. // This ensures it's safe, even with Interest-based retention where it would otherwise // continue accepting but dropping messages. - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { return } @@ -4890,7 +4907,7 @@ func (js *jetStream) processConsumerAssignment(ca *consumerAssignment) { } js.mu.Unlock() - acc, err := s.LookupAccount(accName) + acc, err := s.lookupOrFetchAccount(accName, isMember) if err != nil { ll := fmt.Sprintf("Account [%s] lookup for consumer create failed: %v", accName, err) if isMember { @@ -5032,7 +5049,7 @@ func (js *jetStream) processClusterCreateConsumer(ca *consumerAssignment, state acc, err := s.LookupAccount(accName) if err != nil { - s.Warnf("JetStream cluster failed to lookup axccount %q: %v", accName, err) + s.Warnf("JetStream cluster failed to lookup account %q: %v", accName, err) return } diff --git a/server/jetstream_cluster_1_test.go b/server/jetstream_cluster_1_test.go index 80257d0f163..795e968a366 100644 --- a/server/jetstream_cluster_1_test.go +++ b/server/jetstream_cluster_1_test.go @@ -7497,6 +7497,44 @@ func TestJetStreamClusterStreamHealthCheckOnlyReportsSkew(t *testing.T) { require_NotEqual(t, node.State(), Closed) } +func TestJetStreamClusterStreamHealthCheckStreamCatchup(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + Replicas: 3, + }) + require_NoError(t, err) + + sl := c.streamLeader(globalAccountName, "TEST") + sjs := sl.getJetStream() + acc := sl.globalAccount() + mset, err := acc.lookupStream("TEST") + require_NoError(t, err) + sjs.mu.Lock() + sa := sjs.streamAssignment(globalAccountName, "TEST") + sjs.mu.Unlock() + + require_NoError(t, sjs.isStreamHealthy(acc, sa)) + + // Check we can report unhealthy. + n := mset.raftNode().(*raft) + n.Lock() + n.commit = 0 + n.Unlock() + require_Error(t, sjs.isStreamHealthy(acc, sa), errors.New("group node unhealthy")) + + // Catching up should have precedence. + mset.setCatchingUp() + require_True(t, mset.isCatchingUp()) + require_Error(t, sjs.isStreamHealthy(acc, sa), errors.New("stream catching up")) +} + func TestJetStreamClusterConsumerHealthCheckMustNotRecreate(t *testing.T) { c := createJetStreamClusterExplicit(t, "R3S", 3) defer c.shutdown() diff --git a/server/jetstream_jwt_test.go b/server/jetstream_jwt_test.go index a365b9c269c..ce2a5ada195 100644 --- a/server/jetstream_jwt_test.go +++ b/server/jetstream_jwt_test.go @@ -2031,3 +2031,62 @@ func TestJetStreamJWTUpdateWithPreExistingStream(t *testing.T) { return nil }) } + +func TestJetStreamAccountResolverNoFetchIfNotMember(t *testing.T) { + _, spub := createKey(t) + sysClaim := jwt.NewAccountClaims(spub) + sysClaim.Name = "SYS" + sysJwt := encodeClaim(t, sysClaim, spub) + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + + templ := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: {max_mem_store: 2GB, max_file_store: 2GB, store_dir: '%s'} + + leaf { + listen: 127.0.0.1:-1 + } + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } +` + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + w.Write([]byte("ok")) + } else if strings.HasSuffix(r.URL.Path, spub) { + w.Write([]byte(sysJwt)) + } else { + // Simulate some time being spent, but doesn't respond. + time.Sleep(250 * time.Millisecond) + } + })) + defer ts.Close() + + c := createJetStreamClusterWithTemplateAndModHook(t, templ, "R3S", 3, + func(serverName, clusterName, storeDir, conf string) string { + return conf + fmt.Sprintf(` + operator: %s + system_account: %s + resolver: URL("%s")`, ojwt, spub, ts.URL) + }) + defer c.shutdown() + + s := c.leader() + js := s.getJetStream() + ci := &ClientInfo{Cluster: "R3S", Account: aPub} + cfg := &StreamConfig{Name: "TEST", Subjects: []string{"foo"}} + sa := &streamAssignment{Client: ci, Config: cfg} + start := time.Now() + // Simulate some meta operations where this server is not a member. + // The server should not fetch the account from the resolver. + for range 5 { + js.processStreamAssignment(sa) + } + require_LessThan(t, time.Since(start), 100*time.Millisecond) +} diff --git a/server/monitor.go b/server/monitor.go index 1173f4e75d9..87e2e5420ba 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -2902,6 +2902,7 @@ type JSzOptions struct { Accounts bool `json:"accounts,omitempty"` Streams bool `json:"streams,omitempty"` Consumer bool `json:"consumer,omitempty"` + DirectConsumer bool `json:"direct_consumer,omitempty"` Config bool `json:"config,omitempty"` LeaderOnly bool `json:"leader_only,omitempty"` Offset int `json:"offset,omitempty"` @@ -2950,6 +2951,7 @@ type StreamDetail struct { Config *StreamConfig `json:"config,omitempty"` State StreamState `json:"state,omitempty"` Consumer []*ConsumerInfo `json:"consumer_detail,omitempty"` + DirectConsumer []*ConsumerInfo `json:"direct_consumer_detail,omitempty"` Mirror *StreamSourceInfo `json:"mirror,omitempty"` Sources []*StreamSourceInfo `json:"sources,omitempty"` RaftGroup string `json:"stream_raft_group,omitempty"` @@ -3007,7 +3009,7 @@ type JSInfo struct { Total int `json:"total"` } -func (s *Server) accountDetail(jsa *jsAccount, optStreams, optConsumers, optCfg, optRaft, optStreamLeader bool) *AccountDetail { +func (s *Server) accountDetail(jsa *jsAccount, optStreams, optConsumers, optDirectConsumers, optCfg, optRaft, optStreamLeader bool) *AccountDetail { jsa.mu.RLock() acc := jsa.account name := acc.GetName() @@ -3089,6 +3091,18 @@ func (s *Server) accountDetail(jsa *jsAccount, optStreams, optConsumers, optCfg, } } } + if optDirectConsumers { + for _, consumer := range stream.getDirectConsumers() { + cInfo := consumer.info() + if cInfo == nil { + continue + } + if !optCfg { + cInfo.Config = nil + } + sdet.DirectConsumer = append(sdet.Consumer, cInfo) + } + } } detail.Streams = append(detail.Streams, sdet) } @@ -3112,7 +3126,7 @@ func (s *Server) JszAccount(opts *JSzOptions) (*AccountDetail, error) { if !ok { return nil, fmt.Errorf("account %q not jetstream enabled", acc) } - return s.accountDetail(jsa, opts.Streams, opts.Consumer, opts.Config, opts.RaftGroups, opts.StreamLeaderOnly), nil + return s.accountDetail(jsa, opts.Streams, opts.Consumer, opts.DirectConsumer, opts.Config, opts.RaftGroups, opts.StreamLeaderOnly), nil } // helper to get cluster info from node via dummy group @@ -3280,7 +3294,7 @@ func (s *Server) Jsz(opts *JSzOptions) (*JSInfo, error) { jsi.AccountDetails = make([]*AccountDetail, 0, len(accounts)) for _, jsa := range accounts { - detail := s.accountDetail(jsa, opts.Streams, opts.Consumer, opts.Config, opts.RaftGroups, opts.StreamLeaderOnly) + detail := s.accountDetail(jsa, opts.Streams, opts.Consumer, opts.DirectConsumer, opts.Config, opts.RaftGroups, opts.StreamLeaderOnly) jsi.AccountDetails = append(jsi.AccountDetails, detail) } } @@ -3305,6 +3319,10 @@ func (s *Server) HandleJsz(w http.ResponseWriter, r *http.Request) { if err != nil { return } + directConsumers, err := decodeBool(w, r, "direct-consumers") + if err != nil { + return + } config, err := decodeBool(w, r, "config") if err != nil { return @@ -3336,6 +3354,7 @@ func (s *Server) HandleJsz(w http.ResponseWriter, r *http.Request) { Accounts: accounts, Streams: streams, Consumer: consumers, + DirectConsumer: directConsumers, Config: config, LeaderOnly: leader, Offset: offset, diff --git a/server/monitor_test.go b/server/monitor_test.go index 6adf9eee47d..f6791ba00f5 100644 --- a/server/monitor_test.go +++ b/server/monitor_test.go @@ -5241,6 +5241,23 @@ func TestMonitorJsz(t *testing.T) { } } }) + t.Run("direct-consumers", func(t *testing.T) { + // It could take time for the sourcing to set up. + checkFor(t, 5*time.Second, 250*time.Millisecond, func() error { + for _, url := range []string{monUrl1, monUrl2} { + info := readJsInfo(url + "?acc=ACC&consumers=true&direct-consumers=true") + if len(info.AccountDetails) != 1 { + t.Fatalf("expected account ACC to be returned by %s but got %v", url, info) + } + if slices.ContainsFunc(info.AccountDetails[0].Streams, func(stream StreamDetail) bool { + return len(stream.DirectConsumer) > 0 + }) { + return nil + } + } + return fmt.Errorf("expected direct consumer info to be present on one of the servers") + }) + }) t.Run("config", func(t *testing.T) { for _, url := range []string{monUrl1, monUrl2} { info := readJsInfo(url + "?acc=ACC&consumers=true&config=true") diff --git a/server/opts.go b/server/opts.go index 24e85f0ff45..3aafd7f970b 100644 --- a/server/opts.go +++ b/server/opts.go @@ -356,6 +356,7 @@ type Options struct { Username string `json:"-"` Password string `json:"-"` ProxyRequired bool `json:"-"` + ProxyProtocol bool `json:"-"` Authorization string `json:"-"` AuthCallout *AuthCallout `json:"-"` PingInterval time.Duration `json:"ping_interval"` @@ -1259,6 +1260,8 @@ func (o *Options) processConfigFileLine(k string, v any, errors *[]error, warnin o.MaxPayload = int32(v.(int64)) case "max_pending": o.MaxPending = v.(int64) + case "proxy_protocol": + o.ProxyProtocol = v.(bool) case "max_connections", "max_conn": o.MaxConn = int(v.(int64)) case "max_traced_msg_len": diff --git a/server/raft.go b/server/raft.go index db061f7e94e..2fba5ab3438 100644 --- a/server/raft.go +++ b/server/raft.go @@ -19,7 +19,6 @@ import ( "encoding/binary" "errors" "fmt" - "hash" "math" "math/rand" "net" @@ -153,7 +152,7 @@ type raft struct { state atomic.Int32 // RaftState leaderState atomic.Bool // Is in (complete) leader state. leaderSince atomic.Pointer[time.Time] // How long since becoming leader. - hh hash.Hash64 // Highwayhash, used for snapshots + hh *highwayhash.Digest64 // Highwayhash, used for snapshots snapfile string // Snapshot filename csz int // Cluster size @@ -447,7 +446,7 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel // Set up the highwayhash for the snapshots. key := sha256.Sum256([]byte(n.group)) - n.hh, _ = highwayhash.New64(key[:]) + n.hh, _ = highwayhash.NewDigest64(key[:]) // If we have a term and vote file (tav.idx on the filesystem) then read in // what we think the term and vote was. It's possible these are out of date @@ -1225,7 +1224,8 @@ func (n *raft) encodeSnapshot(snap *snapshot) []byte { // Now do the hash for the end. n.hh.Reset() n.hh.Write(buf[:wi]) - checksum := n.hh.Sum(nil) + var hb [highwayhash.Size64]byte + checksum := n.hh.Sum(hb[:0]) copy(buf[wi:], checksum) wi += len(checksum) return buf[:wi] @@ -1450,7 +1450,8 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) { lchk := buf[hoff:] n.hh.Reset() n.hh.Write(buf[:hoff]) - if !bytes.Equal(lchk[:], n.hh.Sum(nil)) { + var hb [highwayhash.Size64]byte + if !bytes.Equal(lchk[:], n.hh.Sum(hb[:0])) { n.warn("Snapshot corrupt, checksums did not match") os.Remove(n.snapfile) n.snapfile = _EMPTY_ diff --git a/server/route.go b/server/route.go index 57e5320fa70..b5850ecd354 100644 --- a/server/route.go +++ b/server/route.go @@ -1031,7 +1031,7 @@ func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) { c.flags.isSet(firstPongSent) { // sendInfo takes care of checking if the connection is still // valid or not, so don't duplicate tests here. - c.enqueueProto(c.generateClientInfoJSON(info)) + c.enqueueProto(c.generateClientInfoJSON(info, true)) } c.mu.Unlock() } diff --git a/server/server.go b/server/server.go index b3b4b7628ae..d8c8a1cb2c8 100644 --- a/server/server.go +++ b/server/server.go @@ -2806,6 +2806,11 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.Noticef("Listening for client connections on %s", net.JoinHostPort(opts.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + // Alert if PROXY protocol is enabled + if opts.ProxyProtocol { + s.Noticef("PROXY protocol enabled for client connections") + } + // Alert of TLS enabled. if opts.TLSConfig != nil { s.Noticef("TLS required for client connections") @@ -3335,8 +3340,11 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } // Decide if we are going to require TLS or not and generate INFO json. + // If we have ProxyProtocol enabled then we won't include the client + // IP in the initial INFO, as that would leak the proxy IP itself. + // In that case we'll send another INFO after the client introduces itself. tlsRequired := info.TLSRequired - infoBytes := c.generateClientInfoJSON(info) + infoBytes := c.generateClientInfoJSON(info, !opts.ProxyProtocol) // Send our information, except if TLS and TLSHandshakeFirst is requested. if !tlsFirst { @@ -3407,7 +3415,7 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { // different that the current value and regenerate infoBytes. if orgInfoTLSReq != info.TLSRequired { info.TLSRequired = orgInfoTLSReq - infoBytes = c.generateClientInfoJSON(info) + infoBytes = c.generateClientInfoJSON(info, !opts.ProxyProtocol) } c.sendProtoNow(infoBytes) // Set the boolean to false for the rest of the function. @@ -3420,7 +3428,7 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { // one the client wants. We'll always allow this for in-process // connections. if !isClosed && !tlsFirst && opts.TLSConfig != nil && (inProcess || opts.AllowNonTLS) { - pre = make([]byte, 4) + pre = make([]byte, 6) // Minimum 6 bytes for proxy proto in next step. c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) @@ -3432,6 +3440,55 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } } + // Check for proxy protocol if enabled. + if !isClosed && !tlsRequired && opts.ProxyProtocol { + if len(pre) == 0 { + // There has been no pre-read yet, do so so we can work out + // if the client is trying to negotiate PROXY. + pre = make([]byte, 6) + c.nc.SetReadDeadline(time.Now().Add(proxyProtoReadTimeout)) + n, _ := io.ReadFull(c.nc, pre) + c.nc.SetReadDeadline(time.Time{}) + pre = pre[:n] + } + conn = &tlsMixConn{conn, bytes.NewBuffer(pre)} + addr, err := readProxyProtoHeader(conn) + if err != nil && err != errProxyProtoUnrecognized { + // err != errProxyProtoUnrecognized implies that we detected a proxy + // protocol header but we failed to parse it, so don't continue. + c.mu.Unlock() + s.Warnf("Error reading PROXY protocol header from %s: %v", conn.RemoteAddr(), err) + c.closeConnection(ProtocolViolation) + return nil + } + // If addr is nil, it was a LOCAL/UNKNOWN command (health check) + // Use the connection as-is + if addr != nil { + c.nc = &proxyConn{ + Conn: conn, + remoteAddr: addr, + } + // These were set already by initClient, override them. + c.host = addr.srcIP.String() + c.port = addr.srcPort + } + // At this point, err is either: + // - nil => we parsed the proxy protocol header successfully + // - errProxyProtoUnrecognized => we didn't detect proxy protocol at all + // We only clear the pre-read if we successfully read the protocol header + // so that the next step doesn't re-read it. Otherwise we have to assume + // that it's a non-proxied connection and we want the pre-read to remain + // for the next step. + if err == nil { + pre = nil + } + // Because we have ProxyProtocol enabled, our earlier INFO message didn't + // include the client_ip. If we need to send it again then we will include + // it, but sending it here immediately can confuse clients who have just + // PING'd. + infoBytes = c.generateClientInfoJSON(info, true) + } + // Check for TLS if !isClosed && tlsRequired { if s.connRateCounter != nil && !s.connRateCounter.allow() { @@ -4716,7 +4773,7 @@ func (s *Server) LDMClientByID(id uint64) error { // sendInfo takes care of checking if the connection is still // valid or not, so don't duplicate tests here. c.Debugf("Sending Lame Duck Mode info to client") - c.enqueueProto(c.generateClientInfoJSON(info)) + c.enqueueProto(c.generateClientInfoJSON(info, true)) return nil } else { return errors.New("client does not support Lame Duck Mode or is not ready to receive the notification") diff --git a/server/stream.go b/server/stream.go index b98f171b2d5..658cb90e4ae 100644 --- a/server/stream.go +++ b/server/stream.go @@ -7184,6 +7184,20 @@ func (mset *stream) getPublicConsumers() []*consumer { return obs } +// This returns all consumers that are DIRECT. +func (mset *stream) getDirectConsumers() []*consumer { + mset.clsMu.RLock() + defer mset.clsMu.RUnlock() + + var obs []*consumer + for _, o := range mset.cList { + if o.cfg.Direct { + obs = append(obs, o) + } + } + return obs +} + // 2 minutes plus up to 30s jitter. const ( defaultCheckInterestStateT = 2 * time.Minute diff --git a/server/websocket.go b/server/websocket.go index cc83b465bba..c8d3b6c62f5 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -1294,7 +1294,7 @@ func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { } c.initClient() c.Debugf("Client connection created") - c.sendProtoNow(c.generateClientInfoJSON(info)) + c.sendProtoNow(c.generateClientInfoJSON(info, true)) c.mu.Unlock() s.mu.Lock()