Skip to content

Commit

Permalink
fix: deterministic encode cose_sign / cose_sign1 sig_signature (#135)
Browse files Browse the repository at this point in the history
Signed-off-by: Shiwei Zhang <[email protected]>
  • Loading branch information
shizhMSFT authored Mar 10, 2023
1 parent 320246f commit 556edd4
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 1 deletion.
45 changes: 45 additions & 0 deletions cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,48 @@ func (s *byteString) UnmarshalCBOR(data []byte) error {
}
return decModeWithTagsForbidden.Unmarshal(data, (*[]byte)(s))
}

// deterministicBinaryString converts a bstr into the deterministic encoding.
//
// Reference: https://www.rfc-editor.org/rfc/rfc9052.html#section-9
func deterministicBinaryString(data cbor.RawMessage) (cbor.RawMessage, error) {
if len(data) == 0 {
return nil, io.EOF
}
if data[0]>>5 != 2 { // major type 2: bstr
return nil, errors.New("cbor: require bstr type")
}

// fast path: return immediately if bstr is already deterministic
if err := decModeWithTagsForbidden.Valid(data); err != nil {
return nil, err
}
ai := data[0] & 0x1f
if ai < 24 {
return data, nil
}
switch ai {
case 24:
if data[1] >= 24 {
return data, nil
}
case 25:
if data[1] != 0 {
return data, nil
}
case 26:
if data[1] != 0 || data[2] != 0 {
return data, nil
}
case 27:
if data[1] != 0 || data[2] != 0 || data[3] != 0 || data[4] != 0 {
return data, nil
}
}

// slow path: convert by re-encoding
// error checking is not required since `data` has been validataed
var s []byte
_ = decModeWithTagsForbidden.Unmarshal(data, &s)
return encMode.Marshal(s)
}
130 changes: 130 additions & 0 deletions cbor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package cose

import (
"bytes"
"reflect"
"testing"

"github.com/fxamacker/cbor/v2"
)

func Test_byteString_UnmarshalCBOR(t *testing.T) {
Expand Down Expand Up @@ -75,3 +78,130 @@ func Test_byteString_UnmarshalCBOR(t *testing.T) {
})
}
}

func Test_deterministicBinaryString(t *testing.T) {
gen := func(initial []byte, size int) []byte {
data := make([]byte, size+len(initial))
copy(data, initial)
return data
}
tests := []struct {
name string
data cbor.RawMessage
want cbor.RawMessage
wantErr bool
}{
{
name: "empty input",
data: nil,
wantErr: true,
},
{
name: "not bstr",
data: []byte{0x00},
wantErr: true,
},
{
name: "short length",
data: gen([]byte{0x57}, 23),
want: gen([]byte{0x57}, 23),
},
{
name: "optimal uint8 length",
data: gen([]byte{0x58, 0x18}, 24),
want: gen([]byte{0x58, 0x18}, 24),
},
{
name: "non-optimal uint8 length",
data: gen([]byte{0x58, 0x17}, 23),
want: gen([]byte{0x57}, 23),
},
{
name: "optimal uint16 length",
data: gen([]byte{0x59, 0x01, 0x00}, 256),
want: gen([]byte{0x59, 0x01, 0x00}, 256),
},
{
name: "non-optimal uint16 length, target short",
data: gen([]byte{0x59, 0x00, 0x17}, 23),
want: gen([]byte{0x57}, 23),
},
{
name: "non-optimal uint16 length, target uint8",
data: gen([]byte{0x59, 0x00, 0x18}, 24),
want: gen([]byte{0x58, 0x18}, 24),
},
{
name: "optimal uint32 length",
data: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536),
want: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536),
},
{
name: "non-optimal uint32 length, target short",
data: gen([]byte{0x5a, 0x00, 0x00, 0x00, 0x17}, 23),
want: gen([]byte{0x57}, 23),
},
{
name: "non-optimal uint32 length, target uint8",
data: gen([]byte{0x5a, 0x00, 0x00, 0x00, 0x18}, 24),
want: gen([]byte{0x58, 0x18}, 24),
},
{
name: "non-optimal uint32 length, target uint16",
data: gen([]byte{0x5a, 0x00, 0x00, 0x01, 0x00}, 256),
want: gen([]byte{0x59, 0x01, 0x00}, 256),
},
{
name: "non-optimal uint64 length, target short",
data: gen([]byte{0x5b,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x17,
}, 23),
want: gen([]byte{0x57}, 23),
},
{
name: "non-optimal uint64 length, target uint8",
data: gen([]byte{0x5b,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x18,
}, 24),
want: gen([]byte{0x58, 0x18}, 24),
},
{
name: "non-optimal uint64 length, target uint16",
data: gen([]byte{0x5b,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x01, 0x00,
}, 256),
want: gen([]byte{0x59, 0x01, 0x00}, 256),
},
{
name: "non-optimal uint64 length, target uint32",
data: gen([]byte{0x5b,
0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00,
}, 65536),
want: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536),
},
{
name: "early EOF",
data: gen([]byte{0x5b,
0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
}, 42),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := deterministicBinaryString(tt.data)
if (err != nil) != tt.wantErr {
t.Errorf("deterministicBinaryString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("deterministicBinaryString() = %v, want %v", got, tt.want)
}
})
}
}
10 changes: 9 additions & 1 deletion sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,16 @@ func (s *Signature) toBeSigned(bodyProtected cbor.RawMessage, payload, external
// external_aad : bstr,
// payload : bstr
// ]
bodyProtected, err := deterministicBinaryString(bodyProtected)
if err != nil {
return nil, err
}
var signProtected cbor.RawMessage
signProtected, err := s.Headers.MarshalProtected()
signProtected, err = s.Headers.MarshalProtected()
if err != nil {
return nil, err
}
signProtected, err = deterministicBinaryString(signProtected)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions sign1.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ func (m *Sign1Message) toBeSigned(external []byte) ([]byte, error) {
if err != nil {
return nil, err
}
protected, err = deterministicBinaryString(protected)
if err != nil {
return nil, err
}
if external == nil {
external = []byte{}
}
Expand Down
128 changes: 128 additions & 0 deletions sign1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"crypto/rand"
"reflect"
"testing"

"github.com/fxamacker/cbor/v2"
)

func TestSign1Message_MarshalCBOR(t *testing.T) {
Expand Down Expand Up @@ -837,3 +839,129 @@ func TestSign1Message_Verify(t *testing.T) {
}
})
}

