Skip to content

Commit

Permalink
Allow configuring the possible salt lengths for RSA PSS signatures (#…
Browse files Browse the repository at this point in the history
…16549)

* accommodate salt lengths for RSA PSS

* address feedback

* generalise salt length to an int

* fix error reporting

* Revert "fix error reporting"

This reverts commit 8adfc15fe3303b8fdf9f094ea246945ab1364077.

* fix a faulty check

* check for min/max salt lengths

* stringly-typed HTTP param

* unit tests for sign/verify HTTP requests

also, add marshaling for both SDK and HTTP requests

* randomly sample valid salt length

* add changelog

* add documentation
  • Loading branch information
trishankatdatadog authored Aug 31, 2022
1 parent 42645c0 commit 754c119
Show file tree
Hide file tree
Showing 7 changed files with 609 additions and 46 deletions.
66 changes: 64 additions & 2 deletions builtin/logical/transit/path_sign_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package transit

import (
"context"
"crypto/rsa"
"encoding/base64"
"fmt"
"strconv"
"strings"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
Expand Down Expand Up @@ -131,6 +134,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to marshal the signature. The default is 'asn1' which is used by openssl and X.509. It can also be set to 'jws' which is used for JWT signatures; setting it to this will also cause the encoding of the signature to be url-safe base64 instead of using standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},

"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -217,6 +227,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to unmarshal the signature when verifying. The default is 'asn1' which is used by openssl and X.509; can also be set to 'jws' which is used for JWT signatures in which case the signature is also expected to be url-safe base64 encoding instead of standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},

"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand All @@ -228,6 +245,33 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
}
}

func (b *backend) getSaltLength(d *framework.FieldData) (int, error) {
rawSaltLength, ok := d.GetOk("salt_length")
// This should only happen when something is wrong with the schema,
// so this is a reasonable default.
if !ok {
return rsa.PSSSaltLengthAuto, nil
}

rawSaltLengthStr := rawSaltLength.(string)
lowerSaltLengthStr := strings.ToLower(rawSaltLengthStr)
switch lowerSaltLengthStr {
case "auto":
return rsa.PSSSaltLengthAuto, nil
case "hash":
return rsa.PSSSaltLengthEqualsHash, nil
default:
saltLengthInt, err := strconv.Atoi(lowerSaltLengthStr)
if err != nil {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length neither 'auto', 'hash', nor an int: %s", rawSaltLength)
}
if saltLengthInt < rsa.PSSSaltLengthEqualsHash {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length is invalid: %d", saltLengthInt)
}
return saltLengthInt, nil
}
}

func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
ver := d.Get("key_version").(int)
Expand All @@ -252,6 +296,10 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr

prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Expand Down Expand Up @@ -330,7 +378,12 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
}
}

sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm, marshaling)
sig, err := p.SignWithOptions(ver, context, input, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
if batchInputRaw != nil {
response[i].Error = err.Error()
Expand Down Expand Up @@ -470,6 +523,10 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *

prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Expand Down Expand Up @@ -533,7 +590,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
}
}

valid, err := p.VerifySignature(context, input, hashAlgorithm, sigAlgorithm, marshaling, sig)
valid, err := p.VerifySignatureWithOptions(context, input, sig, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
switch err.(type) {
case errutil.UserError:
Expand Down
255 changes: 255 additions & 0 deletions builtin/logical/transit/path_sign_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,258 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
outcome[1].valid = false
verifyRequest(req, false, outcome, "bar", goodsig, true)
}

func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
t.Run("2048", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 2048)
})
t.Run("3072", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 3072)
})
t.Run("4096", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 4096)
})
}

func testTransit_SignVerify_RSA_PSS(t *testing.T, bits int) {
b, storage := createBackendWithSysView(t)

// First create a key
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
Data: map[string]interface{}{
"type": fmt.Sprintf("rsa-%d", bits),
},
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}

signRequest := func(errExpected bool, postpath string) string {
t.Helper()
req.Path = "sign/foo" + postpath
resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: should have gotten error response: %#v", *resp)
}
return ""
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
// Since we are reusing the same request, let's clear the salt length each time.
delete(req.Data, "salt_length")

value, ok := resp.Data["signature"]
if !ok {
t.Fatalf("no signature key found in returned data, got resp data %#v", resp.Data)
}
return value.(string)
}

verifyRequest := func(errExpected bool, postpath, sig string) {
t.Helper()
req.Path = "verify/foo" + postpath
req.Data["signature"] = sig
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
if errExpected {
return
}
t.Fatalf("got error: %v, sig was %v", err, sig)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.IsError() {
if errExpected {
return
}
t.Fatalf("bad: got error response: %#v", *resp)
}
value, ok := resp.Data["valid"]
if !ok {
t.Fatalf("no valid key found in returned data, got resp data %#v", resp.Data)
}
if !value.(bool) && !errExpected {
t.Fatalf("verification failed; req was %#v, resp is %#v", *req, *resp)
} else if value.(bool) && errExpected {
t.Fatalf("expected error and didn't get one; req was %#v, resp is %#v", *req, *resp)
}
// Since we are reusing the same request, let's clear the signature each time.
delete(req.Data, "signature")
}

