Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configuring the possible salt lengths for RSA PSS signatures #16549

66 changes: 64 additions & 2 deletions builtin/logical/transit/path_sign_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package transit

import (
"context"
"crypto/rsa"
"encoding/base64"
"fmt"
"strconv"
"strings"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
Expand Down Expand Up @@ -131,6 +134,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to marshal the signature. The default is 'asn1' which is used by openssl and X.509. It can also be set to 'jws' which is used for JWT signatures; setting it to this will also cause the encoding of the signature to be url-safe base64 instead of using standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},

"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -217,6 +227,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to unmarshal the signature when verifying. The default is 'asn1' which is used by openssl and X.509; can also be set to 'jws' which is used for JWT signatures in which case the signature is also expected to be url-safe base64 encoding instead of standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},

"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand All @@ -228,6 +245,33 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
}
}

func (b *backend) getSaltLength(d *framework.FieldData) (int, error) {
rawSaltLength, ok := d.GetOk("salt_length")
// This should only happen when something is wrong with the schema,
// so this is a reasonable default.
if !ok {
return rsa.PSSSaltLengthAuto, nil
}

rawSaltLengthStr := rawSaltLength.(string)
lowerSaltLengthStr := strings.ToLower(rawSaltLengthStr)
switch lowerSaltLengthStr {
case "auto":
return rsa.PSSSaltLengthAuto, nil
case "hash":
return rsa.PSSSaltLengthEqualsHash, nil
default:
saltLengthInt, err := strconv.Atoi(lowerSaltLengthStr)
if err != nil {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length neither 'auto', 'hash', nor an int: %s", rawSaltLength)
}
if saltLengthInt < rsa.PSSSaltLengthEqualsHash {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length is invalid: %d", saltLengthInt)
}
return saltLengthInt, nil
}
}

func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
ver := d.Get("key_version").(int)
Expand All @@ -252,6 +296,10 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr

prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Expand Down Expand Up @@ -330,7 +378,12 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
}
}

sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm, marshaling)
sig, err := p.SignWithOptions(ver, context, input, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
if batchInputRaw != nil {
response[i].Error = err.Error()
Expand Down Expand Up @@ -470,6 +523,10 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *

prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Expand Down Expand Up @@ -533,7 +590,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
}
}

valid, err := p.VerifySignature(context, input, hashAlgorithm, sigAlgorithm, marshaling, sig)
valid, err := p.VerifySignatureWithOptions(context, input, sig, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
switch err.(type) {
case errutil.UserError:
Expand Down
255 changes: 255 additions & 0 deletions builtin/logical/transit/path_sign_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,258 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
outcome[1].valid = false
verifyRequest(req, false, outcome, "bar", goodsig, true)
}

func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
t.Run("2048", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 2048)
})
t.Run("3072", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 3072)
})
t.Run("4096", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 4096)
})
}

func testTransit_SignVerify_RSA_PSS(t *testing.T, bits int) {
b, storage := createBackendWithSysView(t)

// First create a key
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
Data: map[string]interface{}{
"type": fmt.Sprintf("rsa-%d", bits),
},
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}

signRequest := func(errExpected bool, postpath string) string {
t.Helper()
req.Path = "sign/foo" + postpath
resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: should have gotten error response: %#v", *resp)
}
return ""
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
// Since we are reusing the same request, let's clear the salt length each time.
delete(req.Data, "salt_length")

value, ok := resp.Data["signature"]
if !ok {
t.Fatalf("no signature key found in returned data, got resp data %#v", resp.Data)
}
return value.(string)
}

verifyRequest := func(errExpected bool, postpath, sig string) {
t.Helper()
req.Path = "verify/foo" + postpath
req.Data["signature"] = sig
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
if errExpected {
return
}
t.Fatalf("got error: %v, sig was %v", err, sig)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.IsError() {
if errExpected {
return
}
t.Fatalf("bad: got error response: %#v", *resp)
}
value, ok := resp.Data["valid"]
if !ok {
t.Fatalf("no valid key found in returned data, got resp data %#v", resp.Data)
}
if !value.(bool) && !errExpected {
t.Fatalf("verification failed; req was %#v, resp is %#v", *req, *resp)
} else if value.(bool) && errExpected {
t.Fatalf("expected error and didn't get one; req was %#v, resp is %#v", *req, *resp)
}
// Since we are reusing the same request, let's clear the signature each time.
delete(req.Data, "signature")
}

newReqData := func(hashAlgorithm string, marshalingName string) map[string]interface{} {
return map[string]interface{}{
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
"signature_algorithm": "pss",
"hash_algorithm": hashAlgorithm,
"marshaling_algorithm": marshalingName,
}
}

