Skip to content

Commit

Permalink
[prism] middle sized fixes, avro, datastore, spanner (#27588)
Browse files Browse the repository at this point in the history
Co-authored-by: lostluck <[email protected]>
  • Loading branch information
lostluck and lostluck authored Jul 21, 2023
1 parent 89cd238 commit b54bf52
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 26 deletions.
9 changes: 6 additions & 3 deletions sdks/go/pkg/beam/io/avroio/avroio.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/linkedin/goavro/v2"
)

func init() {
beam.RegisterFunction(expandFn)
beam.RegisterType(reflect.TypeOf((*avroReadFn)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*writeAvroFn)(nil)).Elem())
register.Function3x1(expandFn)
register.DoFn3x1[context.Context, string, func(beam.X), error]((*avroReadFn)(nil))
register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil))
register.Emitter1[beam.X]()
register.Iter1[string]()
}

// Read reads a set of files and returns lines as a PCollection<elem>
Expand Down
33 changes: 23 additions & 10 deletions sdks/go/pkg/beam/io/avroio/avroio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,30 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/local"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"

"github.com/linkedin/goavro/v2"
)

func TestMain(m *testing.M) {
ptest.Main(m)
}

func init() {
beam.RegisterType(reflect.TypeOf((*Tweet)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*NullableFloat64)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*NullableString)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*NullableTweet)(nil)).Elem())
register.Function2x0(toJSONString)
}

func toJSONString(user TwitterUser, emit func(string)) {
b, _ := json.Marshal(user)
emit(string(b))
}

type Tweet struct {
Stamp int64 `json:"timestamp"`
Tweet string `json:"tweet"`
Expand Down Expand Up @@ -126,16 +144,11 @@ func TestWrite(t *testing.T) {
avroFile := "./user.avro"
testUsername := "user1"
testInfo := "userInfo"
p, s, sequence := ptest.CreateList([]string{testUsername})
format := beam.ParDo(s, func(username string, emit func(string)) {
newUser := TwitterUser{
User: username,
Info: testInfo,
}

b, _ := json.Marshal(newUser)
emit(string(b))
}, sequence)
p, s, sequence := ptest.CreateList([]TwitterUser{{
User: testUsername,
Info: testInfo,
}})
format := beam.ParDo(s, toJSONString, sequence)
Write(s, avroFile, userSchema, format)
t.Cleanup(func() {
os.Remove(avroFile)
Expand Down
24 changes: 20 additions & 4 deletions sdks/go/pkg/beam/io/datastoreio/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ import (
"google.golang.org/api/option"
)

func TestMain(m *testing.M) {
// TODO(https://github.com/apache/beam/issues/27549): Make tests compatible with portable runners.
// To work on this change, replace call with `ptest.Main(m)`
ptest.MainWithDefault(m, "direct")
}

func init() {
beam.RegisterType(reflect.TypeOf((*Foo)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*Bar)(nil)).Elem())
}

// fake client type implements datastoreio.clientType
type fakeClient struct {
runCounter int
Expand Down Expand Up @@ -75,7 +86,7 @@ func Test_query(t *testing.T) {
}

itemType := reflect.TypeOf(tc.v)
itemKey := runtime.RegisterType(itemType)
itemKey, _ := runtime.TypeKey(itemType)

p, s := beam.NewPipelineWithRoot()
query(s, "project", "Item", tc.shard, itemType, itemKey, newClient)
Expand All @@ -93,7 +104,12 @@ func Test_query(t *testing.T) {
}
}

// Baz is intentionally unregistered.
type Baz struct {
}

func Test_query_Bad(t *testing.T) {
fooKey, _ := runtime.TypeKey(reflect.TypeOf(Foo{}))
testCases := []struct {
v any
itemType reflect.Type
Expand All @@ -103,8 +119,8 @@ func Test_query_Bad(t *testing.T) {
}{
// mismatch typeKey parameter
{
Foo{},
reflect.TypeOf(Foo{}),
Baz{},
reflect.TypeOf(Baz{}),
"MismatchType",
"No type registered MismatchType",
nil,
Expand All @@ -113,7 +129,7 @@ func Test_query_Bad(t *testing.T) {
{
Foo{},
reflect.TypeOf(Foo{}),
runtime.RegisterType(reflect.TypeOf(Foo{})),
fooKey,
"fake client error",
errors.New("fake client error"),
},
Expand Down
13 changes: 7 additions & 6 deletions sdks/go/pkg/beam/io/spannerio/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
package spannerio

import (
"cloud.google.com/go/spanner"
"context"
"fmt"

"cloud.google.com/go/spanner"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

type spannerFn struct {
Database string `json:"database"` // Database is the spanner connection string
endpoint string // Override spanner endpoint in tests
client *spanner.Client // Spanner Client
Database string `json:"database"` // Database is the spanner connection string
TestEndpoint string // Optional endpoint override for local testing. Not required for production pipelines.
client *spanner.Client // Spanner Client
}

func newSpannerFn(db string) spannerFn {
Expand All @@ -48,9 +49,9 @@ func (f *spannerFn) Setup(ctx context.Context) error {
var opts []option.ClientOption

// Append emulator options assuming endpoint is local (for testing).
if f.endpoint != "" {
if f.TestEndpoint != "" {
opts = []option.ClientOption{
option.WithEndpoint(f.endpoint),
option.WithEndpoint(f.TestEndpoint),
option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())),
option.WithoutAuthentication(),
internaloption.SkipDialSettingsValidation(),
Expand Down
6 changes: 5 additions & 1 deletion sdks/go/pkg/beam/io/spannerio/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio"
)

func TestMain(m *testing.M) {
ptest.Main(m)
}

func TestRead(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -102,7 +106,7 @@ func TestRead(t *testing.T) {

p, s := beam.NewPipelineWithRoot()
fn := newQueryFn(testCase.database, "SELECT * from "+testCase.table, reflect.TypeOf(TestDto{}), queryOptions{})
fn.endpoint = srv.Addr
fn.TestEndpoint = srv.Addr

imp := beam.Impulse(s)
rows := beam.ParDo(s, fn, imp, beam.TypeDefinition{Var: beam.XType, T: reflect.TypeOf(TestDto{})})
Expand Down
4 changes: 2 additions & 2 deletions sdks/go/pkg/beam/io/spannerio/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ package spannerio

import (
"context"
spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio"
"testing"

"cloud.google.com/go/spanner"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio"
"google.golang.org/api/iterator"
)

Expand Down Expand Up @@ -77,7 +77,7 @@ func TestWrite(t *testing.T) {
p, s, col := ptest.CreateList(testCase.rows)

fn := newWriteFn(testCase.database, testCase.table, col.Type().Type())
fn.endpoint = srv.Addr
fn.TestEndpoint = srv.Addr

beam.ParDo0(s, fn, col)

Expand Down

0 comments on commit b54bf52

Please sign in to comment.