newReqData := func(hashAlgorithm string, marshalingName string) map[string]interface{} {
return map[string]interface{}{
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
"signature_algorithm": "pss",
"hash_algorithm": hashAlgorithm,
"marshaling_algorithm": marshalingName,
}
}

signAndVerifyRequest := func(hashAlgorithm string, marshalingName string, signSaltLength string, signErrExpected bool, verifySaltLength string, verifyErrExpected bool) {
t.Log("\t\t\t", signSaltLength, "/", verifySaltLength)
req.Data = newReqData(hashAlgorithm, marshalingName)

req.Data["salt_length"] = signSaltLength
t.Log("\t\t\t\t", "sign req data:", req.Data)
sig := signRequest(signErrExpected, "")

req.Data["salt_length"] = verifySaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(verifyErrExpected, "", sig)
}

invalidSaltLengths := []string{"bar", "-2"}
t.Log("invalidSaltLengths:", invalidSaltLengths)

autoSaltLengths := []string{"auto", "0"}
t.Log("autoSaltLengths:", autoSaltLengths)

hashSaltLengths := []string{"hash", "-1"}
t.Log("hashSaltLengths:", hashSaltLengths)

positiveSaltLengths := []string{"1"}
t.Log("positiveSaltLengths:", positiveSaltLengths)

nonAutoSaltLengths := append(hashSaltLengths, positiveSaltLengths...)
t.Log("nonAutoSaltLengths:", nonAutoSaltLengths)

validSaltLengths := append(autoSaltLengths, nonAutoSaltLengths...)
t.Log("validSaltLengths:", validSaltLengths)

testCombinatorics := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "valid", "/", "invalid salt lengths")
for _, validSaltLength := range validSaltLengths {
for _, invalidSaltLength := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, invalidSaltLength, true)
}
}

t.Log("\t\t", "invalid", "/", "invalid salt lengths")
for _, invalidSaltLength1 := range invalidSaltLengths {
for _, invalidSaltLength2 := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength1, true, invalidSaltLength2, true)
}
}

t.Log("\t\t", "invalid", "/", "valid salt lengths")
for _, invalidSaltLength := range invalidSaltLengths {
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength, true, validSaltLength, true)
}
}

t.Log("\t\t", "valid", "/", "valid salt lengths")
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, validSaltLength, false)
}

t.Log("\t\t", "hash", "/", "hash salt lengths")
for _, hashSaltLength1 := range hashSaltLengths {
for _, hashSaltLength2 := range hashSaltLengths {
if hashSaltLength1 != hashSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength1, false, hashSaltLength2, false)
}
}
}

t.Log("\t\t", "hash", "/", "positive salt lengths")
for _, hashSaltLength := range hashSaltLengths {
for _, positiveSaltLength := range positiveSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength, false, positiveSaltLength, true)
}
}

t.Log("\t\t", "positive", "/", "hash salt lengths")
for _, positiveSaltLength := range positiveSaltLengths {
for _, hashSaltLength := range hashSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, positiveSaltLength, false, hashSaltLength, true)
}
}

t.Log("\t\t", "auto", "/", "auto salt lengths")
for _, autoSaltLength1 := range autoSaltLengths {
for _, autoSaltLength2 := range autoSaltLengths {
if autoSaltLength1 != autoSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength1, false, autoSaltLength2, false)
}
}
}

t.Log("\t\t", "auto", "/", "non-auto salt lengths")
for _, autoSaltLength := range autoSaltLengths {
for _, nonAutoSaltLength := range nonAutoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength, false, nonAutoSaltLength, true)
}
}

t.Log("\t\t", "non-auto", "/", "auto salt lengths")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
for _, autoSaltLength := range autoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, nonAutoSaltLength, false, autoSaltLength, false)
}
}
}

testAutoSignAndVerify := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "Make a signature with an implicit, automatic salt length")
req.Data = newReqData(hashAlgorithm, marshalingName)
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")

t.Log("\t\t", "Verify it with an implicit, automatic salt length")
t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)

t.Log("\t\t", "Verify it with an explicit, automatic salt length")
for _, autoSaltLength := range autoSaltLengths {
t.Log("\t\t\t", "auto", "/", autoSaltLength)
req.Data["salt_length"] = autoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}

t.Log("\t\t", "Try to verify it with an explicit, incorrect salt length")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
t.Log("\t\t\t", "auto", "/", nonAutoSaltLength)
req.Data["salt_length"] = nonAutoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(true, "", sig)
}

t.Log("\t\t", "Make a signature with an explicit, valid salt length & and verify it with an implicit, automatic salt length")
for _, validSaltLength := range validSaltLengths {
t.Log("\t\t\t", validSaltLength, "/", "auto")

req.Data = newReqData(hashAlgorithm, marshalingName)
req.Data["salt_length"] = validSaltLength
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")

t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}
}

for hashAlgorithm := range keysutil.HashTypeMap {
t.Log("Hash algorithm:", hashAlgorithm)
for marshalingName := range keysutil.MarshalingTypeMap {
t.Log("\t", "Marshaling type:", marshalingName)
testCombinatorics(hashAlgorithm, marshalingName)
testAutoSignAndVerify(hashAlgorithm, marshalingName)
}
}
}
3 changes: 3 additions & 0 deletions changelog/16549.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Allow configuring the possible salt lengths for RSA PSS signatures.
```
Loading

0 comments on commit 754c119

Please sign in to comment.