diff --git a/cryptobyte/asn1.go b/cryptobyte/asn1.go index 6fc2838a3f..2492f796af 100644 --- a/cryptobyte/asn1.go +++ b/cryptobyte/asn1.go @@ -733,13 +733,14 @@ func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag return true } -// ReadOptionalASN1Boolean sets *out to the value of the next ASN.1 BOOLEAN or, -// if the next bytes are not an ASN.1 BOOLEAN, to the value of defaultValue. -// It reports whether the operation was successful. -func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool { +// ReadOptionalASN1Boolean attempts to read an optional ASN.1 BOOLEAN +// explicitly tagged with tag into out and advances. If no element with a +// matching tag is present, it sets "out" to defaultValue instead. It reports +// whether the read was successful. +func (s *String) ReadOptionalASN1Boolean(out *bool, tag asn1.Tag, defaultValue bool) bool { var present bool var child String - if !s.ReadOptionalASN1(&child, &present, asn1.BOOLEAN) { + if !s.ReadOptionalASN1(&child, &present, tag) { return false } @@ -748,7 +749,7 @@ func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool { return true } - return s.ReadASN1Boolean(out) + return child.ReadASN1Boolean(out) } func (s *String) readASN1(out *String, outTag *asn1.Tag, skipHeader bool) bool { diff --git a/cryptobyte/asn1_test.go b/cryptobyte/asn1_test.go index e3f53a932e..93760b06e9 100644 --- a/cryptobyte/asn1_test.go +++ b/cryptobyte/asn1_test.go @@ -115,6 +115,28 @@ func TestReadASN1OptionalInteger(t *testing.T) { } } +const defaultBool = false + +var optionalBoolTestData = []readASN1Test{ + {"empty", []byte{}, 0xa0, true, false}, + {"invalid", []byte{0xa1, 0x3, 0x1, 0x2, 0x7f}, 0xa1, false, defaultBool}, + {"missing", []byte{0xa1, 0x3, 0x1, 0x1, 0x7f}, 0xa0, true, defaultBool}, + {"present", []byte{0xa1, 0x3, 0x1, 0x1, 0xff}, 0xa1, true, true}, +} + +func TestReadASN1OptionalBoolean(t *testing.T) { + for _, test := range optionalBoolTestData { + t.Run(test.name, func(t *testing.T) { + in := String(test.in) + var out bool + ok := in.ReadOptionalASN1Boolean(&out, test.tag, defaultBool) + if ok != test.ok || ok && out != test.out.(bool) { + t.Errorf("in.ReadOptionalASN1Boolean() = %v, want %v; out = %v, want %v", ok, test.ok, out, test.out) + } + }) + } +} + func TestReadASN1IntegerSigned(t *testing.T) { testData64 := []struct { in []byte