Skip to content

Commit

Permalink
crypto: always read from crypto/rand
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Oct 25, 2024
1 parent 6fd4b8a commit 7cc46f1
Show file tree
Hide file tree
Showing 19 changed files with 86 additions and 108 deletions.
4 changes: 2 additions & 2 deletions crypto/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type OlmAccount struct {
}

func NewOlmAccount() *OlmAccount {
account, err := olm.NewAccount(nil)
account, err := olm.NewAccount()
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -105,7 +105,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID
func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
if newCount > 0 {
account.Internal.GenOneTimeKeys(nil, uint(newCount))
account.Internal.GenOneTimeKeys(uint(newCount))
}
oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey)
internalKeys, err := account.Internal.OneTimeKeys()
Expand Down
22 changes: 11 additions & 11 deletions crypto/goolm/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"

"maunium.net/go/mautrix/id"

Expand Down Expand Up @@ -68,15 +67,15 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) {
return a, nil
}

// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation.
func NewAccount(reader io.Reader) (*Account, error) {
// NewAccount creates a new Account.
func NewAccount() (*Account, error) {
a := &Account{}
kPEd25519, err := crypto.Ed25519GenerateKey(reader)
kPEd25519, err := crypto.Ed25519GenerateKey()
if err != nil {
return nil, err
}
a.IdKeys.Ed25519 = kPEd25519
kPCurve25519, err := crypto.Curve25519GenerateKey(reader)
kPCurve25519, err := crypto.Curve25519GenerateKey()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -151,14 +150,14 @@ func (a *Account) MarkKeysAsPublished() {

// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxOneTimeKeys then the older
// keys are discarded. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
// keys are discarded.
func (a *Account) GenOneTimeKeys(num uint) error {
for i := uint(0); i < num; i++ {
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
newKP, err := crypto.Curve25519GenerateKey()
if err != nil {
return err
}
Expand Down Expand Up @@ -247,14 +246,15 @@ func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
//if the key is a fallback or prevFallback, don't remove it
}

// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenFallbackKey(reader io.Reader) error {
// GenFallbackKey generates a new fallback key. The old fallback key is stored
// in a.PrevFallbackKey overwriting any previous PrevFallbackKey.
func (a *Account) GenFallbackKey() error {
a.PrevFallbackKey = a.CurrentFallbackKey
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
newKP, err := crypto.Curve25519GenerateKey()
if err != nil {
return err
}
Expand Down
38 changes: 19 additions & 19 deletions crypto/goolm/account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ import (
)

func TestAccount(t *testing.T) {
firstAccount, err := account.NewAccount(nil)
firstAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = firstAccount.GenFallbackKey(nil)
err = firstAccount.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
err = firstAccount.GenOneTimeKeys(nil, 2)
err = firstAccount.GenOneTimeKeys(2)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -118,19 +118,19 @@ func TestAccountPickleJSON(t *testing.T) {
}

func TestSessions(t *testing.T) {
aliceAccount, err := account.NewAccount(nil)
aliceAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = aliceAccount.GenOneTimeKeys(nil, 5)
err = aliceAccount.GenOneTimeKeys(5)
if err != nil {
t.Fatal(err)
}
bobAccount, err := account.NewAccount(nil)
bobAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = bobAccount.GenOneTimeKeys(nil, 5)
err = bobAccount.GenOneTimeKeys(5)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -217,7 +217,7 @@ func TestOldAccountPickle(t *testing.T) {
"K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" +
"O5TmXua1FcU")
pickleKey := []byte("")
account, err := account.NewAccount(nil)
account, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
Expand All @@ -232,16 +232,16 @@ func TestOldAccountPickle(t *testing.T) {
}

func TestLoopback(t *testing.T) {
accountA, err := account.NewAccount(nil)
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}

accountB, err := account.NewAccount(nil)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = accountB.GenOneTimeKeys(nil, 42)
err = accountB.GenOneTimeKeys(42)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -328,16 +328,16 @@ func TestLoopback(t *testing.T) {
}

func TestMoreMessages(t *testing.T) {
accountA, err := account.NewAccount(nil)
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}

accountB, err := account.NewAccount(nil)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = accountB.GenOneTimeKeys(nil, 42)
err = accountB.GenOneTimeKeys(42)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -411,16 +411,16 @@ func TestMoreMessages(t *testing.T) {
}

func TestFallbackKey(t *testing.T) {
accountA, err := account.NewAccount(nil)
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}

accountB, err := account.NewAccount(nil)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = accountB.GenFallbackKey(nil)
err = accountB.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestFallbackKey(t *testing.T) {
}

// create a new fallback key for B (the old fallback should still be usable)
err = accountB.GenFallbackKey(nil)
err = accountB.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -602,7 +602,7 @@ func TestOldV3AccountPickle(t *testing.T) {
}

func TestAccountSign(t *testing.T) {
accountA, err := account.NewAccount(nil)
accountA, err := account.NewAccount()
require.NoError(t, err)
plainText := []byte("Hello, World")
signatureB64, err := accountA.Sign(plainText)
Expand Down
6 changes: 2 additions & 4 deletions crypto/goolm/account/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
package account

import (
"io"

"maunium.net/go/mautrix/crypto/olm"
)

func init() {
olm.InitNewAccount = func(r io.Reader) (olm.Account, error) {
return NewAccount(r)
olm.InitNewAccount = func() (olm.Account, error) {
return NewAccount()
}
olm.InitBlankAccount = func() olm.Account {
return &Account{}
Expand Down
17 changes: 4 additions & 13 deletions crypto/goolm/crypto/curve25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"

"golang.org/x/crypto/curve25519"

Expand All @@ -19,19 +18,11 @@ const (
curve25519PubKeyLength = 32
)

// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) {
// Curve25519GenerateKey creates a new curve25519 key pair.
func Curve25519GenerateKey() (Curve25519KeyPair, error) {
privateKeyByte := make([]byte, Curve25519KeyLength)
if reader == nil {
_, err := rand.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
} else {
_, err := reader.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
if _, err := rand.Read(privateKeyByte); err != nil {
return Curve25519KeyPair{}, err
}

privateKey := Curve25519PrivateKey(privateKeyByte)
Expand Down
10 changes: 5 additions & 5 deletions crypto/goolm/crypto/curve25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (
)

func TestCurve25519(t *testing.T) {
firstKeypair, err := crypto.Curve25519GenerateKey(nil)
firstKeypair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
secondKeypair, err := crypto.Curve25519GenerateKey(nil)
secondKeypair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestCurve25519Case1(t *testing.T) {

func TestCurve25519Pickle(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey(nil)
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestCurve25519Pickle(t *testing.T) {

func TestCurve25519PicklePubKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey(nil)
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) {

func TestCurve25519PicklePrivKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey(nil)
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 3 additions & 4 deletions crypto/goolm/crypto/ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package crypto
import (
"encoding/base64"
"fmt"
"io"

"maunium.net/go/mautrix/crypto/ed25519"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
Expand All @@ -15,9 +14,9 @@ const (
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
)

// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(reader)
// Ed25519GenerateKey creates a new ed25519 key pair.
func Ed25519GenerateKey() (Ed25519KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
return Ed25519KeyPair{}, err
}
Expand Down
10 changes: 5 additions & 5 deletions crypto/goolm/crypto/ed25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func TestEd25519(t *testing.T) {
keypair, err := crypto.Ed25519GenerateKey(nil)
keypair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand All @@ -21,7 +21,7 @@ func TestEd25519(t *testing.T) {

func TestEd25519Case1(t *testing.T) {
//64 bytes for ed25519 package
keyPair, err := crypto.Ed25519GenerateKey(nil)
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand All @@ -46,7 +46,7 @@ func TestEd25519Case1(t *testing.T) {

func TestEd25519Pickle(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey(nil)
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -77,7 +77,7 @@ func TestEd25519Pickle(t *testing.T) {

func TestEd25519PicklePubKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey(nil)
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) {

func TestEd25519PicklePrivKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey(nil)
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion crypto/goolm/pk/decryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Decryption struct {

// NewDecryption returns a new Decryption with a new generated key pair.
func NewDecryption() (*Decryption, error) {
keyPair, err := crypto.Curve25519GenerateKey(nil)
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7cc46f1

Please sign in to comment.