diff --git a/Makefile b/Makefile index 996613bee..40abbd343 100644 --- a/Makefile +++ b/Makefile @@ -60,7 +60,7 @@ static: build .PHONY: unit unit: - $(GO) test -coverprofile=coverage.out $(SPECIFIC_UNIT_TEST) $(TAGS) $(TEST_RACE) -count=1 ./pkg/... ./alpha/... + $(GO) test -coverprofile=coverage.out -coverpkg=./... $(SPECIFIC_UNIT_TEST) $(TAGS) $(TEST_RACE) -count=1 ./pkg/... ./alpha/... .PHONY: sanity-check sanity-check: diff --git a/cmd/opm/serve/serve.go b/cmd/opm/serve/serve.go index 6853c245b..202b571b3 100644 --- a/cmd/opm/serve/serve.go +++ b/cmd/opm/serve/serve.go @@ -6,23 +6,24 @@ import ( "errors" "fmt" "net" - "os" - "sync" - "net/http" endpoint "net/http/pprof" + "os" + "os/signal" "runtime/pprof" + "sync" + "syscall" + "time" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/reflection" - "github.com/operator-framework/operator-registry/alpha/declcfg" "github.com/operator-framework/operator-registry/pkg/api" health "github.com/operator-framework/operator-registry/pkg/api/grpc_health_v1" "github.com/operator-framework/operator-registry/pkg/lib/dns" - "github.com/operator-framework/operator-registry/pkg/lib/graceful" "github.com/operator-framework/operator-registry/pkg/lib/log" "github.com/operator-framework/operator-registry/pkg/registry" "github.com/operator-framework/operator-registry/pkg/server" @@ -33,6 +34,7 @@ type serve struct { port string terminationLog string + serverTimeout time.Duration debug bool pprofAddr string @@ -75,16 +77,20 @@ will not be reflected in the served content. cmd.Flags().StringVarP(&s.terminationLog, "termination-log", "t", "/dev/termination-log", "path to a container termination log file") cmd.Flags().StringVarP(&s.port, "port", "p", "50051", "port number to serve on") cmd.Flags().StringVar(&s.pprofAddr, "pprof-addr", "", "address of startup profiling endpoint (addr:port format)") + cmd.Flags().DurationVar(&s.serverTimeout, "server-timeout", time.Second*30, "server-enforced timeout for grpc requests") return cmd } func (s *serve) run(ctx context.Context) error { p := newProfilerInterface(s.pprofAddr, s.logger) - p.startEndpoint() if err := p.startCpuProfileCache(); err != nil { return fmt.Errorf("could not start CPU profile: %v", err) } + ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer cancel() + eg, ctx := errgroup.WithContext(ctx) + // Immediately set up termination log err := log.AddDefaultWriterHooks(s.terminationLog) if err != nil { @@ -98,40 +104,84 @@ func (s *serve) run(ctx context.Context) error { s.logger = s.logger.WithFields(logrus.Fields{"configs": s.configDir, "port": s.port}) - cfg, err := declcfg.LoadFS(os.DirFS(s.configDir)) - if err != nil { - return fmt.Errorf("load declarative config directory: %v", err) - } - - m, err := declcfg.ConvertToModel(*cfg) - if err != nil { - return fmt.Errorf("could not build index model from declarative config: %v", err) - } - store, err := registry.NewQuerier(m) - defer store.Close() + store, err := registry.NewQuerierFromFS(os.DirFS(s.configDir)) if err != nil { return err } + defer store.Close() lis, err := net.Listen("tcp", ":"+s.port) if err != nil { s.logger.Fatalf("failed to listen: %s", err) } - grpcServer := grpc.NewServer() + var grpcServerOpts []grpc.ServerOption + if s.serverTimeout > 0 { + grpcServerOpts = append(grpcServerOpts, grpc.StreamInterceptor(func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + dss := &deadlinableServerStream{timeout: s.serverTimeout, ServerStream: ss} + defer dss.Cancel() + return handler(srv, dss) + }), grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + ctx, cancel := context.WithTimeout(ctx, s.serverTimeout) + defer cancel() + return handler(ctx, req) + })) + } + + grpcServer := grpc.NewServer(grpcServerOpts...) api.RegisterRegistryServer(grpcServer, server.NewRegistryServer(store)) health.RegisterHealthServer(grpcServer, server.NewHealthServer()) reflection.Register(grpcServer) - s.logger.Info("serving registry") - p.stopCpuProfileCache() - - return graceful.Shutdown(s.logger, func() error { - return grpcServer.Serve(lis) - }, func() { - grpcServer.GracefulStop() - p.stopEndpoint(p.logger.Context) + + eg.Go(func() error { + // All this channel stuff is necessary so that we can return from + // this function early when the context is cancelled. This is required + // to get `eg.Wait()` to unblock, so that we can proceed to gracefully + // shutting down. + errChan := make(chan error) + go func() { + s.logger.Info("serving registry") + errChan <- grpcServer.Serve(lis) + }() + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } + }) + eg.Go(func() error { + return p.listenAndServe(ctx) + }) + eg.Go(func() (err error) { + defer p.stopCpuProfileCache() + if err := store.Wait(ctx); err != nil { + return err + } + s.logger.Info("registry initialization complete") + return nil }) + // wait until both errgroup goroutines return and then + // return the first error that occurred (or nil) + err = eg.Wait() + + // stop the servers prior to handling the error returned + // from Wait(). + s.logger.Info("stopping grpc server") + grpcServer.GracefulStop() + if p.isEnabled() { + s.logger.Info("stopping http pprof server") + if err := p.shutdown(context.Background()); err != nil { + return err + } + } + + if !errors.Is(err, context.Canceled) { + return err + } + return nil + } // manages an HTTP pprof endpoint served by `server`, @@ -162,10 +212,10 @@ func (p *profilerInterface) isEnabled() bool { return p.addr != "" } -func (p *profilerInterface) startEndpoint() { +func (p *profilerInterface) listenAndServe(ctx context.Context) error { // short-circuit if not enabled if !p.isEnabled() { - return + return nil } mux := http.NewServeMux() @@ -181,14 +231,19 @@ func (p *profilerInterface) startEndpoint() { Handler: mux, } - // goroutine exits with main + errChan := make(chan error) go func() { - p.logger.Info("starting pprof endpoint") if err := p.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - p.logger.Fatal(err) + errChan <- err } }() + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } } func (p *profilerInterface) startCpuProfileCache() error { @@ -222,10 +277,8 @@ func (p *profilerInterface) httpHandler(w http.ResponseWriter, r *http.Request) w.Write(p.cache.Bytes()) } -func (p *profilerInterface) stopEndpoint(ctx context.Context) { - if err := p.server.Shutdown(ctx); err != nil { - p.logger.Fatal(err) - } +func (p *profilerInterface) shutdown(ctx context.Context) error { + return p.server.Shutdown(ctx) } func (p *profilerInterface) isCacheReady() bool { @@ -241,3 +294,34 @@ func (p *profilerInterface) setCacheReady() { p.cacheReady = true p.cacheLock.Unlock() } + +type deadlinableServerStream struct { + grpc.ServerStream + timeout time.Duration + + m sync.Mutex + cancelFunc func() + cancelled bool +} + +func (ss *deadlinableServerStream) Context() context.Context { + ss.m.Lock() + defer ss.m.Unlock() + if ss.cancelled { + ctx, cancel := context.WithCancel(ss.ServerStream.Context()) + cancel() + return ctx + } + ctx, cancel := context.WithTimeout(ss.ServerStream.Context(), ss.timeout) + ss.cancelFunc = cancel + return ctx +} + +func (ss *deadlinableServerStream) Cancel() { + ss.m.Lock() + defer ss.m.Unlock() + if ss.cancelFunc != nil { + ss.cancelFunc() + } + ss.cancelled = true +} diff --git a/pkg/lib/registry/registry_test.go b/pkg/lib/registry/registry_test.go index b92a088d9..cb67846cc 100644 --- a/pkg/lib/registry/registry_test.go +++ b/pkg/lib/registry/registry_test.go @@ -87,6 +87,8 @@ func newQuerier(t *testing.T, bundles []*model.Bundle) *registry.Querier { } reg, err := registry.NewQuerier(pkgs) require.NoError(t, err) + err = reg.Wait(context.Background()) + require.NoError(t, err) return reg } diff --git a/pkg/registry/query.go b/pkg/registry/query.go index 4ccf7eba4..b9413908c 100644 --- a/pkg/registry/query.go +++ b/pkg/registry/query.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "io/fs" "os" "path/filepath" "sort" + "github.com/operator-framework/operator-registry/alpha/declcfg" "github.com/operator-framework/operator-registry/alpha/model" "github.com/operator-framework/operator-registry/pkg/api" ) @@ -17,6 +19,9 @@ type Querier struct { tmpDir string apiBundles map[apiBundleKey]string + + initDone chan struct{} + initErr error } func (q Querier) Close() error { @@ -39,30 +44,67 @@ func (s *SliceBundleSender) Send(b *api.Bundle) error { var _ GRPCQuery = &Querier{} +func NewQuerierFromFS(fbcFS fs.FS) (*Querier, error) { + return newQuerier(func() (model.Model, error) { + return fsToModel(fbcFS) + }) +} + func NewQuerier(packages model.Model) (*Querier, error) { - q := &Querier{} + return newQuerier(func() (model.Model, error) { + return packages, nil + }) +} +func newQuerier(getPackages func() (model.Model, error)) (*Querier, error) { + q := &Querier{} tmpDir, err := os.MkdirTemp("", "opm-registry-querier-") if err != nil { return nil, err } q.tmpDir = tmpDir + q.initDone = make(chan struct{}) + go func() { + defer close(q.initDone) + packages, err := getPackages() + if err != nil { + q.initErr = err + return + } + if err := q.storeAPIBundles(packages); err != nil { + q.initErr = err + return + } + q.pkgs = packages + }() + return q, nil +} + +func fsToModel(fbcFS fs.FS) (model.Model, error) { + cfg, err := declcfg.LoadFS(fbcFS) + if err != nil { + return nil, err + } + + return declcfg.ConvertToModel(*cfg) +} +func (q *Querier) storeAPIBundles(packages model.Model) error { q.apiBundles = map[apiBundleKey]string{} for _, pkg := range packages { for _, ch := range pkg.Channels { for _, b := range ch.Bundles { apiBundle, err := api.ConvertModelBundleToAPIBundle(*b) if err != nil { - return q, err + return err } jsonBundle, err := json.Marshal(apiBundle) if err != nil { - return q, err + return err } - filename := filepath.Join(tmpDir, fmt.Sprintf("%s_%s_%s.json", pkg.Name, ch.Name, b.Name)) + filename := filepath.Join(q.tmpDir, fmt.Sprintf("%s_%s_%s.json", pkg.Name, ch.Name, b.Name)) if err := os.WriteFile(filename, jsonBundle, 0666); err != nil { - return q, err + return err } q.apiBundles[apiBundleKey{pkg.Name, ch.Name, b.Name}] = filename packages[pkg.Name].Channels[ch.Name].Bundles[b.Name] = &model.Bundle{ @@ -75,11 +117,24 @@ func NewQuerier(packages model.Model) (*Querier, error) { } } } - q.pkgs = packages - return q, nil + return nil } -func (q Querier) loadAPIBundle(k apiBundleKey) (*api.Bundle, error) { +// Wait waits for the querier initialization to complete. +// If initialization results in an error or if the provided +// context is cancelled, Wait returns the associated error. +// A returned nil error indicates that initialization completed +// successfully. +func (q *Querier) Wait(ctx context.Context) error { + select { + case <-q.initDone: + return q.initErr + case <-ctx.Done(): + return ctx.Err() + } +} + +func (q *Querier) loadAPIBundle(k apiBundleKey) (*api.Bundle, error) { filename, ok := q.apiBundles[k] if !ok { return nil, fmt.Errorf("package %q, channel %q, bundle %q not found", k.pkgName, k.chName, k.name) @@ -95,7 +150,10 @@ func (q Querier) loadAPIBundle(k apiBundleKey) (*api.Bundle, error) { return &b, nil } -func (q Querier) ListPackages(_ context.Context) ([]string, error) { +func (q *Querier) ListPackages(ctx context.Context) ([]string, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } var packages []string for pkgName := range q.pkgs { packages = append(packages, pkgName) @@ -103,7 +161,10 @@ func (q Querier) ListPackages(_ context.Context) ([]string, error) { return packages, nil } -func (q Querier) ListBundles(ctx context.Context) ([]*api.Bundle, error) { +func (q *Querier) ListBundles(ctx context.Context) ([]*api.Bundle, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } var bundleSender SliceBundleSender err := q.SendBundles(ctx, &bundleSender) @@ -114,7 +175,10 @@ func (q Querier) ListBundles(ctx context.Context) ([]*api.Bundle, error) { return bundleSender, nil } -func (q Querier) SendBundles(_ context.Context, s BundleSender) error { +func (q *Querier) SendBundles(ctx context.Context, s BundleSender) error { + if err := q.Wait(ctx); err != nil { + return err + } for _, pkg := range q.pkgs { for _, ch := range pkg.Channels { for _, b := range ch.Bundles { @@ -139,7 +203,10 @@ func (q Querier) SendBundles(_ context.Context, s BundleSender) error { return nil } -func (q Querier) GetPackage(_ context.Context, name string) (*PackageManifest, error) { +func (q *Querier) GetPackage(ctx context.Context, name string) (*PackageManifest, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } pkg, ok := q.pkgs[name] if !ok { return nil, fmt.Errorf("package %q not found", name) @@ -163,7 +230,10 @@ func (q Querier) GetPackage(_ context.Context, name string) (*PackageManifest, e }, nil } -func (q Querier) GetBundle(_ context.Context, pkgName, channelName, csvName string) (*api.Bundle, error) { +func (q *Querier) GetBundle(ctx context.Context, pkgName, channelName, csvName string) (*api.Bundle, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } pkg, ok := q.pkgs[pkgName] if !ok { return nil, fmt.Errorf("package %q not found", pkgName) @@ -187,7 +257,10 @@ func (q Querier) GetBundle(_ context.Context, pkgName, channelName, csvName stri return apiBundle, nil } -func (q Querier) GetBundleForChannel(_ context.Context, pkgName string, channelName string) (*api.Bundle, error) { +func (q *Querier) GetBundleForChannel(ctx context.Context, pkgName string, channelName string) (*api.Bundle, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } pkg, ok := q.pkgs[pkgName] if !ok { return nil, fmt.Errorf("package %q not found", pkgName) @@ -211,7 +284,10 @@ func (q Querier) GetBundleForChannel(_ context.Context, pkgName string, channelN return apiBundle, nil } -func (q Querier) GetChannelEntriesThatReplace(_ context.Context, name string) ([]*ChannelEntry, error) { +func (q *Querier) GetChannelEntriesThatReplace(ctx context.Context, name string) ([]*ChannelEntry, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } var entries []*ChannelEntry for _, pkg := range q.pkgs { @@ -227,7 +303,10 @@ func (q Querier) GetChannelEntriesThatReplace(_ context.Context, name string) ([ return entries, nil } -func (q Querier) GetBundleThatReplaces(_ context.Context, name, pkgName, channelName string) (*api.Bundle, error) { +func (q *Querier) GetBundleThatReplaces(ctx context.Context, name, pkgName, channelName string) (*api.Bundle, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } pkg, ok := q.pkgs[pkgName] if !ok { return nil, fmt.Errorf("package %s not found", pkgName) @@ -257,7 +336,10 @@ func (q Querier) GetBundleThatReplaces(_ context.Context, name, pkgName, channel return nil, fmt.Errorf("no entry found for package %q, channel %q", pkgName, channelName) } -func (q Querier) GetChannelEntriesThatProvide(_ context.Context, group, version, kind string) ([]*ChannelEntry, error) { +func (q *Querier) GetChannelEntriesThatProvide(ctx context.Context, group, version, kind string) ([]*ChannelEntry, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } var entries []*ChannelEntry for _, pkg := range q.pkgs { @@ -292,7 +374,10 @@ func (q Querier) GetChannelEntriesThatProvide(_ context.Context, group, version, // --- // Separate, but possibly related, I noticed there are several channels in the channel entry // table who's minimum depth is 1. What causes 1 to be minimum depth in some cases and 0 in others? -func (q Querier) GetLatestChannelEntriesThatProvide(_ context.Context, group, version, kind string) ([]*ChannelEntry, error) { +func (q *Querier) GetLatestChannelEntriesThatProvide(ctx context.Context, group, version, kind string) ([]*ChannelEntry, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } var entries []*ChannelEntry for _, pkg := range q.pkgs { @@ -317,7 +402,10 @@ func (q Querier) GetLatestChannelEntriesThatProvide(_ context.Context, group, ve return entries, nil } -func (q Querier) GetBundleThatProvides(ctx context.Context, group, version, kind string) (*api.Bundle, error) { +func (q *Querier) GetBundleThatProvides(ctx context.Context, group, version, kind string) (*api.Bundle, error) { + if err := q.Wait(ctx); err != nil { + return nil, err + } latestEntries, err := q.GetLatestChannelEntriesThatProvide(ctx, group, version, kind) if err != nil { return nil, err @@ -344,7 +432,7 @@ func (q Querier) GetBundleThatProvides(ctx context.Context, group, version, kind return nil, fmt.Errorf("no entry found that provides group:%q version:%q kind:%q", group, version, kind) } -func (q Querier) doesModelBundleProvide(b model.Bundle, group, version, kind string) (bool, error) { +func (q *Querier) doesModelBundleProvide(b model.Bundle, group, version, kind string) (bool, error) { apiBundle, err := q.loadAPIBundle(apiBundleKey{b.Package.Name, b.Channel.Name, b.Name}) if err != nil { return false, fmt.Errorf("convert bundle %q: %v", b.Name, err) diff --git a/pkg/registry/query_test.go b/pkg/registry/query_test.go index 1ba1613f3..25d4f68b8 100644 --- a/pkg/registry/query_test.go +++ b/pkg/registry/query_test.go @@ -2,12 +2,15 @@ package registry import ( "context" + "errors" + "io/fs" "testing" "testing/fstest" "github.com/stretchr/testify/require" "github.com/operator-framework/operator-registry/alpha/declcfg" + "github.com/operator-framework/operator-registry/alpha/model" ) func TestQuerier_GetBundle(t *testing.T) { @@ -196,6 +199,72 @@ func TestQuerier_ListPackages(t *testing.T) { require.Equal(t, 2, len(packages)) } +func TestQuerier_NewQuerierFromFSErrors(t *testing.T) { + type testcase struct { + name string + fsys fs.FS + expectedErr string + } + + testcases := []testcase{ + { + name: "InvalidFBC", + fsys: invalidFS, + expectedErr: "invalid index:\n└── invalid package \"cockroachdb\":\n ├── default channel must be set\n └── package must contain at least one channel", + }, + { + name: "LoadFBCFailure", + fsys: notFBCFS, + expectedErr: "json: cannot unmarshal string into Go value of type declcfg.tmp", + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + q, err := NewQuerierFromFS(tc.fsys) + defer func() { + require.NoError(t, q.Close()) + }() + require.NotNil(t, q) + require.Nil(t, err) + + expectErr := func(v interface{}, err error) { + require.Nil(t, v) + require.EqualError(t, err, tc.expectedErr) + } + + expectErr(q.GetBundle(context.Background(), "", "", "")) + expectErr(q.GetBundleForChannel(context.Background(), "", "")) + expectErr(q.GetBundleThatProvides(context.Background(), "", "", "")) + expectErr(q.GetBundleThatReplaces(context.Background(), "", "", "")) + expectErr(q.GetChannelEntriesThatProvide(context.Background(), "", "", "")) + expectErr(q.GetChannelEntriesThatReplace(context.Background(), "")) + expectErr(q.GetLatestChannelEntriesThatProvide(context.Background(), "", "", "")) + expectErr(q.GetPackage(context.Background(), "")) + expectErr(q.ListBundles(context.Background())) + expectErr(q.ListPackages(context.Background())) + require.EqualError(t, q.Wait(context.Background()), tc.expectedErr) + }) + } +} + +func TestQuerier_WaitContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + block := make(chan struct{}) + q, err := newQuerier(func() (model.Model, error) { + <-block + return model.Model{}, nil + }) + require.NoError(t, err) + + waitErr := q.Wait(ctx) + close(block) + if !errors.Is(waitErr, context.Canceled) { + t.Fatalf("expect querier Wait() to return context.Canceled error, got %v", waitErr) + } +} + func genTestModelQuerier(t *testing.T) *Querier { t.Helper() @@ -208,9 +277,27 @@ func genTestModelQuerier(t *testing.T) *Querier { reg, err := NewQuerier(m) require.NoError(t, err) + err = reg.Wait(context.Background()) + require.NoError(t, err) + return reg } +var invalidFS = fstest.MapFS{ + "cockroachdb.json": &fstest.MapFile{ + Data: []byte(`{ + "schema": "olm.package", + "name": "cockroachdb", + "defaultChannel": "", + "icon": { + "base64data": "PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAzMS44MiAzMiIgd2lkdGg9IjI0ODYiIGhlaWdodD0iMjUwMCI+PHRpdGxlPkNMPC90aXRsZT48cGF0aCBkPSJNMTkuNDIgOS4xN2ExNS4zOSAxNS4zOSAwIDAgMS0zLjUxLjQgMTUuNDYgMTUuNDYgMCAwIDEtMy41MS0uNCAxNS42MyAxNS42MyAwIDAgMSAzLjUxLTMuOTEgMTUuNzEgMTUuNzEgMCAwIDEgMy41MSAzLjkxek0zMCAuNTdBMTcuMjIgMTcuMjIgMCAwIDAgMjUuNTkgMGExNy40IDE3LjQgMCAwIDAtOS42OCAyLjkzQTE3LjM4IDE3LjM4IDAgMCAwIDYuMjMgMGExNy4yMiAxNy4yMiAwIDAgMC00LjQ0LjU3QTE2LjIyIDE2LjIyIDAgMCAwIDAgMS4xM2EuMDcuMDcgMCAwIDAgMCAuMDkgMTcuMzIgMTcuMzIgMCAwIDAgLjgzIDEuNTcuMDcuMDcgMCAwIDAgLjA4IDAgMTYuMzkgMTYuMzkgMCAwIDEgMS44MS0uNTQgMTUuNjUgMTUuNjUgMCAwIDEgMTEuNTkgMS44OCAxNy41MiAxNy41MiAwIDAgMC0zLjc4IDQuNDhjLS4yLjMyLS4zNy42NS0uNTUgMXMtLjIyLjQ1LS4zMy42OS0uMzEuNzItLjQ0IDEuMDhhMTcuNDYgMTcuNDYgMCAwIDAgNC4yOSAxOC43Yy4yNi4yNS41My40OS44MS43M3MuNDQuMzcuNjcuNTQuNTkuNDQuODkuNjRhLjA3LjA3IDAgMCAwIC4wOCAwYy4zLS4yMS42LS40Mi44OS0uNjRzLjQ1LS4zNS42Ny0uNTQuNTUtLjQ4LjgxLS43M2ExNy40NSAxNy40NSAwIDAgMCA1LjM4LTEyLjYxIDE3LjM5IDE3LjM5IDAgMCAwLTEuMDktNi4wOWMtLjE0LS4zNy0uMjktLjczLS40NS0xLjA5cy0uMjItLjQ3LS4zMy0uNjktLjM1LS42Ni0uNTUtMWExNy42MSAxNy42MSAwIDAgMC0zLjc4LTQuNDggMTUuNjUgMTUuNjUgMCAwIDEgMTEuNi0xLjg0IDE2LjEzIDE2LjEzIDAgMCAxIDEuODEuNTQuMDcuMDcgMCAwIDAgLjA4IDBxLjQ0LS43Ni44Mi0xLjU2YS4wNy4wNyAwIDAgMCAwLS4wOUExNi44OSAxNi44OSAwIDAgMCAzMCAuNTd6IiBmaWxsPSIjMTUxZjM0Ii8+PHBhdGggZD0iTTIxLjgyIDE3LjQ3YTE1LjUxIDE1LjUxIDAgMCAxLTQuMjUgMTAuNjkgMTUuNjYgMTUuNjYgMCAwIDEtLjcyLTQuNjggMTUuNSAxNS41IDAgMCAxIDQuMjUtMTAuNjkgMTUuNjIgMTUuNjIgMCAwIDEgLjcyIDQuNjgiIGZpbGw9IiMzNDg1NDAiLz48cGF0aCBkPSJNMTUgMjMuNDhhMTUuNTUgMTUuNTUgMCAwIDEtLjcyIDQuNjggMTUuNTQgMTUuNTQgMCAwIDEtMy41My0xNS4zN0ExNS41IDE1LjUgMCAwIDEgMTUgMjMuNDgiIGZpbGw9IiM3ZGJjNDIiLz48L3N2Zz4=", + "mediatype": "image/svg+xml" + } +} +`)}} + +var notFBCFS = fstest.MapFS{"txtfile": &fstest.MapFile{Data: []byte(`not fbc format`)}} + var validFS = fstest.MapFS{ "cockroachdb.json": &fstest.MapFile{ Data: []byte(`{ diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 09117ee84..90b1251a0 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -17,6 +17,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + "github.com/operator-framework/operator-registry/alpha/action" + "github.com/operator-framework/operator-registry/alpha/declcfg" "github.com/operator-framework/operator-registry/pkg/api" "github.com/operator-framework/operator-registry/pkg/registry" "github.com/operator-framework/operator-registry/pkg/sqlite" @@ -25,52 +27,54 @@ import ( const ( dbPort = ":50052" dbAddress = "localhost" + dbPort - dbName = "test.db" cfgPort = ":50053" cfgAddress = "localhost" + cfgPort -) -func dbStore(dbPath string) *sqlite.SQLQuerier { - _ = os.Remove(dbPath) + asyncCfgPort = ":50054" + asyncCfgAddress = "localhost" + asyncCfgPort +) +func dbStore(dbPath string) (*sqlite.SQLQuerier, error) { db, err := sqlite.Open(dbPath) if err != nil { - logrus.Fatal(err) + return nil, err } load, err := sqlite.NewSQLLiteLoader(db) if err != nil { - logrus.Fatal(err) + return nil, err } if err := load.Migrate(context.TODO()); err != nil { - logrus.Fatal(err) + return nil, err } loader := sqlite.NewSQLLoaderForDirectory(load, "../../manifests") if err := loader.Populate(); err != nil { - logrus.Fatal(err) + return nil, err } if err := db.Close(); err != nil { - logrus.Fatal(err) + return nil, err } store, err := sqlite.NewSQLLiteQuerier(dbPath) if err != nil { - logrus.Fatal(err) + return nil, err + } - return store + return store, nil } -func cfgStore() (*registry.Querier, error) { - tmpDir, err := ioutil.TempDir("", "server_test-") +func cfgStore(tmpDir string) (*registry.Querier, error) { + dbFile, err := os.CreateTemp(tmpDir, "cfgStore-*.db") + if err != nil { + return nil, err + } + dbFile.Close() + db, err := dbStore(dbFile.Name()) if err != nil { return nil, err } - defer os.RemoveAll(tmpDir) - - dbFile := filepath.Join(tmpDir, "test.db") - dbStore := dbStore(dbFile) - m, err := sqlite.ToModel(context.TODO(), dbStore) + m, err := sqlite.ToModel(context.TODO(), db) if err != nil { return nil, err } @@ -81,6 +85,29 @@ func cfgStore() (*registry.Querier, error) { return store, nil } +func asyncCfgStore(tmpDir string) (*registry.Querier, error) { + dbFile, err := os.CreateTemp(tmpDir, "asyncCfgStore-*.db") + if err != nil { + return nil, err + } + dbFile.Close() + if _, err := dbStore(dbFile.Name()); err != nil { + return nil, err + } + + fbcDir := filepath.Join(tmpDir, "catalog") + migrate := action.Migrate{CatalogRef: dbFile.Name(), OutputDir: fbcDir, WriteFunc: declcfg.WriteYAML, FileExt: ".yaml"} + if err := migrate.Run(context.Background()); err != nil { + return nil, err + } + + store, err := registry.NewQuerierFromFS(os.DirFS(fbcDir)) + if err != nil { + return nil, err + } + return store, nil +} + func server(store registry.GRPCQuery) *grpc.Server { s := grpc.NewServer() api.RegisterRegistryServer(s, NewRegistryServer(store)) @@ -88,14 +115,31 @@ func server(store registry.GRPCQuery) *grpc.Server { } func TestMain(m *testing.M) { - s1 := server(dbStore(dbName)) + tmpDir, err := ioutil.TempDir("", "server_test-") + if err != nil { + logrus.Fatalf("failed to create tmp dir: %v", err) + } + defer os.Remove(tmpDir) + + dbPath := filepath.Join(tmpDir, "sqlite-*.db") + dbQuerier, err := dbStore(dbPath) + if err != nil { + logrus.Fatalf("failed to create sqlite querier: %v", err) + } + s1 := server(dbQuerier) - cfgQuerier, err := cfgStore() - defer cfgQuerier.Close() + cfgQuerier, err := cfgStore(tmpDir) if err != nil { logrus.Fatalf("failed to create fbc querier: %v", err) } s2 := server(cfgQuerier) + + asyncCfgQuerier, err := asyncCfgStore(tmpDir) + if err != nil { + logrus.Fatalf("failed to create async fbc querier: %v", err) + } + s3 := server(asyncCfgQuerier) + go func() { lis, err := net.Listen("tcp", dbPort) if err != nil { @@ -114,9 +158,30 @@ func TestMain(m *testing.M) { logrus.Fatalf("failed to serve configs: %v", err) } }() + go func() { + lis, err := net.Listen("tcp", asyncCfgPort) + if err != nil { + logrus.Fatalf("failed to listen: %v", err) + } + if err := s3.Serve(lis); err != nil { + logrus.Fatalf("failed to serve configs: %v", err) + } + }() exit := m.Run() - if err := os.Remove(dbName); err != nil { - logrus.Fatalf("couldn't remove db") + s1.GracefulStop() + s2.GracefulStop() + s3.GracefulStop() + if err := cfgQuerier.Wait(context.Background()); err != nil { + logrus.Fatal(err) + } + if err := cfgQuerier.Close(); err != nil { + logrus.Fatal(err) + } + if err := asyncCfgQuerier.Wait(context.Background()); err != nil { + logrus.Fatal(err) + } + if err := asyncCfgQuerier.Close(); err != nil { + logrus.Fatal(err) } os.Exit(exit) } @@ -136,6 +201,7 @@ func client(t *testing.T, address string) (api.RegistryClient, *grpc.ClientConn) func TestListPackages(t *testing.T) { t.Run("Sqlite", testListPackages(dbAddress)) t.Run("DeclarativeConfig", testListPackages(cfgAddress)) + t.Run("AsyncDeclarativeConfig", testListPackages(asyncCfgAddress)) } func testListPackages(addr string) func(*testing.T) { @@ -168,6 +234,8 @@ func testListPackages(addr string) func(*testing.T) { func TestGetPackage(t *testing.T) { t.Run("Sqlite", testGetPackage(dbAddress)) t.Run("DeclarativeConfig", testGetPackage(cfgAddress)) + t.Run("AsyncDeclarativeConfig", testGetPackage(asyncCfgAddress)) + } func testGetPackage(addr string) func(*testing.T) { @@ -209,6 +277,7 @@ func testGetPackage(addr string) func(*testing.T) { func TestGetBundle(t *testing.T) { t.Run("Sqlite", testGetBundle(dbAddress, etcdoperator_v0_9_2("alpha", false, false))) t.Run("DeclarativeConfig", testGetBundle(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) + t.Run("AsyncDeclarativeConfig", testGetBundle(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) } func testGetBundle(addr string, expected *api.Bundle) func(*testing.T) { @@ -232,6 +301,7 @@ func TestGetBundleForChannel(t *testing.T) { })) } t.Run("DeclarativeConfig", testGetBundleForChannel(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) + t.Run("AsyncDeclarativeConfig", testGetBundleForChannel(asyncCfgAddress, etcdoperator_v0_9_2("alpha", false, true))) } func testGetBundleForChannel(addr string, expected *api.Bundle) func(*testing.T) { @@ -248,6 +318,7 @@ func testGetBundleForChannel(addr string, expected *api.Bundle) func(*testing.T) func TestGetChannelEntriesThatReplace(t *testing.T) { t.Run("Sqlite", testGetChannelEntriesThatReplace(dbAddress)) t.Run("DeclarativeConfig", testGetChannelEntriesThatReplace(cfgAddress)) + t.Run("AsyncDeclarativeConfig", testGetChannelEntriesThatReplace(asyncCfgAddress)) } func testGetChannelEntriesThatReplace(addr string) func(*testing.T) { @@ -324,6 +395,7 @@ func testGetChannelEntriesThatReplace(addr string) func(*testing.T) { func TestGetBundleThatReplaces(t *testing.T) { t.Run("Sqlite", testGetBundleThatReplaces(dbAddress, etcdoperator_v0_9_2("alpha", false, false))) t.Run("DeclarativeConfig", testGetBundleThatReplaces(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) + t.Run("AsyncDeclarativeConfig", testGetBundleThatReplaces(asyncCfgAddress, etcdoperator_v0_9_2("alpha", false, true))) } func testGetBundleThatReplaces(addr string, expected *api.Bundle) func(*testing.T) { @@ -340,6 +412,7 @@ func testGetBundleThatReplaces(addr string, expected *api.Bundle) func(*testing. func TestGetBundleThatReplacesSynthetic(t *testing.T) { t.Run("Sqlite", testGetBundleThatReplacesSynthetic(dbAddress, etcdoperator_v0_9_2("alpha", false, false))) t.Run("DeclarativeConfig", testGetBundleThatReplacesSynthetic(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) + t.Run("AsyncDeclarativeConfig", testGetBundleThatReplacesSynthetic(asyncCfgAddress, etcdoperator_v0_9_2("alpha", false, true))) } func testGetBundleThatReplacesSynthetic(addr string, expected *api.Bundle) func(*testing.T) { @@ -357,6 +430,7 @@ func testGetBundleThatReplacesSynthetic(addr string, expected *api.Bundle) func( func TestGetChannelEntriesThatProvide(t *testing.T) { t.Run("Sqlite", testGetChannelEntriesThatProvide(dbAddress)) t.Run("DeclarativeConfig", testGetChannelEntriesThatProvide(cfgAddress)) + t.Run("AsyncDeclarativeConfig", testGetChannelEntriesThatProvide(asyncCfgAddress)) } func testGetChannelEntriesThatProvide(addr string) func(t *testing.T) { @@ -474,6 +548,7 @@ func testGetChannelEntriesThatProvide(addr string) func(t *testing.T) { func TestGetLatestChannelEntriesThatProvide(t *testing.T) { t.Run("Sqlite", testGetLatestChannelEntriesThatProvide(dbAddress)) t.Run("DeclarativeConfig", testGetLatestChannelEntriesThatProvide(cfgAddress)) + t.Run("AsyncDeclarativeConfig", testGetLatestChannelEntriesThatProvide(asyncCfgAddress)) } func testGetLatestChannelEntriesThatProvide(addr string) func(t *testing.T) { @@ -550,6 +625,7 @@ func testGetLatestChannelEntriesThatProvide(addr string) func(t *testing.T) { func TestGetDefaultBundleThatProvides(t *testing.T) { t.Run("Sqlite", testGetDefaultBundleThatProvides(dbAddress, etcdoperator_v0_9_2("alpha", false, false))) t.Run("DeclarativeConfig", testGetDefaultBundleThatProvides(cfgAddress, etcdoperator_v0_9_2("alpha", false, true))) + t.Run("AsyncDeclarativeConfig", testGetDefaultBundleThatProvides(asyncCfgAddress, etcdoperator_v0_9_2("alpha", false, true))) } func testGetDefaultBundleThatProvides(addr string, expected *api.Bundle) func(*testing.T) { @@ -570,6 +646,9 @@ func TestListBundles(t *testing.T) { t.Run("DeclarativeConfig", testListBundles(cfgAddress, etcdoperator_v0_9_2("alpha", true, true), etcdoperator_v0_9_2("stable", true, true))) + t.Run("AsyncDeclarativeConfig", testListBundles(asyncCfgAddress, + etcdoperator_v0_9_2("alpha", true, true), + etcdoperator_v0_9_2("stable", true, true))) } func testListBundles(addr string, etcdAlpha *api.Bundle, etcdStable *api.Bundle) func(*testing.T) {