diff --git a/wire/message.go b/wire/message.go index a0adfeddba..d0b06741c9 100644 --- a/wire/message.go +++ b/wire/message.go @@ -28,37 +28,38 @@ const MaxMessagePayload = (1024 * 1024 * 32) // 32MB // Commands used in message headers which describe the type of message. const ( - CmdVersion = "version" - CmdVerAck = "verack" - CmdGetAddr = "getaddr" - CmdAddr = "addr" - CmdGetBlocks = "getblocks" - CmdInv = "inv" - CmdGetData = "getdata" - CmdNotFound = "notfound" - CmdBlock = "block" - CmdTx = "tx" - CmdGetHeaders = "getheaders" - CmdHeaders = "headers" - CmdPing = "ping" - CmdPong = "pong" - CmdMemPool = "mempool" - CmdMiningState = "miningstate" - CmdGetMiningState = "getminings" - CmdReject = "reject" - CmdSendHeaders = "sendheaders" - CmdFeeFilter = "feefilter" - CmdGetCFilterV2 = "getcfilterv2" - CmdCFilterV2 = "cfilterv2" - CmdGetInitState = "getinitstate" - CmdInitState = "initstate" - CmdMixPairReq = "mixpairreq" - CmdMixKeyExchange = "mixkeyxchg" - CmdMixCiphertexts = "mixcphrtxt" - CmdMixSlotReserve = "mixslotres" - CmdMixDCNet = "mixdcnet" - CmdMixConfirm = "mixconfirm" - CmdMixSecrets = "mixsecrets" + CmdVersion = "version" + CmdVerAck = "verack" + CmdGetAddr = "getaddr" + CmdAddr = "addr" + CmdGetBlocks = "getblocks" + CmdInv = "inv" + CmdGetData = "getdata" + CmdNotFound = "notfound" + CmdBlock = "block" + CmdTx = "tx" + CmdGetHeaders = "getheaders" + CmdHeaders = "headers" + CmdPing = "ping" + CmdPong = "pong" + CmdMemPool = "mempool" + CmdMiningState = "miningstate" + CmdGetMiningState = "getminings" + CmdReject = "reject" + CmdSendHeaders = "sendheaders" + CmdFeeFilter = "feefilter" + CmdGetCFilterV2 = "getcfilterv2" + CmdCFilterV2 = "cfilterv2" + CmdGetInitState = "getinitstate" + CmdInitState = "initstate" + CmdMixPairReq = "mixpairreq" + CmdMixKeyExchange = "mixkeyxchg" + CmdMixCiphertexts = "mixcphrtxt" + CmdMixSlotReserve = "mixslotres" + CmdMixFactoredPoly = "mixfactpoly" + CmdMixDCNet = "mixdcnet" + CmdMixConfirm = "mixconfirm" + CmdMixSecrets = "mixsecrets" ) const ( @@ -207,6 +208,9 @@ func makeEmptyMessage(command string) (Message, error) { case CmdMixSlotReserve: msg = &MsgMixSlotReserve{} + case CmdMixFactoredPoly: + msg = &MsgMixFactoredPoly{} + case CmdMixDCNet: msg = &MsgMixDCNet{} diff --git a/wire/message_test.go b/wire/message_test.go index 1a56a5ee13..6025fa5b95 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -87,6 +87,7 @@ func TestMessage(t *testing.T) { msgMixKE := NewMsgMixKeyExchange([33]byte{}, [32]byte{}, 1, 1, [33]byte{}, [1218]byte{}, [32]byte{}, []chainhash.Hash{}) msgMixCT := NewMsgMixCiphertexts([33]byte{}, [32]byte{}, 1, [][1047]byte{}, []chainhash.Hash{}) msgMixSR := NewMsgMixSlotReserve([33]byte{}, [32]byte{}, 1, [][][]byte{{{}}}, []chainhash.Hash{}) + msgMixFP := NewMsgMixFactoredPoly([33]byte{}, [32]byte{}, 1, [][]byte{}, []chainhash.Hash{}) msgMixDC := NewMsgMixDCNet([33]byte{}, [32]byte{}, 1, []MixVect{make(MixVect, 1)}, []chainhash.Hash{}) msgMixCM := NewMsgMixConfirm([33]byte{}, [32]byte{}, 1, NewMsgTx(), []chainhash.Hash{}) msgMixRS := NewMsgMixSecrets([33]byte{}, [32]byte{}, 1, [32]byte{}, [][]byte{}, MixVect{}) @@ -126,6 +127,7 @@ func TestMessage(t *testing.T) { {msgMixKE, msgMixKE, pver, MainNet, 1449}, {msgMixCT, msgMixCT, pver, MainNet, 158}, {msgMixSR, msgMixSR, pver, MainNet, 161}, + {msgMixFP, msgMixFP, pver, MainNet, 159}, {msgMixDC, msgMixDC, pver, MainNet, 181}, {msgMixCM, msgMixCM, pver, MainNet, 173}, {msgMixRS, msgMixRS, pver, MainNet, 192}, diff --git a/wire/msgmixdcnet_test.go b/wire/msgmixdcnet_test.go index 1963517be2..5885ffa145 100644 --- a/wire/msgmixdcnet_test.go +++ b/wire/msgmixdcnet_test.go @@ -85,7 +85,7 @@ func TestMsgMixDCNetWire(t *testing.T) { expected = append(expected, repeat(0x91, 20)...) expected = append(expected, repeat(0x92, 20)...) expected = append(expected, repeat(0x93, 20)...) - // Four seen DCs (repeating 32 bytes of 0x94, 0x95, 0x96, 0x97) + // Four seen SRs (repeating 32 bytes of 0x94, 0x95, 0x96, 0x97) expected = append(expected, 0x04) expected = append(expected, repeat(0x94, 32)...) expected = append(expected, repeat(0x95, 32)...) diff --git a/wire/msgmixfactoredpoly.go b/wire/msgmixfactoredpoly.go new file mode 100644 index 0000000000..2547ba50ec --- /dev/null +++ b/wire/msgmixfactoredpoly.go @@ -0,0 +1,259 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgMixFactoredPoly encodes the solution of the factored slot reservation +// polynomial. +type MsgMixFactoredPoly struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + Roots [][]byte + SeenSlotReserves []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixFactoredPoly) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixFactoredPoly.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run) + if err != nil { + return err + } + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixMcount { + msg := fmt.Sprintf("too many roots in message [count %v, max %v]", + count, MaxMixMcount) + return messageError(op, ErrInvalidMsg, msg) + } + + roots := make([][]byte, count) + for i := range roots { + root, err := ReadVarBytes(r, pver, MaxMixFieldValLen, "MixFactoredPoly.Roots") + if err != nil { + return err + } + roots[i] = root + } + msg.Roots = roots + + count, err = ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenSlotReserves = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixFactoredPoly) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixFactoredPoly.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixFactoredPoly) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixFactoredPoly) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixFactoredPoly) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + count := len(msg.Roots) + if !hashing && count > MaxMixMcount { + msg := fmt.Sprintf("too many solutions to factored polynomial [count %v, max %v]", + count, MaxMixMcount) + return messageError(op, ErrInvalidMsg, msg) + } + for _, root := range msg.Roots { + if !hashing && len(root) > MaxMixFieldValLen { + msg := "root exceeds bytes necessary to represent number in field" + return messageError(op, ErrInvalidMsg, msg) + } + } + srcount := len(msg.SeenSlotReserves) + if !hashing && srcount > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + srcount, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for _, root := range msg.Roots { + err := WriteVarBytes(w, pver, root) + if err != nil { + return err + } + } + + err = WriteVarInt(w, pver, uint64(srcount)) + if err != nil { + return err + } + for i := range msg.SeenSlotReserves { + err = writeElement(w, &msg.SeenSlotReserves[i]) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixFactoredPoly) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixFactoredPoly+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixFactoredPoly) Command() string { + return CmdMixFactoredPoly +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixFactoredPoly) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 49291 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixFactoredPoly) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixFactoredPoly) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous SR messages seen by the peer. +func (msg *MsgMixFactoredPoly) PrevMsgs() []chainhash.Hash { + return msg.SeenSlotReserves +} + +// Sid returns the session ID. +func (msg *MsgMixFactoredPoly) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixFactoredPoly) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixFactoredPoly returns a new mixpairreq message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixFactoredPoly(identity [33]byte, sid [32]byte, run uint32, + roots [][]byte, seenSlotReserves []chainhash.Hash) *MsgMixFactoredPoly { + + return &MsgMixFactoredPoly{ + Identity: identity, + SessionID: sid, + Run: run, + Roots: roots, + SeenSlotReserves: seenSlotReserves, + } +} diff --git a/wire/msgmixfactoredpoly_test.go b/wire/msgmixfactoredpoly_test.go new file mode 100644 index 0000000000..56c201feaa --- /dev/null +++ b/wire/msgmixfactoredpoly_test.go @@ -0,0 +1,194 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixFactoredPoly() *MsgMixFactoredPoly { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + const mcount = 4 + roots := make([][]byte, mcount) + // Add 4 roots ranging from repeating bytes of 0x84 through 0x87. + b := byte(0x84) + for i := range roots { + roots[i] = repeat(b, 32) + b++ + } + + seenSRs := make([]chainhash.Hash, 4) + for b := byte(0x88); b < 0x8C; b++ { + copy(seenSRs[b-0x88][:], repeat(b, 32)) + } + + fp := NewMsgMixFactoredPoly(id, sid, run, roots, seenSRs) + fp.Signature = sig + + return fp +} + +func TestMsgMixFactoredPoly(t *testing.T) { + pver := MixVersion + + fp := newTestMixFactoredPoly() + + buf := new(bytes.Buffer) + err := fp.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + // Four roots (repeating 32 bytes from 0x84 through 0x87) + expected = append(expected, 0x04) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x84, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x87, 32)...) + // Four seen SRs (repeating 32 bytes of 0x88, 0x89, 0x8A, 0x8B) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, repeat(0x8A, 32)...) + expected = append(expected, repeat(0x8B, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedFP := new(MsgMixFactoredPoly) + err = decodedFP.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(fp, decodedFP) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedFP), spew.Sdump(fp)) + } +} + +func TestMsgMixFactoredPolyCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixFactoredPoly() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixFactoredPoly) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixFactoredPolyMaxPayloadLength tests the results returned by +// [MsgMixFactoredPoly.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixFactoredPolyMaxPayloadLength(t *testing.T) { + var fp *MsgMixFactoredPoly + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := fp.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + uint32(VarIntSerializeSize(MaxMixMcount)) + // Root and exponent count + MaxMixMcount*MaxMixFieldValLen + // Roots + uint32(VarIntSerializeSize(MaxMixPeers)) + // Slot reserve count + 32*MaxMixPeers // Slot reserve hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := fp.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +}