diff --git a/network/newstream/cursors_test.go b/network/newstream/cursors_test.go index bc1f9c0f2f..1dc80ad5dc 100644 --- a/network/newstream/cursors_test.go +++ b/network/newstream/cursors_test.go @@ -24,7 +24,6 @@ import ( "io/ioutil" "math" "math/rand" - "strconv" "testing" "time" @@ -406,7 +405,7 @@ func compareNodeBinsToStreams(t *testing.T, onesCursors map[string]uint64, other } for nameKey, cur := range onesCursors { - id, err := strconv.Atoi(parseID(nameKey).Key) + id, err := parseSyncKey(parseID(nameKey).Key) if err != nil { return err } @@ -424,7 +423,7 @@ func compareNodeBinsToStreamsWithDepth(t *testing.T, onesCursors map[string]uint } // inclusive test for nameKey, cur := range onesCursors { - bin, err := strconv.Atoi(parseID(nameKey).Key) + bin, err := parseSyncKey(parseID(nameKey).Key) if err != nil { return err } @@ -437,9 +436,9 @@ func compareNodeBinsToStreamsWithDepth(t *testing.T, onesCursors map[string]uint } // exclusive test - for i := 0; i < int(depth); i++ { + for i := uint8(0); i < uint8(depth); i++ { // should not have anything shallower than depth - id := NewID("SYNC", fmt.Sprintf("%d", i)) + id := NewID("SYNC", encodeSyncKey(i)) if _, ok := onesCursors[id.String()]; ok { return fmt.Errorf("oneCursors contains id %s, but it should not", id) } @@ -474,7 +473,7 @@ func TestCorrectCursorsExchangeRace(t *testing.T) { //create the response res := &StreamInfoRes{} for _, v := range msg.Streams { - cur, err := strconv.Atoi(v.Key) + cur, err := parseSyncKey(v.Key) if err != nil { t.Fatal(err) } @@ -568,11 +567,11 @@ CHECKSTREAMS: //get the pivot cursors for peer, assert equal to what is in `sub` for _, stream := range getAllSyncStreams() { cur, ok := otherPeer.getCursor(stream) - keyInt, err := strconv.Atoi(stream.Key) + keyInt, err := parseSyncKey(stream.Key) if err != nil { t.Fatal(err) } - shouldExist := checkKeyInSlice(keyInt, sub) + shouldExist := checkKeyInSlice(int(keyInt), sub) if shouldExist == ok { continue @@ -646,10 +645,10 @@ func (s *slipStreamMock) HandleMsg(ctx context.Context, msg interface{}) error { } func getAllSyncStreams() (streams []ID) { - for i := 0; i <= 16; i++ { + for i := uint8(0); i <= 16; i++ { streams = append(streams, ID{ Name: syncStreamName, - Key: fmt.Sprintf("%d", i), + Key: encodeSyncKey(i), }) } return diff --git a/network/newstream/sync_provider.go b/network/newstream/sync_provider.go index 42bde8db84..0309123b64 100644 --- a/network/newstream/sync_provider.go +++ b/network/newstream/sync_provider.go @@ -250,11 +250,11 @@ func (s *syncProvider) WantStream(p *Peer, streamID ID) bool { // check all subscriptions that should exist for this peer subBins, _ := syncSubscriptionsDiff(po, -1, depth, s.kad.MaxProxDisplay, s.syncBinsOnlyWithinDepth) - v, err := strconv.Atoi(streamID.Key) + v, err := parseSyncKey(streamID.Key) if err != nil { return false } - return checkKeyInSlice(v, subBins) + return checkKeyInSlice(int(v), subBins) } var ( @@ -322,7 +322,7 @@ func (s *syncProvider) updateSyncSubscriptions(p *Peer, subBins, quitBins []int) streams := make([]ID, l) for i, po := range subBins { - stream := NewID(s.StreamName(), strconv.Itoa(po)) + stream := NewID(s.StreamName(), encodeSyncKey(uint8(po))) _, err := p.getOrCreateInterval(p.peerStreamIntervalKey(stream)) if err != nil { p.logger.Error("got an error while trying to register initial streams", "stream", stream) @@ -340,7 +340,7 @@ func (s *syncProvider) updateSyncSubscriptions(p *Peer, subBins, quitBins []int) } for _, po := range quitBins { p.logger.Debug("stream unwanted, removing cursor info for peer", "bin", po) - p.deleteCursor(NewID(syncStreamName, strconv.Itoa(po))) + p.deleteCursor(NewID(syncStreamName, encodeSyncKey(uint8(po)))) } } @@ -432,14 +432,7 @@ func checkKeyInSlice(k int, slice []int) (found bool) { } func (s *syncProvider) ParseKey(streamKey string) (interface{}, error) { - b, err := strconv.Atoi(streamKey) - if err != nil { - return 0, err - } - if b < 0 || b > 16 { - return 0, errors.New("stream key out of range") - } - return uint8(b), nil + return parseSyncKey(streamKey) } func (s *syncProvider) EncodeKey(i interface{}) (string, error) { @@ -447,7 +440,7 @@ func (s *syncProvider) EncodeKey(i interface{}) (string, error) { if !ok { return "", errors.New("error encoding key") } - return fmt.Sprintf("%d", v), nil + return encodeSyncKey(v), nil } func (s *syncProvider) StreamName() string { return s.name } @@ -457,3 +450,18 @@ func (s *syncProvider) Boundedness() bool { return false } func (s *syncProvider) Autostart() bool { return s.autostart } func (s *syncProvider) Close() { close(s.quit) } + +func parseSyncKey(streamKey string) (uint8, error) { + b, err := strconv.ParseUint(streamKey, 36, 8) + if err != nil { + return 0, err + } + if b < 0 || b > chunk.MaxPO { + return 0, fmt.Errorf("stream key %v out of range", b) + } + return uint8(b), nil +} + +func encodeSyncKey(i uint8) string { + return strconv.FormatUint(uint64(i), 36) +}