Skip to content

Commit 66f8c42

Browse files
authored
From string validation (#75)
* chore: address linter issues Perform changes suggested by linter; no functional changes. * error.go: make ErrInvalidID a const Make ErrInvalidID a constant instead of a variable. This prevent it from being changed by external packages; a behavior that although allowed by the compiler, should probably be considered an invalid operation. * add benchmark and new failing test for FromString * fix: let decode look for additional base32 padding Update FromString and XID.TextUnmarshal so that it looks for discarded bits in the final source character. This ensures that XIDs that have been manually tampered with in a way that's ignored by base32 decode, will not pass as valid.
1 parent 1ac68e2 commit 66f8c42

File tree

3 files changed

+86
-11
lines changed

3 files changed

+86
-11
lines changed

error.go

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package xid
2+
3+
const (
4+
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
5+
ErrInvalidID strErr = "xid: invalid ID"
6+
)
7+
8+
// strErr allows declaring errors as constants.
9+
type strErr string
10+
11+
func (err strErr) Error() string { return string(err) }

id.go

+15-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ import (
4747
"crypto/rand"
4848
"database/sql/driver"
4949
"encoding/binary"
50-
"errors"
5150
"fmt"
5251
"hash/crc32"
5352
"io/ioutil"
@@ -73,9 +72,6 @@ const (
7372
)
7473

7574
var (
76-
// ErrInvalidID is returned when trying to unmarshal an invalid ID
77-
ErrInvalidID = errors.New("xid: invalid ID")
78-
7975
// objectIDCounter is atomically incremented when generating a new ObjectId
8076
// using NewObjectId() function. It's used as a counter part of an id.
8177
// This id is initialized with a random value.
@@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
242238
return ErrInvalidID
243239
}
244240
}
245-
decode(id, text)
241+
if !decode(id, text) {
242+
return ErrInvalidID
243+
}
246244
return nil
247245
}
248246

@@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error {
260258
return id.UnmarshalText(b[1 : len(b)-1])
261259
}
262260

263-
// decode by unrolling the stdlib base32 algorithm + removing all safe checks
264-
func decode(id *ID, src []byte) {
261+
// decode by unrolling the stdlib base32 algorithm + customized safe check.
262+
func decode(id *ID, src []byte) bool {
265263
_ = src[19]
266264
_ = id[11]
267265

@@ -277,6 +275,16 @@ func decode(id *ID, src []byte) {
277275
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
278276
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
279277
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
278+
279+
// Validate that there are no discarer bits (padding) in src that would
280+
// cause the string-encoded id not to equal src.
281+
var check [4]byte
282+
283+
check[3] = encoding[(id[11]<<4)&0x1F]
284+
check[2] = encoding[(id[11]>>1)&0x1F]
285+
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
286+
check[0] = encoding[id[10]>>3]
287+
return bytes.Equal([]byte(src[16:20]), check[:])
280288
}
281289

282290
// Time returns the timestamp part of the id.

id_test.go

+60-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"math/rand"
89
"reflect"
910
"testing"
11+
"testing/quick"
1012
"time"
1113
)
1214

@@ -19,21 +21,21 @@ type IDParts struct {
1921
}
2022

2123
var IDs = []IDParts{
22-
IDParts{
24+
{
2325
ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9},
2426
1300816219,
2527
[]byte{0x60, 0xf4, 0x86},
2628
0xe428,
2729
4271561,
2830
},
29-
IDParts{
31+
{
3032
ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
3133
0,
3234
[]byte{0x00, 0x00, 0x00},
3335
0x0000,
3436
0,
3537
},
36-
IDParts{
38+
{
3739
ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01},
3840
0,
3941
[]byte{0xaa, 0xbb, 0xcc},
@@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) {
252254
})
253255
}
254256

257+
func TestFromStringQuick(t *testing.T) {
258+
f := func(id1 ID, c byte) bool {
259+
s1 := id1.String()
260+
for i := range s1 {
261+
s2 := []byte(s1)
262+
s2[i] = c
263+
id2, err := FromString(string(s2))
264+
if id1 == id2 && err == nil && c != s1[i] {
265+
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
266+
return false
267+
}
268+
}
269+
return true
270+
}
271+
err := quick.Check(f, &quick.Config{
272+
Values: func(args []reflect.Value, r *rand.Rand) {
273+
i := r.Intn(len(encoding))
274+
args[0] = reflect.ValueOf(New())
275+
args[1] = reflect.ValueOf(byte(encoding[i]))
276+
},
277+
MaxCount: 1000,
278+
})
279+
if err != nil {
280+
t.Error(err)
281+
}
282+
}
283+
284+
func TestFromStringQuickInvalidChars(t *testing.T) {
285+
f := func(id1 ID, c byte) bool {
286+
s1 := id1.String()
287+
for i := range s1 {
288+
s2 := []byte(s1)
289+
s2[i] = c
290+
id2, err := FromString(string(s2))
291+
if id1 == id2 && err == nil && c != s1[i] {
292+
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
293+
return false
294+
}
295+
}
296+
return true
297+
}
298+
err := quick.Check(f, &quick.Config{
299+
Values: func(args []reflect.Value, r *rand.Rand) {
300+
i := r.Intn(0xFF)
301+
args[0] = reflect.ValueOf(New())
302+
args[1] = reflect.ValueOf(byte(i))
303+
},
304+
MaxCount: 2000,
305+
})
306+
if err != nil {
307+
t.Error(err)
308+
}
309+
}
310+
255311
// func BenchmarkUUIDv1(b *testing.B) {
256312
// b.RunParallel(func(pb *testing.PB) {
257313
// for pb.Next() {
@@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) {
329385
{13, true},
330386
}
331387
for _, c := range cases {
332-
b := make([]byte, c.length, c.length)
388+
b := make([]byte, c.length)
333389
_, err := FromBytes(b)
334390
if got, want := err != nil, c.shouldFail; got != want {
335391
t.Errorf("FromBytes() error got %v, want %v", got, want)

0 commit comments

Comments
 (0)