Skip to content

Commit

Permalink
crypto/goolm: reorganize pickle code
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Oct 27, 2024
1 parent 6d5b85a commit 67645d4
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 242 deletions.
10 changes: 4 additions & 6 deletions crypto/goolm/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import (

"maunium.net/go/mautrix/id"

"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/crypto/olm"
)

Expand Down Expand Up @@ -76,12 +74,12 @@ func NewAccount() (*Account, error) {

// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a *Account) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
return libolmpickle.PickleAsJSON(a, accountPickleVersionJSON, key)
}

// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
return libolmpickle.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
}

// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
Expand Down Expand Up @@ -322,7 +320,7 @@ func (a *Account) ForgetOldFallbackKey() {
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Account) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
decrypted, err := libolmpickle.Unpickle(key, pickled)
if err != nil {
return err
}
Expand Down Expand Up @@ -410,7 +408,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.ErrNoKeyProvided
}
return cipher.Pickle(key, a.PickleLibOlm())
return libolmpickle.Pickle(key, a.PickleLibOlm())
}

// PickleLibOlm pickles the [Account] and returns the raw bytes.
Expand Down
51 changes: 0 additions & 51 deletions crypto/goolm/cipher/pickle.go

This file was deleted.

28 changes: 0 additions & 28 deletions crypto/goolm/cipher/pickle_test.go

This file was deleted.

40 changes: 40 additions & 0 deletions crypto/goolm/libolmpickle/encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package libolmpickle

import (
"bytes"
"encoding/binary"

"go.mau.fi/util/exerrors"
)

const (
PickleBoolLength = 1
PickleUInt8Length = 1
PickleUInt32Length = 4
)

type Encoder struct {
bytes.Buffer
}

func NewEncoder() *Encoder { return &Encoder{} }

func (p *Encoder) WriteUInt8(value uint8) {
exerrors.PanicIfNotNil(p.WriteByte(value))
}

func (p *Encoder) WriteBool(value bool) {
if value {
exerrors.PanicIfNotNil(p.WriteByte(0x01))
} else {
exerrors.PanicIfNotNil(p.WriteByte(0x00))
}
}

func (p *Encoder) WriteEmptyBytes(count int) {
exerrors.Must(p.Write(make([]byte, count)))
}

func (p *Encoder) WriteUInt32(value uint32) {
exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value))
}
99 changes: 99 additions & 0 deletions crypto/goolm/libolmpickle/encoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package libolmpickle_test

import (
"testing"

"github.com/stretchr/testify/assert"

"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)

func TestEncoder(t *testing.T) {
var encoder libolmpickle.Encoder
encoder.WriteUInt32(4)
encoder.WriteUInt8(8)
encoder.WriteBool(false)
encoder.WriteEmptyBytes(10)
encoder.WriteBool(true)
encoder.Write([]byte("test"))
encoder.WriteUInt32(420_000)
assert.Equal(t, []byte{
0x00, 0x00, 0x00, 0x04, // 4
0x08, // 8
0x00, // false
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes
0x01, //true
0x74, 0x65, 0x73, 0x74, // "test" (ASCII)
0x00, 0x06, 0x68, 0xa0, // 420,000
}, encoder.Bytes())
}

func TestPickleUInt32(t *testing.T) {
values := []uint32{
0xffffffff,
0x00ff00ff,
0xf0000000,
0xf00f0000,
}
expected := [][]byte{
{0xff, 0xff, 0xff, 0xff},
{0x00, 0xff, 0x00, 0xff},
{0xf0, 0x00, 0x00, 0x00},
{0xf0, 0x0f, 0x00, 0x00},
}
for i, value := range values {
var encoder libolmpickle.Encoder
encoder.WriteUInt32(value)
assert.Equal(t, expected[i], encoder.Bytes())
}
}

func TestPickleBool(t *testing.T) {
values := []bool{
true,
false,
}
expected := [][]byte{
{0x01},
{0x00},
}
for i, value := range values {
var encoder libolmpickle.Encoder
encoder.WriteBool(value)
assert.Equal(t, expected[i], encoder.Bytes())
}
}

func TestPickleUInt8(t *testing.T) {
values := []uint8{
0xff,
0x1a,
}
expected := [][]byte{
{0xff},
{0x1a},
}
for i, value := range values {
var encoder libolmpickle.Encoder
encoder.WriteUInt8(value)
assert.Equal(t, expected[i], encoder.Bytes())
}
}

func TestPickleBytes(t *testing.T) {
values := [][]byte{
{0xff, 0xff, 0xff, 0xff},
{0x00, 0xff, 0x00, 0xff},
{0xf0, 0x00, 0x00, 0x00},
}
expected := [][]byte{
{0xff, 0xff, 0xff, 0xff},
{0x00, 0xff, 0x00, 0xff},
{0xf0, 0x00, 0x00, 0x00},
}
for i, value := range values {
var encoder libolmpickle.Encoder
encoder.Write(value)
assert.Equal(t, expected[i], encoder.Bytes())
}
}
62 changes: 35 additions & 27 deletions crypto/goolm/libolmpickle/pickle.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,48 @@
package libolmpickle

import (
"bytes"
"encoding/binary"
"crypto/aes"
"fmt"

"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/crypto/aessha2"
"maunium.net/go/mautrix/crypto/goolm/goolmbase64"
"maunium.net/go/mautrix/crypto/olm"
)

const (
PickleBoolLength = 1
PickleUInt8Length = 1
PickleUInt32Length = 4
)

type Encoder struct {
bytes.Buffer
}
const pickleMACLength = 8

func NewEncoder() *Encoder { return &Encoder{} }

func (p *Encoder) WriteUInt8(value uint8) {
exerrors.PanicIfNotNil(p.WriteByte(value))
}
var kdfPickle = []byte("Pickle") //used to derive the keys for encryption

func (p *Encoder) WriteBool(value bool) {
if value {
exerrors.PanicIfNotNil(p.WriteByte(0x01))
// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64.
func Pickle(key, plaintext []byte) ([]byte, error) {
if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil {
return nil, err
} else if ciphertext, err := c.Encrypt(plaintext); err != nil {
return nil, err
} else if mac, err := c.MAC(ciphertext); err != nil {
return nil, err
} else {
exerrors.PanicIfNotNil(p.WriteByte(0x00))
return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil
}
}

func (p *Encoder) WriteEmptyBytes(count int) {
exerrors.Must(p.Write(make([]byte, count)))
}

func (p *Encoder) WriteUInt32(value uint32) {
exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value))
// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256.
func Unpickle(key, input []byte) ([]byte, error) {
ciphertext, err := goolmbase64.Decode(input)
if err != nil {
return nil, err
}
ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]
if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil {
return nil, err
} else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil {
return nil, err
} else if !verified {
return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC)
} else {
// Set to next block size
targetCipherText := make([]byte, int(len(ciphertext)/aes.BlockSize)*aes.BlockSize)
copy(targetCipherText, ciphertext)
return c.Decrypt(targetCipherText)
}
}
Loading

0 comments on commit 67645d4

Please sign in to comment.