diff --git a/sdks/go/pkg/beam/io/avroio/avroio.go b/sdks/go/pkg/beam/io/avroio/avroio.go index b282c4aa3047..b00c6d2eea00 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio.go +++ b/sdks/go/pkg/beam/io/avroio/avroio.go @@ -25,13 +25,16 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/linkedin/goavro/v2" ) func init() { - beam.RegisterFunction(expandFn) - beam.RegisterType(reflect.TypeOf((*avroReadFn)(nil)).Elem()) - beam.RegisterType(reflect.TypeOf((*writeAvroFn)(nil)).Elem()) + register.Function3x1(expandFn) + register.DoFn3x1[context.Context, string, func(beam.X), error]((*avroReadFn)(nil)) + register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil)) + register.Emitter1[beam.X]() + register.Iter1[string]() } // Read reads a set of files and returns lines as a PCollection diff --git a/sdks/go/pkg/beam/io/avroio/avroio_test.go b/sdks/go/pkg/beam/io/avroio/avroio_test.go index 8e2894133bfe..403a81875557 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio_test.go +++ b/sdks/go/pkg/beam/io/avroio/avroio_test.go @@ -25,12 +25,30 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" _ "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/local" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/linkedin/goavro/v2" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + +func init() { + beam.RegisterType(reflect.TypeOf((*Tweet)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableFloat64)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableString)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableTweet)(nil)).Elem()) + register.Function2x0(toJSONString) +} + +func toJSONString(user TwitterUser, emit func(string)) { + b, _ := json.Marshal(user) + emit(string(b)) +} + type Tweet struct { Stamp int64 `json:"timestamp"` Tweet string `json:"tweet"` @@ -126,16 +144,11 @@ func TestWrite(t *testing.T) { avroFile := "./user.avro" testUsername := "user1" testInfo := "userInfo" - p, s, sequence := ptest.CreateList([]string{testUsername}) - format := beam.ParDo(s, func(username string, emit func(string)) { - newUser := TwitterUser{ - User: username, - Info: testInfo, - } - - b, _ := json.Marshal(newUser) - emit(string(b)) - }, sequence) + p, s, sequence := ptest.CreateList([]TwitterUser{{ + User: testUsername, + Info: testInfo, + }}) + format := beam.ParDo(s, toJSONString, sequence) Write(s, avroFile, userSchema, format) t.Cleanup(func() { os.Remove(avroFile) diff --git a/sdks/go/pkg/beam/io/datastoreio/datastore_test.go b/sdks/go/pkg/beam/io/datastoreio/datastore_test.go index a18891bfd14d..345eaa2a59ef 100644 --- a/sdks/go/pkg/beam/io/datastoreio/datastore_test.go +++ b/sdks/go/pkg/beam/io/datastoreio/datastore_test.go @@ -29,6 +29,17 @@ import ( "google.golang.org/api/option" ) +func TestMain(m *testing.M) { + // TODO(https://github.com/apache/beam/issues/27549): Make tests compatible with portable runners. + // To work on this change, replace call with `ptest.Main(m)` + ptest.MainWithDefault(m, "direct") +} + +func init() { + beam.RegisterType(reflect.TypeOf((*Foo)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*Bar)(nil)).Elem()) +} + // fake client type implements datastoreio.clientType type fakeClient struct { runCounter int @@ -75,7 +86,7 @@ func Test_query(t *testing.T) { } itemType := reflect.TypeOf(tc.v) - itemKey := runtime.RegisterType(itemType) + itemKey, _ := runtime.TypeKey(itemType) p, s := beam.NewPipelineWithRoot() query(s, "project", "Item", tc.shard, itemType, itemKey, newClient) @@ -93,7 +104,12 @@ func Test_query(t *testing.T) { } } +// Baz is intentionally unregistered. +type Baz struct { +} + func Test_query_Bad(t *testing.T) { + fooKey, _ := runtime.TypeKey(reflect.TypeOf(Foo{})) testCases := []struct { v any itemType reflect.Type @@ -103,8 +119,8 @@ func Test_query_Bad(t *testing.T) { }{ // mismatch typeKey parameter { - Foo{}, - reflect.TypeOf(Foo{}), + Baz{}, + reflect.TypeOf(Baz{}), "MismatchType", "No type registered MismatchType", nil, @@ -113,7 +129,7 @@ func Test_query_Bad(t *testing.T) { { Foo{}, reflect.TypeOf(Foo{}), - runtime.RegisterType(reflect.TypeOf(Foo{})), + fooKey, "fake client error", errors.New("fake client error"), }, diff --git a/sdks/go/pkg/beam/io/spannerio/common.go b/sdks/go/pkg/beam/io/spannerio/common.go index 04cc2154a604..743a70d2fcff 100644 --- a/sdks/go/pkg/beam/io/spannerio/common.go +++ b/sdks/go/pkg/beam/io/spannerio/common.go @@ -18,9 +18,10 @@ package spannerio import ( - "cloud.google.com/go/spanner" "context" "fmt" + + "cloud.google.com/go/spanner" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" "google.golang.org/grpc" @@ -28,9 +29,9 @@ import ( ) type spannerFn struct { - Database string `json:"database"` // Database is the spanner connection string - endpoint string // Override spanner endpoint in tests - client *spanner.Client // Spanner Client + Database string `json:"database"` // Database is the spanner connection string + TestEndpoint string // Optional endpoint override for local testing. Not required for production pipelines. + client *spanner.Client // Spanner Client } func newSpannerFn(db string) spannerFn { @@ -48,9 +49,9 @@ func (f *spannerFn) Setup(ctx context.Context) error { var opts []option.ClientOption // Append emulator options assuming endpoint is local (for testing). - if f.endpoint != "" { + if f.TestEndpoint != "" { opts = []option.ClientOption{ - option.WithEndpoint(f.endpoint), + option.WithEndpoint(f.TestEndpoint), option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), option.WithoutAuthentication(), internaloption.SkipDialSettingsValidation(), diff --git a/sdks/go/pkg/beam/io/spannerio/read_test.go b/sdks/go/pkg/beam/io/spannerio/read_test.go index 1a7705b1aca2..7e1a65d0fe8a 100644 --- a/sdks/go/pkg/beam/io/spannerio/read_test.go +++ b/sdks/go/pkg/beam/io/spannerio/read_test.go @@ -27,6 +27,10 @@ import ( spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + func TestRead(t *testing.T) { ctx := context.Background() @@ -102,7 +106,7 @@ func TestRead(t *testing.T) { p, s := beam.NewPipelineWithRoot() fn := newQueryFn(testCase.database, "SELECT * from "+testCase.table, reflect.TypeOf(TestDto{}), queryOptions{}) - fn.endpoint = srv.Addr + fn.TestEndpoint = srv.Addr imp := beam.Impulse(s) rows := beam.ParDo(s, fn, imp, beam.TypeDefinition{Var: beam.XType, T: reflect.TypeOf(TestDto{})}) diff --git a/sdks/go/pkg/beam/io/spannerio/write_test.go b/sdks/go/pkg/beam/io/spannerio/write_test.go index f273315ba119..3c2c1f591519 100644 --- a/sdks/go/pkg/beam/io/spannerio/write_test.go +++ b/sdks/go/pkg/beam/io/spannerio/write_test.go @@ -17,12 +17,12 @@ package spannerio import ( "context" - spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" "testing" "cloud.google.com/go/spanner" "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" + spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" "google.golang.org/api/iterator" ) @@ -77,7 +77,7 @@ func TestWrite(t *testing.T) { p, s, col := ptest.CreateList(testCase.rows) fn := newWriteFn(testCase.database, testCase.table, col.Type().Type()) - fn.endpoint = srv.Addr + fn.TestEndpoint = srv.Addr beam.ParDo0(s, fn, col)