Skip to content

Commit

Permalink
crypto/goolm/crypto: use stdlib for HKDF and HMAC operations
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Oct 25, 2024
1 parent c09eae3 commit 9f74b58
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 146 deletions.
13 changes: 9 additions & 4 deletions crypto/goolm/cipher/aes_sha256.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package cipher

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"io"

"golang.org/x/crypto/hkdf"

"maunium.net/go/mautrix/crypto/aescbc"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)

// derivedAESKeys stores the derived keys for the AESSHA256 cipher
Expand All @@ -17,9 +20,9 @@ type derivedAESKeys struct {

// deriveAESKeys derives three keys for the AESSHA256 cipher
func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) {
hkdf := crypto.HKDFSHA256(key, nil, kdfInfo)
kdf := hkdf.New(sha256.New, key, nil, kdfInfo)
keymatter := make([]byte, 80)
_, err := io.ReadFull(hkdf, keymatter)
_, err := io.ReadFull(kdf, keymatter)
return derivedAESKeys{
key: keymatter[:32],
hmacKey: keymatter[32:64],
Expand Down Expand Up @@ -63,7 +66,9 @@ func (c AESSHA256) MAC(key, message []byte) ([]byte, error) {
if err != nil {
return nil, err
}
return crypto.HMACSHA256(keys.hmacKey, message), nil
hash := hmac.New(sha256.New, keys.hmacKey)
_, err = hash.Write(message)
return hash.Sum(nil), err
}

// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes).
Expand Down
29 changes: 0 additions & 29 deletions crypto/goolm/crypto/hmac.go

This file was deleted.

101 changes: 0 additions & 101 deletions crypto/goolm/crypto/hmac_test.go

This file was deleted.

7 changes: 5 additions & 2 deletions crypto/goolm/megolm/megolm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
package megolm

import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"fmt"

"maunium.net/go/mautrix/crypto/goolm/cipher"
Expand Down Expand Up @@ -63,8 +65,9 @@ func NewWithRandom(counter uint32) (*Ratchet, error) {

// rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to.
func (m *Ratchet) rehashPart(from, to int) {
newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to])
copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength])
hash := hmac.New(sha256.New, m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength])
hash.Write(hashKeySeeds[to])
copy(m.Data[to*RatchetPartLength:], hash.Sum(nil))
}

// Advance advances the ratchet one step.
Expand Down
7 changes: 6 additions & 1 deletion crypto/goolm/ratchet/chain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ratchet

import (
"crypto/hmac"
"crypto/sha256"

"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)
Expand All @@ -18,7 +21,9 @@ type chainKey struct {

// advance advances the chain
func (c *chainKey) advance() {
c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed})
hash := hmac.New(sha256.New, c.Key)
hash.Write([]byte{chainKeySeed})
c.Key = hash.Sum(nil)
c.Index++
}

Expand Down
20 changes: 13 additions & 7 deletions crypto/goolm/ratchet/olm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
package ratchet

import (
"crypto/hmac"
"crypto/sha256"
"fmt"
"io"

"golang.org/x/crypto/hkdf"

"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
Expand Down Expand Up @@ -70,7 +74,7 @@ func New() *Ratchet {

// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
Expand All @@ -83,7 +87,7 @@ func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu

// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
Expand Down Expand Up @@ -192,7 +196,7 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc
if err != nil {
return nil, err
}
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return nil, err
Expand All @@ -203,10 +207,12 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc

// createMessageKeys returns the messageKey derived from the chainKey
func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey {
res := messageKey{}
res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed})
res.Index = chainKey.Index
return res
hash := hmac.New(sha256.New, chainKey.Key)
hash.Write([]byte{messageKeySeed})
return messageKey{
Key: hash.Sum(nil),
Index: chainKey.Index,
}
}

// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified.
Expand Down
5 changes: 3 additions & 2 deletions crypto/goolm/session/olm_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package session

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
Expand Down Expand Up @@ -203,8 +204,8 @@ func (s *OlmSession) ID() id.SessionID {
copy(message, s.AliceIdentityKey)
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
hash := crypto.SHA256(message)
res := id.SessionID(goolmbase64.Encode(hash))
hash := sha256.Sum256(message)
res := id.SessionID(goolmbase64.Encode(hash[:]))
return res
}

Expand Down

0 comments on commit 9f74b58

Please sign in to comment.