From 4f14f7bfec975fa7caf93c0b3dd716d12b1913cc Mon Sep 17 00:00:00 2001 From: Thy Ton Date: Thu, 9 Jan 2025 08:20:09 -0800 Subject: [PATCH] plugin register with artifact stubs VAULT-32686 (#29113) * add plugin catalog's entValidate() and setInternal() oss stubs * create plugin register command constructor oss stub * create EntPluginRunner oss stub * add validateSHA256() oss stub to validate plugin catalog update input --- command/commands.go | 4 +- command/plugin_register_stubs_oss.go | 14 ++ sdk/helper/pluginutil/runner.go | 2 + sdk/helper/pluginutil/runner_stubs_oss.go | 8 + vault/logical_system.go | 4 +- vault/logical_system_plugins_stubs_oss.go | 17 +++ vault/plugincatalog/plugin_catalog.go | 120 +-------------- .../plugincatalog/plugin_catalog_stubs_oss.go | 138 ++++++++++++++++++ 8 files changed, 189 insertions(+), 118 deletions(-) create mode 100644 command/plugin_register_stubs_oss.go create mode 100644 sdk/helper/pluginutil/runner_stubs_oss.go create mode 100644 vault/logical_system_plugins_stubs_oss.go create mode 100644 vault/plugincatalog/plugin_catalog_stubs_oss.go diff --git a/command/commands.go b/command/commands.go index daa40ead48a9..90287478ef85 100644 --- a/command/commands.go +++ b/command/commands.go @@ -551,9 +551,7 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) map[string]cli.Co }, nil }, "plugin register": func() (cli.Command, error) { - return &PluginRegisterCommand{ - BaseCommand: getBaseCommand(), - }, nil + return NewPluginRegisterCommand(getBaseCommand()), nil }, "plugin reload": func() (cli.Command, error) { return &PluginReloadCommand{ diff --git a/command/plugin_register_stubs_oss.go b/command/plugin_register_stubs_oss.go new file mode 100644 index 000000000000..dbcfb4a44776 --- /dev/null +++ b/command/plugin_register_stubs_oss.go @@ -0,0 +1,14 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package command + +import "github.com/hashicorp/cli" + +func NewPluginRegisterCommand(baseCommand *BaseCommand) cli.Command { + return &PluginRegisterCommand{ + BaseCommand: baseCommand, + } +} diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index ebbe110c3474..ecae61459b5c 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -57,6 +57,8 @@ const MultiplexingCtxKey string = "multiplex_id" // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { + EntPluginRunner + Name string `json:"name" structs:"name"` Type consts.PluginType `json:"type" structs:"type"` Version string `json:"version" structs:"version"` diff --git a/sdk/helper/pluginutil/runner_stubs_oss.go b/sdk/helper/pluginutil/runner_stubs_oss.go new file mode 100644 index 000000000000..b5d390a44d24 --- /dev/null +++ b/sdk/helper/pluginutil/runner_stubs_oss.go @@ -0,0 +1,8 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !enterprise + +package pluginutil + +type EntPluginRunner struct{} diff --git a/vault/logical_system.go b/vault/logical_system.go index 49dfc0c6864d..1a00b10adb76 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -538,8 +538,8 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica sha256 := d.Get("sha256").(string) if sha256 == "" { sha256 = d.Get("sha_256").(string) - if sha256 == "" { - return logical.ErrorResponse("missing SHA-256 value"), nil + if resp := validateSHA256(sha256); resp.IsError() { + return resp, nil } } diff --git a/vault/logical_system_plugins_stubs_oss.go b/vault/logical_system_plugins_stubs_oss.go new file mode 100644 index 000000000000..30ccdbc5e985 --- /dev/null +++ b/vault/logical_system_plugins_stubs_oss.go @@ -0,0 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package vault + +import ( + "github.com/hashicorp/vault/sdk/logical" +) + +func validateSHA256(sha256 string) *logical.Response { + if sha256 == "" { + return logical.ErrorResponse("missing SHA-256 value") + } + return nil +} diff --git a/vault/plugincatalog/plugin_catalog.go b/vault/plugincatalog/plugin_catalog.go index b2b61f54bb7a..d38cf167611f 100644 --- a/vault/plugincatalog/plugin_catalog.go +++ b/vault/plugincatalog/plugin_catalog.go @@ -173,6 +173,13 @@ func SetupPluginCatalog(ctx context.Context, in *PluginCatalogInput) (*PluginCat return nil, err } + // Sanitize the plugin catalog + err = catalog.entValidate(ctx) + if err != nil { + logger.Error("error while sanitizing plugin storage", "error", err) + return nil, err + } + if legacy, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUseLegacyEnvLayering)); legacy { conflicts := false osKeys := envKeys(os.Environ()) @@ -959,119 +966,6 @@ func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInpu return err } -func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) { - command := plugin.Command - if plugin.OCIImage == "" { - // Best effort check to make sure the command isn't breaking out of the - // configured plugin directory. - command = filepath.Join(c.directory, plugin.Command) - sym, err := filepath.EvalSymlinks(command) - if err != nil { - return nil, fmt.Errorf("error while validating the command path: %w", err) - } - symAbs, err := filepath.Abs(filepath.Dir(sym)) - if err != nil { - return nil, fmt.Errorf("error while validating the command path: %w", err) - } - - if symAbs != c.directory { - return nil, errors.New("cannot execute files outside of configured plugin directory") - } - } - - // entryTmp should only be used for the below type and version checks. It uses the - // full command instead of the relative command because get() normally prepends - // the plugin directory to the command, but we can't use get() here. - entryTmp := &pluginutil.PluginRunner{ - Name: plugin.Name, - Command: command, - OCIImage: plugin.OCIImage, - Runtime: plugin.Runtime, - Args: plugin.Args, - Env: plugin.Env, - Sha256: plugin.Sha256, - Builtin: false, - } - if entryTmp.OCIImage != "" && entryTmp.Runtime != "" { - var err error - entryTmp.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entryTmp.Runtime, consts.PluginRuntimeTypeContainer) - if err != nil { - return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", plugin.Name, err) - } - } - // If the plugin type is unknown, we want to attempt to determine the type - if plugin.Type == consts.PluginTypeUnknown { - var err error - plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp) - if err != nil { - return nil, err - } - if plugin.Type == consts.PluginTypeUnknown { - return nil, ErrPluginBadType - } - } - - // getting the plugin version is best-effort, so errors are not fatal - runningVersion := logical.EmptyPluginVersion - var versionErr error - switch plugin.Type { - case consts.PluginTypeSecrets, consts.PluginTypeCredential: - runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) - case consts.PluginTypeDatabase: - runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) - default: - return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type) - } - if versionErr != nil { - c.logger.Warn("Error determining plugin version", "error", versionErr) - if errors.Is(versionErr, ErrPluginUnableToRun) { - return nil, versionErr - } - } else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version { - c.logger.Error("Plugin self-reported version did not match requested version", - "plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version) - return nil, fmt.Errorf("%w: %s reported version (%s) did not match requested version (%s)", - ErrPluginVersionMismatch, plugin.Name, runningVersion.Version, plugin.Version) - } else if plugin.Version == "" && runningVersion.Version != "" { - plugin.Version = runningVersion.Version - _, err := semver.NewVersion(plugin.Version) - if err != nil { - return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err) - } - } - - entry := &pluginutil.PluginRunner{ - Name: plugin.Name, - Type: plugin.Type, - Version: plugin.Version, - Command: plugin.Command, - OCIImage: plugin.OCIImage, - Runtime: plugin.Runtime, - Args: plugin.Args, - Env: plugin.Env, - Sha256: plugin.Sha256, - Builtin: false, - } - - buf, err := json.Marshal(entry) - if err != nil { - return nil, fmt.Errorf("failed to encode plugin entry: %w", err) - } - - storageKey := path.Join(plugin.Type.String(), plugin.Name) - if plugin.Version != "" { - storageKey = path.Join(storageKey, plugin.Version) - } - logicalEntry := logical.StorageEntry{ - Key: storageKey, - Value: buf, - } - if err := c.catalogView.Put(ctx, &logicalEntry); err != nil { - return nil, fmt.Errorf("failed to persist plugin entry: %w", err) - } - return entry, nil -} - // Delete is used to remove an external plugin from the catalog. Builtin plugins // can not be deleted. func (c *PluginCatalog) Delete(ctx context.Context, name string, pluginType consts.PluginType, pluginVersion string) error { diff --git a/vault/plugincatalog/plugin_catalog_stubs_oss.go b/vault/plugincatalog/plugin_catalog_stubs_oss.go new file mode 100644 index 000000000000..12673393f7e5 --- /dev/null +++ b/vault/plugincatalog/plugin_catalog_stubs_oss.go @@ -0,0 +1,138 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package plugincatalog + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "path" + "path/filepath" + + semver "github.com/hashicorp/go-version" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" +) + +// setInternal creates a new plugin entry in the catalog and persists it to storage +func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) { + command := plugin.Command + if plugin.OCIImage == "" { + // Best effort check to make sure the command isn't breaking out of the + // configured plugin directory. + command = filepath.Join(c.directory, plugin.Command) + sym, err := filepath.EvalSymlinks(command) + if err != nil { + return nil, fmt.Errorf("error while validating the command path: %w", err) + } + symAbs, err := filepath.Abs(filepath.Dir(sym)) + if err != nil { + return nil, fmt.Errorf("error while validating the command path: %w", err) + } + + if symAbs != c.directory { + return nil, errors.New("cannot execute files outside of configured plugin directory") + } + } + + // entryTmp should only be used for the below type and version checks. It uses the + // full command instead of the relative command because get() normally prepends + // the plugin directory to the command, but we can't use get() here. + entryTmp := &pluginutil.PluginRunner{ + Name: plugin.Name, + Command: command, + OCIImage: plugin.OCIImage, + Runtime: plugin.Runtime, + Args: plugin.Args, + Env: plugin.Env, + Sha256: plugin.Sha256, + Builtin: false, + } + if entryTmp.OCIImage != "" && entryTmp.Runtime != "" { + var err error + entryTmp.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entryTmp.Runtime, consts.PluginRuntimeTypeContainer) + if err != nil { + return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", plugin.Name, err) + } + } + // If the plugin type is unknown, we want to attempt to determine the type + if plugin.Type == consts.PluginTypeUnknown { + var err error + plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp) + if err != nil { + return nil, err + } + if plugin.Type == consts.PluginTypeUnknown { + return nil, ErrPluginBadType + } + } + + // getting the plugin version is best-effort, so errors are not fatal + runningVersion := logical.EmptyPluginVersion + var versionErr error + switch plugin.Type { + case consts.PluginTypeSecrets, consts.PluginTypeCredential: + runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) + case consts.PluginTypeDatabase: + runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) + default: + return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type) + } + if versionErr != nil { + c.logger.Warn("Error determining plugin version", "error", versionErr) + if errors.Is(versionErr, ErrPluginUnableToRun) { + return nil, versionErr + } + } else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version { + c.logger.Error("Plugin self-reported version did not match requested version", + "plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version) + return nil, fmt.Errorf("%w: %s reported version (%s) did not match requested version (%s)", + ErrPluginVersionMismatch, plugin.Name, runningVersion.Version, plugin.Version) + } else if plugin.Version == "" && runningVersion.Version != "" { + plugin.Version = runningVersion.Version + _, err := semver.NewVersion(plugin.Version) + if err != nil { + return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err) + } + } + + entry := &pluginutil.PluginRunner{ + Name: plugin.Name, + Type: plugin.Type, + Version: plugin.Version, + Command: plugin.Command, + OCIImage: plugin.OCIImage, + Runtime: plugin.Runtime, + Args: plugin.Args, + Env: plugin.Env, + Sha256: plugin.Sha256, + Builtin: false, + } + + buf, err := json.Marshal(entry) + if err != nil { + return nil, fmt.Errorf("failed to encode plugin entry: %w", err) + } + + storageKey := path.Join(plugin.Type.String(), plugin.Name) + if plugin.Version != "" { + storageKey = path.Join(storageKey, plugin.Version) + } + logicalEntry := logical.StorageEntry{ + Key: storageKey, + Value: buf, + } + if err := c.catalogView.Put(ctx, &logicalEntry); err != nil { + return nil, fmt.Errorf("failed to persist plugin entry: %w", err) + } + return entry, nil +} + +func (c *PluginCatalog) entValidate(context.Context) error { + return nil +}