Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions device/noise-protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package device

import (
"encoding/binary"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -115,6 +116,53 @@ type MessageCookieReply struct {
Cookie [blake2s.Size128 + poly1305.TagSize]byte
}

var errMessageTooShort = errors.New("message too short")

func (msg *MessageInitiation) unmarshal(b []byte) error {
if len(b) < MessageInitiationSize {
return errMessageTooShort
}
Comment on lines +121 to +124
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it reset.

Also the size of the packet is checked earlier

case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}

so this check could be removed as well as error result check at the callsite.


msg.Type = binary.LittleEndian.Uint32(b)
msg.Sender = binary.LittleEndian.Uint32(b[4:])
copy(msg.Ephemeral[:], b[8:])
copy(msg.Static[:], b[8+len(msg.Ephemeral):])
copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])

return nil
}

func (msg *MessageResponse) unmarshal(b []byte) error {
if len(b) < MessageResponseSize {
return errMessageTooShort
}

msg.Type = binary.LittleEndian.Uint32(b)
msg.Sender = binary.LittleEndian.Uint32(b[4:])
msg.Receiver = binary.LittleEndian.Uint32(b[8:])
copy(msg.Ephemeral[:], b[12:])
copy(msg.Empty[:], b[12+len(msg.Ephemeral):])
copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):])
copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):])

return nil
}

func (msg *MessageCookieReply) unmarshal(b []byte) error {
if len(b) < MessageCookieReplySize {
return errMessageTooShort
}

msg.Type = binary.LittleEndian.Uint32(b)
msg.Receiver = binary.LittleEndian.Uint32(b[4:])
copy(msg.Nonce[:], b[8:])
copy(msg.Cookie[:], b[8+len(msg.Nonce):])

return nil
}

