diff --git a/daemon/api.go b/daemon/api.go index 642af9ece1f..f32791037c5 100644 --- a/daemon/api.go +++ b/daemon/api.go @@ -169,8 +169,8 @@ var ( assertstateRefreshSnapAssertions = assertstate.RefreshSnapAssertions assertstateRestoreValidationSetsTracking = assertstate.RestoreValidationSetsTracking - registrystateGetViaView = registrystate.GetViaView - registrystateSetViaView = registrystate.SetViaView + registrystateGet = registrystate.Get + registrystateSet = registrystate.Set ) func ensureStateSoonImpl(st *state.State) { diff --git a/daemon/api_registry.go b/daemon/api_registry.go index 11073d0f747..32a8682e1df 100644 --- a/daemon/api_registry.go +++ b/daemon/api_registry.go @@ -60,7 +60,7 @@ func getView(c *Command, r *http.Request, _ *auth.UserState) Response { fields = strutil.CommaSeparatedList(fieldStr) } - results, err := registrystateGetViaView(st, account, registryName, view, fields) + results, err := registrystateGet(st, account, registryName, view, fields) if err != nil { return toAPIError(err) } @@ -86,7 +86,7 @@ func setView(c *Command, r *http.Request, _ *auth.UserState) Response { return BadRequest("cannot decode registry request body: %v", err) } - err := registrystateSetViaView(st, account, registryName, view, values) + err := registrystateSet(st, account, registryName, view, values) if err != nil { return toAPIError(err) } diff --git a/daemon/api_registry_test.go b/daemon/api_registry_test.go index 772ff7020d6..fbc2cbd3707 100644 --- a/daemon/api_registry_test.go +++ b/daemon/api_registry_test.go @@ -89,7 +89,7 @@ func (s *registrySuite) TestGetView(c *C) { {name: "map", value: map[string]int{"foo": 123}}, } { cmt := Commentf("%s test", t.name) - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registry, view string, fields []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registry, view string, fields []string) (interface{}, error) { c.Check(acc, Equals, "system", cmt) c.Check(registry, Equals, "network", cmt) c.Check(view, Equals, "wifi-setup", cmt) @@ -112,7 +112,7 @@ func (s *registrySuite) TestViewGetMany(c *C) { s.setFeatureFlag(c) var calls int - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { calls++ switch calls { case 1: @@ -137,7 +137,7 @@ func (s *registrySuite) TestViewGetSomeFieldNotFound(c *C) { s.setFeatureFlag(c) var calls int - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registry, view string, _ []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registry, view string, _ []string) (interface{}, error) { calls++ switch calls { case 1: @@ -162,7 +162,7 @@ func (s *registrySuite) TestGetViewNoFieldsFound(c *C) { s.setFeatureFlag(c) var calls int - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { calls++ switch calls { case 1: @@ -193,7 +193,7 @@ func (s *registrySuite) TestGetViewNoFieldsFound(c *C) { func (s *registrySuite) TestViewGetDatabagNotFound(c *C) { s.setFeatureFlag(c) - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { return nil, ®istry.NotFoundError{Account: "foo", RegistryName: "network", View: "wifi-setup", Operation: "get", Requests: []string{"ssid"}, Cause: "mocked"} }) defer restore() @@ -242,7 +242,7 @@ func (s *registrySuite) testViewSetMany(c *C) { s.setFeatureFlag(c) var calls int - restore := daemon.MockRegistrystateSetViaView(func(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error { calls++ switch calls { case 1: @@ -307,7 +307,7 @@ func (s *registrySuite) TestGetViewError(c *C) { {name: "registry not found", err: ®istry.NotFoundError{}, code: 404}, {name: "internal", err: errors.New("internal"), code: 500}, } { - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) { return nil, t.err }) @@ -324,7 +324,7 @@ func (s *registrySuite) TestGetViewMisshapenQuery(c *C) { s.setFeatureFlag(c) var calls int - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { calls++ switch calls { case 1: @@ -361,7 +361,7 @@ func (s *registrySuite) TestSetView(c *C) { {name: "map", value: map[string]interface{}{"foo": "bar"}}, } { cmt := Commentf("%s test", t.name) - restore := daemon.MockRegistrystateSetViaView(func(st *state.State, acc, registryName, view string, requests map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(st *state.State, acc, registryName, view string, requests map[string]interface{}) error { c.Check(acc, Equals, "system", cmt) c.Check(registryName, Equals, "network", cmt) c.Check(view, Equals, "wifi-setup", cmt) @@ -412,7 +412,7 @@ func (s *registrySuite) TestSetView(c *C) { func (s *registrySuite) TestUnsetView(c *C) { s.setFeatureFlag(c) - restore := daemon.MockRegistrystateSetViaView(func(_ *state.State, acc, registryName, view string, requests map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(_ *state.State, acc, registryName, view string, requests map[string]interface{}) error { c.Check(acc, Equals, "system") c.Check(registryName, Equals, "network") c.Check(view, Equals, "wifi-setup") @@ -452,7 +452,7 @@ func (s *registrySuite) TestSetViewError(c *C) { {name: "not found", err: ®istry.NotFoundError{}, code: 404}, {name: "internal", err: errors.New("internal"), code: 500}, } { - restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error { return t.err }) cmt := Commentf("%s test", t.name) @@ -471,7 +471,7 @@ func (s *registrySuite) TestSetViewError(c *C) { func (s *registrySuite) TestSetViewEmptyBody(c *C) { s.setFeatureFlag(c) - restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error { err := errors.New("unexpected call to registrystate.Set") c.Error(err) return err @@ -501,7 +501,7 @@ func (s *registrySuite) TestSetViewBadRequest(c *C) { func (s *registrySuite) TestGetBadRequest(c *C) { s.setFeatureFlag(c) - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registryName, view string, fields []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registryName, view string, fields []string) (interface{}, error) { return nil, ®istry.BadRequestError{ Account: "acc", RegistryName: "reg", @@ -525,7 +525,7 @@ func (s *registrySuite) TestGetBadRequest(c *C) { func (s *registrySuite) TestSetBadRequest(c *C) { s.setFeatureFlag(c) - restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error { return ®istry.BadRequestError{ Account: "acc", RegistryName: "reg", @@ -549,7 +549,7 @@ func (s *registrySuite) TestSetBadRequest(c *C) { } func (s *registrySuite) TestSetFailUnsetFeatureFlag(c *C) { - restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error { err := fmt.Errorf("unexpected call to registrystate") c.Error(err) return err @@ -568,7 +568,7 @@ func (s *registrySuite) TestSetFailUnsetFeatureFlag(c *C) { } func (s *registrySuite) TestGetFailUnsetFeatureFlag(c *C) { - restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error { + restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error { err := fmt.Errorf("unexpected call to registrystate") c.Error(err) return err @@ -588,7 +588,7 @@ func (s *registrySuite) TestGetNoFields(c *C) { s.setFeatureFlag(c) value := map[string]interface{}{"foo": 1, "bar": "baz", "nested": map[string]interface{}{"a": []interface{}{1, 2}}} - restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { + restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) { c.Check(fields, IsNil) return value, nil }) diff --git a/daemon/export_test.go b/daemon/export_test.go index b0735033bef..de65d3af093 100644 --- a/daemon/export_test.go +++ b/daemon/export_test.go @@ -379,19 +379,19 @@ var ( MaxReadBuflen = maxReadBuflen ) -func MockRegistrystateGetViaView(f func(_ *state.State, _, _, _ string, _ []string) (interface{}, error)) (restore func()) { - old := registrystateGetViaView - registrystateGetViaView = f +func MockRegistrystateGet(f func(_ *state.State, _, _, _ string, _ []string) (interface{}, error)) (restore func()) { + old := registrystateGet + registrystateGet = f return func() { - registrystateGetViaView = old + registrystateGet = old } } -func MockRegistrystateSetViaView(f func(_ *state.State, _, _, _ string, _ map[string]interface{}) error) (restore func()) { - old := registrystateSetViaView - registrystateSetViaView = f +func MockRegistrystateSet(f func(_ *state.State, _, _, _ string, _ map[string]interface{}) error) (restore func()) { + old := registrystateSet + registrystateSet = f return func() { - registrystateSetViaView = old + registrystateSet = old } } diff --git a/overlord/hookstate/ctlcmd/export_test.go b/overlord/hookstate/ctlcmd/export_test.go index f4fd11e5c31..3f006039a8a 100644 --- a/overlord/hookstate/ctlcmd/export_test.go +++ b/overlord/hookstate/ctlcmd/export_test.go @@ -28,9 +28,11 @@ import ( "github.com/snapcore/snapd/client/clientutil" "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/servicestate" "github.com/snapcore/snapd/overlord/snapstate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/registry" "github.com/snapcore/snapd/snap" "github.com/snapcore/snapd/testutil" ) @@ -177,3 +179,11 @@ func MockNewStatusDecorator(f func(ctx context.Context, isGlobal bool, uid strin newStatusDecorator = f return restore } + +func MockRegistrystateRegistryTransaction(f func(*hookstate.Context, *registry.Registry) (*registrystate.Transaction, error)) (restore func()) { + old := registrystateRegistryTransaction + registrystateRegistryTransaction = f + return func() { + registrystateRegistryTransaction = old + } +} diff --git a/overlord/hookstate/ctlcmd/get.go b/overlord/hookstate/ctlcmd/get.go index b94388fad2a..53cf3c2b08d 100644 --- a/overlord/hookstate/ctlcmd/get.go +++ b/overlord/hookstate/ctlcmd/get.go @@ -46,6 +46,7 @@ type getCommand struct { ForceSlotSide bool `long:"slot" description:"return attribute values from the slot side of the connection"` ForcePlugSide bool `long:"plug" description:"return attribute values from the plug side of the connection"` View bool `long:"view" description:"return registry values from the view declared in the plug"` + Pristine bool `long:"pristine" description:"return registry values disregarding changes from the current transaction"` Positional struct { PlugOrSlotSpec string `positional-args:"true" positional-arg-name:":"` @@ -159,6 +160,9 @@ func (c *getCommand) Execute(args []string) error { if c.Typed && c.Document { return fmt.Errorf("cannot use -d and -t together") } + if c.Pristine && !c.View { + return fmt.Errorf("cannot use --pristine without --view") + } if strings.Contains(c.Positional.PlugOrSlotSpec, ":") { parts := strings.SplitN(c.Positional.PlugOrSlotSpec, ":", 2) @@ -176,7 +180,7 @@ func (c *getCommand) Execute(args []string) error { if c.View { requests := c.Positional.Keys - return c.getRegistryValues(context, name, requests) + return c.getRegistryValues(context, name, requests, c.Pristine) } return c.getInterfaceSetting(context, name) } @@ -357,7 +361,9 @@ func (c *getCommand) getInterfaceSetting(context *hookstate.Context, plugOrSlot }) } -func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, requests []string) error { +var registrystateRegistryTransaction = registrystate.RegistryTransaction + +func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, requests []string, pristine bool) error { if c.ForcePlugSide || c.ForceSlotSide { return errors.New(i18n.G("cannot use --plug or --slot with --view")) } @@ -369,12 +375,17 @@ func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, return fmt.Errorf("cannot get registry: %v", err) } - tx, err := registrystate.RegistryTransaction(ctx, view.Registry()) + tx, err := registrystateRegistryTransaction(ctx, view.Registry()) if err != nil { return err } - res, err := registrystate.GetViaViewInTx(tx, view, requests) + bag := registry.DataBag(tx) + if pristine { + bag = tx.Pristine() + } + + res, err := registrystate.GetViaView(bag, view, requests) if err != nil { return err } diff --git a/overlord/hookstate/ctlcmd/get_test.go b/overlord/hookstate/ctlcmd/get_test.go index 60665df801a..4fca3371534 100644 --- a/overlord/hookstate/ctlcmd/get_test.go +++ b/overlord/hookstate/ctlcmd/get_test.go @@ -40,6 +40,7 @@ import ( "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/registry" "github.com/snapcore/snapd/snap" "github.com/snapcore/snapd/testutil" ) @@ -591,7 +592,7 @@ slots: func (s *registrySuite) TestRegistryGetSingleView(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ "ssid": "my-ssid", }) s.state.Unlock() @@ -605,7 +606,7 @@ func (s *registrySuite) TestRegistryGetSingleView(c *C) { func (s *registrySuite) TestRegistryGetManyViews(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ "ssid": "my-ssid", "password": "secret", }) @@ -624,7 +625,7 @@ func (s *registrySuite) TestRegistryGetManyViews(c *C) { func (s *registrySuite) TestRegistryGetNoRequest(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ "ssid": "my-ssid", "password": "secret", }) @@ -643,7 +644,7 @@ func (s *registrySuite) TestRegistryGetNoRequest(c *C) { func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ "ssid": "my-ssid", }) s.state.Unlock() @@ -659,7 +660,7 @@ func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) { c.Check(stderr, IsNil) s.state.Lock() - err = registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + err = registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ "ssid": "other-ssid", }) s.state.Unlock() @@ -847,3 +848,42 @@ func (s *registrySuite) TestRegistryGetAndSetViewNotFound(c *C) { c.Check(stdout, IsNil) c.Check(stderr, IsNil) } + +func (s *registrySuite) TestRegistryGetPristine(c *C) { + s.state.Lock() + defer s.state.Unlock() + + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "foo", + }) + c.Assert(err, IsNil) + + task := s.state.NewTask("run-hook", "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "save-view-plug"} + ctx, err := hookstate.NewContext(task, s.state, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + + err = tx.Set("wifi.ssid", "bar") + c.Assert(err, IsNil) + + restore := ctlcmd.MockRegistrystateRegistryTransaction(func(*hookstate.Context, *registry.Registry) (*registrystate.Transaction, error) { + return tx, nil + }) + defer restore() + + s.state.Unlock() + defer s.state.Lock() + + stdout, stderr, err := ctlcmd.Run(ctx, []string{"get", "--view", "--pristine", ":read-wifi", "ssid"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, "foo\n") + c.Check(stderr, IsNil) + + stdout, stderr, err = ctlcmd.Run(ctx, []string{"get", "--view", ":read-wifi", "ssid"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, "bar\n") + c.Check(stderr, IsNil) +} diff --git a/overlord/hookstate/ctlcmd/set.go b/overlord/hookstate/ctlcmd/set.go index 47b7b69beb4..ac50dd36bf0 100644 --- a/overlord/hookstate/ctlcmd/set.go +++ b/overlord/hookstate/ctlcmd/set.go @@ -243,5 +243,5 @@ func setRegistryValues(ctx *hookstate.Context, plugName string, requests map[str // TODO: once we have hooks, check that we don't set values in the wrong hooks // (e.g., "registry-changed" hooks can only read data) - return registrystate.SetViaViewInTx(tx, view, requests) + return registrystate.SetViaView(tx, view, requests) } diff --git a/overlord/hookstate/ctlcmd/set_test.go b/overlord/hookstate/ctlcmd/set_test.go index e12a7a16deb..ae43b93eb97 100644 --- a/overlord/hookstate/ctlcmd/set_test.go +++ b/overlord/hookstate/ctlcmd/set_test.go @@ -414,7 +414,7 @@ func (s *registrySuite) TestRegistrySetSingleView(c *C) { s.mockContext.Unlock() s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) + val, err := registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) s.state.Unlock() c.Assert(err, IsNil) c.Assert(val, DeepEquals, map[string]interface{}{"ssid": "other-ssid"}) @@ -430,7 +430,7 @@ func (s *registrySuite) TestRegistrySetManyViews(c *C) { s.mockContext.Unlock() s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid", "password"}) + val, err := registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid", "password"}) s.state.Unlock() c.Assert(err, IsNil) c.Assert(val, DeepEquals, map[string]interface{}{ @@ -447,7 +447,7 @@ func (s *registrySuite) TestRegistrySetHappensTransactionally(c *C) { c.Check(stderr, IsNil) s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) + _, err = registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) s.state.Unlock() c.Assert(err, ErrorMatches, ".*matching rules don't map to any values") @@ -457,7 +457,7 @@ func (s *registrySuite) TestRegistrySetHappensTransactionally(c *C) { s.mockContext.Unlock() s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) + val, err := registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) s.state.Unlock() c.Assert(err, IsNil) c.Assert(val, DeepEquals, map[string]interface{}{ diff --git a/overlord/hookstate/ctlcmd/unset_test.go b/overlord/hookstate/ctlcmd/unset_test.go index 01e999bf2f7..afb0b9da57a 100644 --- a/overlord/hookstate/ctlcmd/unset_test.go +++ b/overlord/hookstate/ctlcmd/unset_test.go @@ -164,7 +164,7 @@ func (s *unsetSuite) TestCommandWithoutContext(c *C) { func (s *registrySuite) TestRegistryUnsetManyViews(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid", "password": "my-secret"}) + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid", "password": "my-secret"}) s.state.Unlock() c.Assert(err, IsNil) @@ -177,14 +177,14 @@ func (s *registrySuite) TestRegistryUnsetManyViews(c *C) { s.mockContext.Unlock() s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "write-wifi", []string{"ssid", "password"}) + _, err = registrystate.Get(s.state, s.devAccID, "network", "write-wifi", []string{"ssid", "password"}) s.state.Unlock() c.Assert(err, ErrorMatches, `cannot get "ssid", "password" .*: matching rules don't map to any values`) } func (s *registrySuite) TestRegistryUnsetHappensTransactionally(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid"}) + err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid"}) s.state.Unlock() c.Assert(err, IsNil) @@ -194,7 +194,7 @@ func (s *registrySuite) TestRegistryUnsetHappensTransactionally(c *C) { c.Check(stderr, IsNil) s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) + val, err := registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) s.state.Unlock() c.Assert(err, IsNil) c.Assert(val, DeepEquals, map[string]interface{}{ @@ -207,7 +207,7 @@ func (s *registrySuite) TestRegistryUnsetHappensTransactionally(c *C) { s.mockContext.Unlock() s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) + _, err = registrystate.Get(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) s.state.Unlock() c.Assert(err, ErrorMatches, `cannot get "ssid" .*: matching rules don't map to any values`) } diff --git a/overlord/registrystate/registrystate.go b/overlord/registrystate/registrystate.go index bb0ee4c2bf8..00e8653a056 100644 --- a/overlord/registrystate/registrystate.go +++ b/overlord/registrystate/registrystate.go @@ -34,9 +34,9 @@ import ( var assertstateRegistry = assertstate.Registry -// SetViaView finds the view identified by the account, registry and view names -// and sets the request fields to their respective values. -func SetViaView(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error { +// Set finds the view identified by the account, registry and view names and +// sets the request fields to their respective values. +func Set(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error { registryAssert, err := assertstateRegistry(st, account, registryName) if err != nil { return err @@ -68,21 +68,21 @@ func SetViaView(st *state.State, account, registryName, viewName string, request return err } - if err = SetViaViewInTx(tx, view, requests); err != nil { + if err := SetViaView(tx, view, requests); err != nil { return err } return tx.Commit(st, reg.Schema) } -// SetViaViewInTx uses the view to set the requests in the transaction's databag. -func SetViaViewInTx(tx *Transaction, view *registry.View, requests map[string]interface{}) error { +// SetViaView uses the view to set the requests in the transaction's databag. +func SetViaView(bag registry.DataBag, view *registry.View, requests map[string]interface{}) error { for field, value := range requests { var err error if value == nil { - err = view.Unset(tx, field) + err = view.Unset(bag, field) } else { - err = view.Set(tx, field, value) + err = view.Set(bag, field, value) } if err != nil { @@ -93,11 +93,11 @@ func SetViaViewInTx(tx *Transaction, view *registry.View, requests map[string]in return nil } -// GetViaView finds the view identified by the account, registry and view names -// and uses it to get the values for the specified fields. The results are -// returned in a map of fields to their values, unless there are no fields in -// which case all views are returned. -func GetViaView(st *state.State, account, registryName, viewName string, fields []string) (interface{}, error) { +// Get finds the view identified by the account, registry and view names and +// uses it to get the values for the specified fields. The results are returned +// in a map of fields to their values, unless there are no fields in which case +// case all views are returned. +func Get(st *state.State, account, registryName, viewName string, fields []string) (interface{}, error) { registryAssert, err := assertstateRegistry(st, account, registryName) if err != nil { return nil, err @@ -121,14 +121,14 @@ func GetViaView(st *state.State, account, registryName, viewName string, fields return nil, err } - return GetViaViewInTx(tx, view, fields) + return GetViaView(tx, view, fields) } -// GetViaViewInTx uses the view to get values for the fields from the databag -// in the transaction. -func GetViaViewInTx(tx *Transaction, view *registry.View, fields []string) (interface{}, error) { +// GetViaView uses the view to get values for the fields from the databag in +// the transaction. +func GetViaView(bag registry.DataBag, view *registry.View, fields []string) (interface{}, error) { if len(fields) == 0 { - val, err := view.Get(tx, "") + val, err := view.Get(bag, "") if err != nil { return nil, err } @@ -138,7 +138,7 @@ func GetViaViewInTx(tx *Transaction, view *registry.View, fields []string) (inte results := make(map[string]interface{}, len(fields)) for _, field := range fields { - value, err := view.Get(tx, field) + value, err := view.Get(bag, field) if err != nil { if errors.Is(err, ®istry.NotFoundError{}) && len(fields) > 1 { // keep looking; return partial result if only some fields are found @@ -153,8 +153,8 @@ func GetViaViewInTx(tx *Transaction, view *registry.View, fields []string) (inte if len(results) == 0 { return nil, ®istry.NotFoundError{ - Account: tx.RegistryAccount, - RegistryName: tx.RegistryName, + Account: view.Registry().Account, + RegistryName: view.Registry().Name, View: view.Name, Operation: "get", Requests: fields, diff --git a/overlord/registrystate/registrystate_test.go b/overlord/registrystate/registrystate_test.go index c56a68f662e..0334f0cb670 100644 --- a/overlord/registrystate/registrystate_test.go +++ b/overlord/registrystate/registrystate_test.go @@ -155,7 +155,7 @@ func (s *registryTestSuite) TestGetView(c *C) { c.Assert(err, IsNil) s.state.Set("registry-databags", map[string]map[string]registry.JSONDataBag{s.devAccID: {"network": databag}}) - res, err := registrystate.GetViaView(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) + res, err := registrystate.Get(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) c.Assert(err, IsNil) c.Assert(res, DeepEquals, map[string]interface{}{"ssid": "foo"}) } @@ -164,17 +164,17 @@ func (s *registryTestSuite) TestGetNotFound(c *C) { s.state.Lock() defer s.state.Unlock() - res, err := registrystate.GetViaView(s.state, s.devAccID, "network", "other-view", []string{"ssid"}) + res, err := registrystate.Get(s.state, s.devAccID, "network", "other-view", []string{"ssid"}) c.Assert(err, FitsTypeOf, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "ssid" in registry view %s/network/other-view: not found`, s.devAccID)) c.Check(res, IsNil) - res, err = registrystate.GetViaView(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) + res, err = registrystate.Get(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) c.Assert(err, FitsTypeOf, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "ssid" in registry view %s/network/setup-wifi: matching rules don't map to any values`, s.devAccID)) c.Check(res, IsNil) - res, err = registrystate.GetViaView(s.state, s.devAccID, "network", "setup-wifi", []string{"other-field"}) + res, err = registrystate.Get(s.state, s.devAccID, "network", "setup-wifi", []string{"other-field"}) c.Assert(err, FitsTypeOf, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "other-field" in registry view %s/network/setup-wifi: no matching read rule`, s.devAccID)) c.Check(res, IsNil) @@ -184,7 +184,7 @@ func (s *registryTestSuite) TestSetView(c *C) { s.state.Lock() defer s.state.Unlock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "foo"}) + err := registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "foo"}) c.Assert(err, IsNil) var databags map[string]map[string]registry.JSONDataBag @@ -200,11 +200,11 @@ func (s *registryTestSuite) TestSetNotFound(c *C) { s.state.Lock() defer s.state.Unlock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"foo": "bar"}) + err := registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"foo": "bar"}) c.Assert(err, FitsTypeOf, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot set "foo" in registry view %s/network/setup-wifi: no matching write rule`, s.devAccID)) - err = registrystate.SetViaView(s.state, s.devAccID, "network", "other-view", map[string]interface{}{"foo": "bar"}) + err = registrystate.Set(s.state, s.devAccID, "network", "other-view", map[string]interface{}{"foo": "bar"}) c.Assert(err, FitsTypeOf, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot set "foo" in registry view %s/network/other-view: not found`, s.devAccID)) } @@ -214,10 +214,10 @@ func (s *registryTestSuite) TestUnsetView(c *C) { defer s.state.Unlock() databag := registry.NewJSONDataBag() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "foo"}) + err := registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "foo"}) c.Assert(err, IsNil) - err = registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": nil}) + err = registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": nil}) c.Assert(err, IsNil) val, err := databag.Get("wifi.ssid") @@ -238,13 +238,13 @@ func (s *registryTestSuite) TestRegistrystateSetWithExistingState(c *C) { s.state.Set("registry-databags", databags) - results, err := registrystate.GetViaView(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) + results, err := registrystate.Get(s.state, s.devAccID, "network", "setup-wifi", []string{"ssid"}) c.Assert(err, IsNil) resultsMap, ok := results.(map[string]interface{}) c.Assert(ok, Equals, true) c.Assert(resultsMap["ssid"], Equals, "bar") - err = registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "baz"}) + err = registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "baz"}) c.Assert(err, IsNil) err = s.state.Get("registry-databags", &databags) @@ -283,7 +283,7 @@ func (s *registryTestSuite) TestRegistrystateSetWithNoState(c *C) { for _, tc := range testcases { s.state.Set("registry-databags", tc.state) - err := registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "bar"}) + err := registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{"ssid": "bar"}) c.Assert(err, IsNil) var databags map[string]map[string]registry.JSONDataBag @@ -300,7 +300,7 @@ func (s *registryTestSuite) TestRegistrystateGetEntireView(c *C) { s.state.Lock() defer s.state.Unlock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{ + err := registrystate.Set(s.state, s.devAccID, "network", "setup-wifi", map[string]interface{}{ "ssids": []interface{}{"foo", "bar"}, "password": "pass", "private": map[string]interface{}{ @@ -310,7 +310,7 @@ func (s *registryTestSuite) TestRegistrystateGetEntireView(c *C) { }) c.Assert(err, IsNil) - res, err := registrystate.GetViaView(s.state, s.devAccID, "network", "setup-wifi", nil) + res, err := registrystate.Get(s.state, s.devAccID, "network", "setup-wifi", nil) c.Assert(err, IsNil) c.Assert(res, DeepEquals, map[string]interface{}{ "ssids": []interface{}{"foo", "bar"}, diff --git a/overlord/registrystate/transaction.go b/overlord/registrystate/transaction.go index 4aff73b031f..865476e9915 100644 --- a/overlord/registrystate/transaction.go +++ b/overlord/registrystate/transaction.go @@ -296,3 +296,7 @@ func (t *Transaction) aborted() bool { func (t *Transaction) AbortInfo() (snap, reason string) { return t.abortingSnap, t.abortReason } + +func (t *Transaction) Pristine() registry.DataBag { + return t.pristine +} diff --git a/overlord/registrystate/transaction_test.go b/overlord/registrystate/transaction_test.go index ac9d04b65e5..8fbce07adc5 100644 --- a/overlord/registrystate/transaction_test.go +++ b/overlord/registrystate/transaction_test.go @@ -430,3 +430,31 @@ func (s *transactionTestSuite) TestAbortPreventsReadsAndWrites(c *C) { err = tx.Commit(s.state, registry.NewJSONSchema()) c.Assert(err, ErrorMatches, "cannot commit aborted transaction") } + +func (s *transactionTestSuite) TestTransactionPristine(c *C) { + bag := registry.NewJSONDataBag() + err := bag.Set("foo", "bar") + c.Assert(err, IsNil) + + err = registrystate.WriteDatabag(s.state, bag, "my-account", "my-reg") + c.Assert(err, IsNil) + + tx, err := registrystate.NewTransaction(s.state, "my-account", "my-reg") + c.Assert(err, IsNil) + + err = tx.Set("foo", "baz") + c.Assert(err, IsNil) + + checkPristine := func(key, expected string) { + pristineBag := tx.Pristine() + val, err := pristineBag.Get(key) + c.Assert(err, IsNil) + c.Check(val, Equals, expected) + } + checkPristine("foo", "bar") + + err = tx.Commit(s.state, registry.NewJSONSchema()) + c.Assert(err, IsNil) + + checkPristine("foo", "baz") +}