Skip to content

Commit

Permalink
Add iter.FromSeq{,2}{,Context} (#13)
Browse files Browse the repository at this point in the history
* Checkpoint.

* Add test coverage.

* Add coverage.

* Add coverage.
  • Loading branch information
bobg authored Feb 9, 2024
1 parent 09c6ead commit 9680c2c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
54 changes: 54 additions & 0 deletions iter/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// (a preview of which is available in Go 1.22 when building with GOEXPERIMENT=rangefunc).
package iter

import "context"

// Of is the interface implemented by iterators.
// It is called "Of" so that when qualified with this package name
// and instantiated with a member type,
Expand Down Expand Up @@ -98,3 +100,55 @@ func AllPairs[T, U any](inp Of[Pair[T, U]]) Seq2[T, U] {
}
}
}

// FromSeq converts a Go 1.23 iterator into an Of[T].
func FromSeq[T any](seq Seq[T]) Of[T] {
return Go(func(ch chan<- T) error {
seq(func(v T) bool {
ch <- v
return true
})
return nil
})
}

// FromSeqContext converts a Go 1.23 iterator into an Of[T].
func FromSeqContext[T any](ctx context.Context, seq Seq[T]) Of[T] {
return Go(func(ch chan<- T) error {
seq(func(v T) bool {
select {
case <-ctx.Done():
return false
case ch <- v:
return true
}
})
return ctx.Err()
})
}

// FromSeq2 converts a Go 1.23 pair iterator into an Of[Pair[T, U]].
func FromSeq2[T, U any](seq Seq2[T, U]) Of[Pair[T, U]] {
return Go(func(ch chan<- Pair[T, U]) error {
seq(func(t T, u U) bool {
ch <- Pair[T, U]{X: t, Y: u}
return true
})
return nil
})
}

// FromSeq2Context converts a Go 1.23 pair iterator into an Of[Pair[T, U]].
func FromSeq2Context[T, U any](ctx context.Context, seq Seq2[T, U]) Of[Pair[T, U]] {
return Go(func(ch chan<- Pair[T, U]) error {
seq(func(t T, u U) bool {
select {
case <-ctx.Done():
return false
case ch <- Pair[T, U]{X: t, Y: u}:
return true
}
})
return ctx.Err()
})
}
108 changes: 108 additions & 0 deletions iter/iter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package iter

import (
"context"
"errors"
"reflect"
"testing"
)
Expand Down Expand Up @@ -63,3 +65,109 @@ func testSeq2[T, U any](t *testing.T, seq Seq2[T, U], wantT []T, wantU []U) {
t.Errorf("got %v, want %v", gotU, wantU)
}
}

func TestFromSeq(t *testing.T) {
seq := func(yield func(int) bool) {
for i := 0; i < 5; i++ {
if !yield(i) {
break
}
}
}
it := FromSeq(seq)
got, err := ToSlice(it)
if err != nil {
t.Fatal(err)
}
want := []int{0, 1, 2, 3, 4}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

func TestFromSeq2(t *testing.T) {
names := []string{"Alice", "Bob", "Carol"}
seq2 := func(yield func(int, string) bool) {
for i, name := range names {
if !yield(i, name) {
break
}
}
}
it := FromSeq2(seq2)
got, err := ToSlice(it)
if err != nil {
t.Fatal(err)
}
want := []Pair[int, string]{{0, "Alice"}, {1, "Bob"}, {2, "Carol"}}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}

func TestFromSeqContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

seq := func(yield func(int) bool) {
for i := 0; i < 5; i++ {
if !yield(i) {
break
}
}
}

it := FromSeqContext(ctx, seq)

if !it.Next() {
t.Fatal("it.Next() returned false, want true")
}
val0 := it.Val()
if val0 != 0 {
t.Errorf("got %v, want 0", val0)
}

cancel()

got, err := ToSlice(it)
if !errors.Is(err, context.Canceled) {
t.Errorf("got %v, want %v", err, context.Canceled)
}
if len(got) > 0 {
t.Errorf("got %v, want []", got)
}
}

func TestFromSeq2Context(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

names := []string{"Alice", "Bob", "Carol"}
seq2 := func(yield func(int, string) bool) {
for i, name := range names {
if !yield(i, name) {
break
}
}
}

it := FromSeq2Context(ctx, seq2)

if !it.Next() {
t.Fatal("it.Next() returned false, want true")
}
val0 := it.Val()
if val0.X != 0 || val0.Y != "Alice" {
t.Errorf("got %v, want {0, Alice}", val0)
}

cancel()

got, err := ToSlice(it)
if !errors.Is(err, context.Canceled) {
t.Errorf("got %v, want %v", err, context.Canceled)
}
if len(got) > 0 {
t.Errorf("got %v, want []", got)
}
}

0 comments on commit 9680c2c

Please sign in to comment.