type Handshake struct {
state handshakeState
mutex sync.RWMutex
Expand Down
33 changes: 33 additions & 0 deletions device/noise-protocol_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package device

import (
"bytes"
"encoding/binary"
"testing"
)

var msgSink MessageInitiation

func BenchmarkMessageInitiationUnmarshal(b *testing.B) {
packet := make([]byte, MessageInitiationSize)
reader := bytes.NewReader(packet)
err := binary.Read(reader, binary.LittleEndian, &msgSink)
if err != nil {
b.Fatal(err)
}

b.Run("binary.Read", func(b *testing.B) {
b.ReportAllocs()
for range b.N {
reader := bytes.NewReader(packet)
_ = binary.Read(reader, binary.LittleEndian, &msgSink)
}
})

b.Run("unmarshal", func(b *testing.B) {
b.ReportAllocs()
for range b.N {
_ = msgSink.unmarshal(packet)
}
})
}
10 changes: 3 additions & 7 deletions device/receive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package device

import (
"bytes"
"encoding/binary"
"errors"
"net"
Expand Down Expand Up @@ -287,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) {
// unmarshal packet

var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
Copy link
Contributor Author

@AlexanderYastrebov AlexanderYastrebov Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here byte reader can be replaced with https://pkg.go.dev/encoding/binary#Decode added in 1.23 golang/go#60023 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you intend to rework this around binary.Decode instead of this commit, let me know. I'm all for reducing allocations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a note, nothing (except zero-copy) can beat copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ go test ./device/ -c
$ go tool objdump -S -s BenchmarkMessageInitiationUnmarshal.func2 device.test
TEXT golang.zx2c4.com/wireguard/device.BenchmarkMessageInitiationUnmarshal.func2(SB) /home/ayastrebov/src/github.com/WireGuard/wireguard-go/device/noise-protocol_test.go
        b.Run("unmarshal", func(b *testing.B) {
  0x5adb60              493b6610                CMPQ SP, 0x10(R14)
  0x5adb64              0f8638010000            JBE 0x5adca2
  0x5adb6a              55                      PUSHQ BP
  0x5adb6b              4889e5                  MOVQ SP, BP
  0x5adb6e              4883ec30                SUBQ $0x30, SP
  0x5adb72              488b5a08                MOVQ 0x8(DX), BX
  0x5adb76              48895c2428              MOVQ BX, 0x28(SP)
                b.ReportAllocs()
  0x5adb7b              90                      NOPL
        b.Run("unmarshal", func(b *testing.B) {
  0x5adb7c              488b5210                MOVQ 0x10(DX), DX
  0x5adb80              4889542420              MOVQ DX, 0x20(SP)
        b.showAllocResult = true
  0x5adb85              c6800202000001          MOVB $0x1, 0x202(AX)
                for range b.N {
  0x5adb8c              488bb0c0010000          MOVQ 0x1c0(AX), SI
  0x5adb93              eb0b                    JMP 0x5adba0
  0x5adb95              48ffce                  DECQ SI
  0x5adb98              0f1f840000000000        NOPL 0(AX)(AX*1)
  0x5adba0              4885f6                  TESTQ SI, SI
  0x5adba3              0f8ef3000000            JLE 0x5adc9c
                        _ = msgSink.unmarshal(packet)
  0x5adba9              90                      NOPL
        b.Run("unmarshal", func(b *testing.B) {
  0x5adbaa              4881fa94000000          CMPQ DX, $0x94
        if len(b) < MessageInitiationSize {
  0x5adbb1              7ce2                    JL 0x5adb95
                for range b.N {
  0x5adbb3              4889742418              MOVQ SI, 0x18(SP)
        copy(msg.Ephemeral[:], b[8:])
  0x5adbb8              488d7b08                LEAQ 0x8(BX), DI
        return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  0x5adbbc              448b03                  MOVL 0(BX), R8
        msg.Type = binary.LittleEndian.Uint32(b)
  0x5adbbf              4489059aec2700          MOVL R8, golang.zx2c4.com/wireguard/device.msgSink(SB)
        return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  0x5adbc6              448b4304                MOVL 0x4(BX), R8
        msg.Sender = binary.LittleEndian.Uint32(b[4:])
  0x5adbca              44890593ec2700          MOVL R8, golang.zx2c4.com/wireguard/device.msgSink+4(SB)
        copy(msg.Ephemeral[:], b[8:])
  0x5adbd1              488d0590ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+8(SB), AX
  0x5adbd8              4839f8                  CMPQ AX, DI
  0x5adbdb              741c                    JE 0x5adbf9
  0x5adbdd              4889fb                  MOVQ DI, BX
  0x5adbe0              b920000000              MOVL $0x20, CX
  0x5adbe5              e856f8ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adbea              488b542420              MOVQ 0x20(SP), DX
        copy(msg.Static[:], b[8+len(msg.Ephemeral):])
  0x5adbef              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adbf4              488b742418              MOVQ 0x18(SP), SI
        copy(msg.Static[:], b[8+len(msg.Ephemeral):])
  0x5adbf9              488d7b28                LEAQ 0x28(BX), DI
  0x5adbfd              488d0584ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+40(SB), AX
  0x5adc04              4839f8                  CMPQ AX, DI
  0x5adc07              741c                    JE 0x5adc25
  0x5adc09              4889fb                  MOVQ DI, BX
  0x5adc0c              b930000000              MOVL $0x30, CX
  0x5adc11              e82af8ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adc16              488b542420              MOVQ 0x20(SP), DX
        copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
  0x5adc1b              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adc20              488b742418              MOVQ 0x18(SP), SI
        copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
  0x5adc25              488d7b58                LEAQ 0x58(BX), DI
  0x5adc29              488d0588ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+88(SB), AX
  0x5adc30              4839f8                  CMPQ AX, DI
  0x5adc33              741f                    JE 0x5adc54
  0x5adc35              4889fb                  MOVQ DI, BX
  0x5adc38              b91c000000              MOVL $0x1c, CX
  0x5adc3d              0f1f00                  NOPL 0(AX)
  0x5adc40              e8fbf7ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adc45              488b542420              MOVQ 0x20(SP), DX
        copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
  0x5adc4a              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adc4f              488b742418              MOVQ 0x18(SP), SI
        copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
  0x5adc54              488d7b74                LEAQ 0x74(BX), DI
  0x5adc58              4c8d0575ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+116(SB), R8
  0x5adc5f              90                      NOPL
  0x5adc60              4939f8                  CMPQ R8, DI
  0x5adc63              740b                    JE 0x5adc70
  0x5adc65              0f104374                MOVUPS 0x74(BX), X0
  0x5adc69              0f110564ec2700          MOVUPS X0, golang.zx2c4.com/wireguard/device.msgSink+116(SB)
        copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])
  0x5adc70              488dbb84000000          LEAQ 0x84(BX), DI
  0x5adc77              4c8d0566ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+132(SB), R8
  0x5adc7e              6690                    NOPW
  0x5adc80              4939f8                  CMPQ R8, DI
  0x5adc83              0f840cffffff            JE 0x5adb95
  0x5adc89              0f108384000000          MOVUPS 0x84(BX), X0
  0x5adc90              0f11054dec2700          MOVUPS X0, golang.zx2c4.com/wireguard/device.msgSink+132(SB)
  0x5adc97              e9f9feffff              JMP 0x5adb95
        })
  0x5adc9c              4883c430                ADDQ $0x30, SP
  0x5adca0              5d                      POPQ BP
  0x5adca1              c3                      RET
        b.Run("unmarshal", func(b *testing.B) {
  0x5adca2              4889442408              MOVQ AX, 0x8(SP)
  0x5adca7              e894caecff              CALL runtime.morestack.abi0(SB)
  0x5adcac              488b442408              MOVQ 0x8(SP), AX
  0x5adcb1              e9aafeffff              JMP golang.zx2c4.com/wireguard/device.BenchmarkMessageInitiationUnmarshal.func2(SB)

err := binary.Read(reader, binary.LittleEndian, &reply)
err := reply.unmarshal(elem.packet)
if err != nil {
device.log.Verbosef("Failed to decode cookie reply")
goto skip
Expand Down Expand Up @@ -353,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) {
// unmarshal

var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
err := msg.unmarshal(elem.packet)
if err != nil {
device.log.Errorf("Failed to decode initiation message")
goto skip
Expand Down Expand Up @@ -386,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) {
// unmarshal

var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
err := msg.unmarshal(elem.packet)
if err != nil {
device.log.Errorf("Failed to decode response message")
goto skip
Expand Down