Skip to content

Commit

Permalink
validate key sizes and allow empty public keys
Browse files Browse the repository at this point in the history
Signed-off-by: qmuntal <[email protected]>
  • Loading branch information
qmuntal committed Jul 13, 2023
1 parent bfe4717 commit 377ee80
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 142 deletions.
78 changes: 59 additions & 19 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,21 @@ func (k Key) validate(op KeyOp) error {
if crv == CurveInvalid || (len(x) == 0 && len(y) == 0 && len(d) == 0) {
return ErrInvalidKey
}
if size := curveSize(crv); size > 0 {
// RFC 8152 Section 13.1.1 says that x and y leading zero octets MUST be preserved,
// but the Go crypto/elliptic package trims them. So we relax the check
// here to allow for omitted leading zero octets, we will add them back
// when marshaling.
if len(x) > size {
return fmt.Errorf("invalid x size: expected lower or equal to %d, got %d", size, len(x))
}
if len(y) > size {
return fmt.Errorf("invalid y size: expected lower or equal to %d, got %d", size, len(y))
}
if len(d) > size {
return fmt.Errorf("invalid d size: expected lower or equal to %d, got %d", size, len(d))
}
}
switch crv {
case CurveX25519, CurveX448, CurveEd25519, CurveEd448:
return fmt.Errorf(
Expand All @@ -459,6 +474,12 @@ func (k Key) validate(op KeyOp) error {
if crv == CurveInvalid || (len(x) == 0 && len(d) == 0) {
return ErrInvalidKey
}
if len(x) > 0 && len(x) != ed25519.PublicKeySize {
return fmt.Errorf("invalid x size: expected %d, got %d", ed25519.PublicKeySize, len(x))
}
if len(d) > 0 && len(d) != ed25519.SeedSize {
return fmt.Errorf("invalid d size: expected %d, got %d", ed25519.SeedSize, len(d))
}
switch crv {
case CurveP256, CurveP384, CurveP521:
return fmt.Errorf(
Expand Down Expand Up @@ -540,6 +561,18 @@ func (k *Key) MarshalCBOR() ([]byte, error) {
existing[lbl] = struct{}{}
tmp[lbl] = v
}
if k.KeyType == KeyTypeEC2 {
// If EC2 key, ensure that x and y are padded to the correct size.
crv, x, y, _ := k.EC2()
if size := curveSize(crv); size > 0 {
if 0 < len(x) && len(x) < size {
tmp[KeyLabelEC2X] = append(make([]byte, size-len(x)), x...)
}
if 0 < len(y) && len(y) < size {
tmp[KeyLabelEC2Y] = append(make([]byte, size-len(y)), y...)
}
}
}
return encMode.Marshal(tmp)
}

Expand Down Expand Up @@ -672,14 +705,6 @@ func (k *Key) PrivateKey() (crypto.PrivateKey, error) {

switch alg {
case AlgorithmES256, AlgorithmES384, AlgorithmES512:
_, x, y, d := k.EC2()
// RFC8152 allows omitting X and Y from private keys;
// crypto.PrivateKey assumes they are available.
// see https://www.rfc-editor.org/rfc/rfc8152#section-13.1.1
if len(x) == 0 || len(y) == 0 {
return nil, ErrEC2NoPub
}

var curve elliptic.Curve

switch alg {
Expand All @@ -691,22 +716,24 @@ func (k *Key) PrivateKey() (crypto.PrivateKey, error) {
curve = elliptic.P521()
}

priv := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)},
D: new(big.Int),
_, x, y, d := k.EC2()
var bx, by *big.Int
if len(x) == 0 || len(y) == 0 {
bx, by = curve.ScalarBaseMult(d)
} else {
bx = new(big.Int).SetBytes(x)
by = new(big.Int).SetBytes(y)
}
priv.X.SetBytes(x)
priv.Y.SetBytes(y)
priv.D.SetBytes(d)
bd := new(big.Int).SetBytes(d)

return priv, nil
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{Curve: curve, X: bx, Y: by},
D: bd,
}, nil
case AlgorithmEd25519:
_, x, d := k.OKP()
// RFC8152 allows omitting X from private keys;
// crypto.PrivateKey assumes it is available.
// see https://www.rfc-editor.org/rfc/rfc8152#section-13.2
if len(x) == 0 {
return nil, ErrOKPNoPub
return ed25519.NewKeyFromSeed(d), nil
}

buf := make([]byte, ed25519.PrivateKeySize)
Expand Down Expand Up @@ -820,6 +847,19 @@ func algorithmFromEllipticCurve(c elliptic.Curve) Algorithm {
}
}

func curveSize(crv Curve) int {
var bitSize int
switch crv {
case CurveP256:
bitSize = elliptic.P256().Params().BitSize
case CurveP384:
bitSize = elliptic.P384().Params().BitSize
case CurveP521:
bitSize = elliptic.P521().Params().BitSize
}
return (bitSize + 7) / 8
}

func decodeBytes(dic map[interface{}]interface{}, lbl interface{}) (b []byte, ok bool, err error) {
val, ok := dic[lbl]
if !ok {
Expand Down
Loading

0 comments on commit 377ee80

Please sign in to comment.