// TestSign1Message_Verify_issue119: non-minimal protected header length
func TestSign1Message_Verify_issue119(t *testing.T) {
// generate key and set up signer / verifier
alg := AlgorithmES256
key := generateTestECDSAKey(t)
signer, err := NewSigner(alg, key)
if err != nil {
t.Fatalf("NewSigner() error = %v", err)
}
verifier, err := NewVerifier(alg, key.Public())
if err != nil {
t.Fatalf("NewVerifier() error = %v", err)
}

// generate message and sign
msg := &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
},
},
Payload: []byte("hello"),
}
if err := msg.Sign(rand.Reader, nil, signer); err != nil {
t.Fatalf("Sign1Message.Sign() error = %v", err)
}
data, err := msg.MarshalCBOR()
if err != nil {
t.Fatalf("Sign1Message.MarshalCBOR() error = %v", err)
}

// decanonicalize protected header
decanonicalize := func(data []byte) ([]byte, error) {
var content sign1Message
if err := decModeWithTagsForbidden.Unmarshal(data[1:], &content); err != nil {
return nil, err
}

protected := make([]byte, len(content.Protected)+1)
copy(protected[2:], content.Protected[1:])
protected[0] = 0x58
protected[1] = content.Protected[0] & 0x1f
content.Protected = protected

return encMode.Marshal(cbor.Tag{
Number: CBORTagSign1Message,
Content: content,
})
}
if data, err = decanonicalize(data); err != nil {
t.Fatalf("fail to decanonicalize: %v", err)
}

// verify message
var decoded Sign1Message
if err = decoded.UnmarshalCBOR(data); err != nil {
t.Fatalf("Sign1Message.UnmarshalCBOR() error = %v", err)
}
if err := decoded.Verify(nil, verifier); err != nil {
t.Fatalf("Sign1Message.Verify() error = %v", err)
}
}

func TestSign1Message_toBeSigned(t *testing.T) {
tests := []struct {
name string
m *Sign1Message
external []byte
want []byte
wantErr bool
}{
{
name: "valid message",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: algorithmMock,
},
},
Payload: []byte("hello world"),
},
want: []byte{
0x84, // array type
0x6a, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x31, // context
0x47, 0xa1, 0x01, 0x3a, 0x6d, 0x6f, 0x63, 0x6a, // protected
0x40, // external
0x4b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, // payload
},
},
{
name: "invalid protected header",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
1.5: nil,
},
},
Payload: []byte{},
},
wantErr: true,
},
{
name: "invalid raw protected header",
m: &Sign1Message{
Headers: Headers{
RawProtected: []byte{0x00},
},
Payload: []byte{},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.toBeSigned(tt.external)
if (err != nil) != tt.wantErr {
t.Errorf("Sign1Message.toBeSigned() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Sign1Message.toBeSigned() = %v, want %v", got, tt.want)
}
})
}
}
Loading

0 comments on commit 556edd4

Please sign in to comment.