signAndVerifyRequest := func(hashAlgorithm string, marshalingName string, signSaltLength string, signErrExpected bool, verifySaltLength string, verifyErrExpected bool) {
t.Log("\t\t\t", signSaltLength, "/", verifySaltLength)
req.Data = newReqData(hashAlgorithm, marshalingName)

req.Data["salt_length"] = signSaltLength
t.Log("\t\t\t\t", "sign req data:", req.Data)
sig := signRequest(signErrExpected, "")

req.Data["salt_length"] = verifySaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(verifyErrExpected, "", sig)
}

invalidSaltLengths := []string{"bar", "-2"}
t.Log("invalidSaltLengths:", invalidSaltLengths)

autoSaltLengths := []string{"auto", "0"}
t.Log("autoSaltLengths:", autoSaltLengths)

hashSaltLengths := []string{"hash", "-1"}
t.Log("hashSaltLengths:", hashSaltLengths)

positiveSaltLengths := []string{"1"}
t.Log("positiveSaltLengths:", positiveSaltLengths)

nonAutoSaltLengths := append(hashSaltLengths, positiveSaltLengths...)
t.Log("nonAutoSaltLengths:", nonAutoSaltLengths)

validSaltLengths := append(autoSaltLengths, nonAutoSaltLengths...)
t.Log("validSaltLengths:", validSaltLengths)

testCombinatorics := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "valid", "/", "invalid salt lengths")
for _, validSaltLength := range validSaltLengths {
for _, invalidSaltLength := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, invalidSaltLength, true)
}
}

t.Log("\t\t", "invalid", "/", "invalid salt lengths")
for _, invalidSaltLength1 := range invalidSaltLengths {
for _, invalidSaltLength2 := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength1, true, invalidSaltLength2, true)
}
}

t.Log("\t\t", "invalid", "/", "valid salt lengths")
for _, invalidSaltLength := range invalidSaltLengths {
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength, true, validSaltLength, true)
}
}

t.Log("\t\t", "valid", "/", "valid salt lengths")
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, validSaltLength, false)
}

t.Log("\t\t", "hash", "/", "hash salt lengths")
for _, hashSaltLength1 := range hashSaltLengths {
for _, hashSaltLength2 := range hashSaltLengths {
if hashSaltLength1 != hashSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength1, false, hashSaltLength2, false)
}
}
}

t.Log("\t\t", "hash", "/", "positive salt lengths")
for _, hashSaltLength := range hashSaltLengths {
for _, positiveSaltLength := range positiveSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength, false, positiveSaltLength, true)
}
}

t.Log("\t\t", "positive", "/", "hash salt lengths")
for _, positiveSaltLength := range positiveSaltLengths {
for _, hashSaltLength := range hashSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, positiveSaltLength, false, hashSaltLength, true)
}
}

t.Log("\t\t", "auto", "/", "auto salt lengths")
for _, autoSaltLength1 := range autoSaltLengths {
for _, autoSaltLength2 := range autoSaltLengths {
if autoSaltLength1 != autoSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength1, false, autoSaltLength2, false)
}
}
}

t.Log("\t\t", "auto", "/", "non-auto salt lengths")
for _, autoSaltLength := range autoSaltLengths {
for _, nonAutoSaltLength := range nonAutoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength, false, nonAutoSaltLength, true)
}
}

t.Log("\t\t", "non-auto", "/", "auto salt lengths")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
for _, autoSaltLength := range autoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, nonAutoSaltLength, false, autoSaltLength, false)
}
}
}

testAutoSignAndVerify := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "Make a signature with an implicit, automatic salt length")
req.Data = newReqData(hashAlgorithm, marshalingName)
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")

t.Log("\t\t", "Verify it with an implicit, automatic salt length")
t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)

t.Log("\t\t", "Verify it with an explicit, automatic salt length")
for _, autoSaltLength := range autoSaltLengths {
t.Log("\t\t\t", "auto", "/", autoSaltLength)
req.Data["salt_length"] = autoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}

t.Log("\t\t", "Try to verify it with an explicit, incorrect salt length")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
t.Log("\t\t\t", "auto", "/", nonAutoSaltLength)
req.Data["salt_length"] = nonAutoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(true, "", sig)
}

t.Log("\t\t", "Make a signature with an explicit, valid salt length & and verify it with an implicit, automatic salt length")
for _, validSaltLength := range validSaltLengths {
t.Log("\t\t\t", validSaltLength, "/", "auto")

req.Data = newReqData(hashAlgorithm, marshalingName)
req.Data["salt_length"] = validSaltLength
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")

t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}
}

for hashAlgorithm := range keysutil.HashTypeMap {
t.Log("Hash algorithm:", hashAlgorithm)
for marshalingName := range keysutil.MarshalingTypeMap {
t.Log("\t", "Marshaling type:", marshalingName)
testCombinatorics(hashAlgorithm, marshalingName)
testAutoSignAndVerify(hashAlgorithm, marshalingName)
}
}
}
3 changes: 3 additions & 0 deletions changelog/16549.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Allow configuring the possible salt lengths for RSA PSS signatures.
```
Loading