Skip to content

Commit

Permalink
plugin register with artifact stubs VAULT-32686 (#29113)
Browse files Browse the repository at this point in the history
* add plugin catalog's entValidate() and setInternal() oss stubs 
* create plugin register command constructor oss stub
* create EntPluginRunner oss stub
* add validateSHA256() oss stub to validate plugin catalog update input
  • Loading branch information
thyton authored Jan 9, 2025
1 parent 80fe86a commit 4f14f7b
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 118 deletions.
4 changes: 1 addition & 3 deletions command/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,7 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) map[string]cli.Co
}, nil
},
"plugin register": func() (cli.Command, error) {
return &PluginRegisterCommand{
BaseCommand: getBaseCommand(),
}, nil
return NewPluginRegisterCommand(getBaseCommand()), nil
},
"plugin reload": func() (cli.Command, error) {
return &PluginReloadCommand{
Expand Down
14 changes: 14 additions & 0 deletions command/plugin_register_stubs_oss.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build !enterprise

package command

import "github.com/hashicorp/cli"

func NewPluginRegisterCommand(baseCommand *BaseCommand) cli.Command {
return &PluginRegisterCommand{
BaseCommand: baseCommand,
}
}
2 changes: 2 additions & 0 deletions sdk/helper/pluginutil/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const MultiplexingCtxKey string = "multiplex_id"
// PluginRunner defines the metadata needed to run a plugin securely with
// go-plugin.
type PluginRunner struct {
EntPluginRunner

Name string `json:"name" structs:"name"`
Type consts.PluginType `json:"type" structs:"type"`
Version string `json:"version" structs:"version"`
Expand Down
8 changes: 8 additions & 0 deletions sdk/helper/pluginutil/runner_stubs_oss.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !enterprise

package pluginutil

type EntPluginRunner struct{}
4 changes: 2 additions & 2 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica
sha256 := d.Get("sha256").(string)
if sha256 == "" {
sha256 = d.Get("sha_256").(string)
if sha256 == "" {
return logical.ErrorResponse("missing SHA-256 value"), nil
if resp := validateSHA256(sha256); resp.IsError() {
return resp, nil
}
}

Expand Down
17 changes: 17 additions & 0 deletions vault/logical_system_plugins_stubs_oss.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build !enterprise

package vault

import (
"github.com/hashicorp/vault/sdk/logical"
)

func validateSHA256(sha256 string) *logical.Response {
if sha256 == "" {
return logical.ErrorResponse("missing SHA-256 value")
}
return nil
}
120 changes: 7 additions & 113 deletions vault/plugincatalog/plugin_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ func SetupPluginCatalog(ctx context.Context, in *PluginCatalogInput) (*PluginCat
return nil, err
}

// Sanitize the plugin catalog
err = catalog.entValidate(ctx)
if err != nil {
logger.Error("error while sanitizing plugin storage", "error", err)
return nil, err
}

if legacy, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUseLegacyEnvLayering)); legacy {
conflicts := false
osKeys := envKeys(os.Environ())
Expand Down Expand Up @@ -959,119 +966,6 @@ func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInpu
return err
}

func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) {
command := plugin.Command
if plugin.OCIImage == "" {
// Best effort check to make sure the command isn't breaking out of the
// configured plugin directory.
command = filepath.Join(c.directory, plugin.Command)
sym, err := filepath.EvalSymlinks(command)
if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err)
}
symAbs, err := filepath.Abs(filepath.Dir(sym))
if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err)
}

if symAbs != c.directory {
return nil, errors.New("cannot execute files outside of configured plugin directory")
}
}

// entryTmp should only be used for the below type and version checks. It uses the
// full command instead of the relative command because get() normally prepends
// the plugin directory to the command, but we can't use get() here.
entryTmp := &pluginutil.PluginRunner{
Name: plugin.Name,
Command: command,
OCIImage: plugin.OCIImage,
Runtime: plugin.Runtime,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}
if entryTmp.OCIImage != "" && entryTmp.Runtime != "" {
var err error
entryTmp.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entryTmp.Runtime, consts.PluginRuntimeTypeContainer)
if err != nil {
return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", plugin.Name, err)
}
}
// If the plugin type is unknown, we want to attempt to determine the type
if plugin.Type == consts.PluginTypeUnknown {
var err error
plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if err != nil {
return nil, err
}
if plugin.Type == consts.PluginTypeUnknown {
return nil, ErrPluginBadType
}
}

// getting the plugin version is best-effort, so errors are not fatal
runningVersion := logical.EmptyPluginVersion
var versionErr error
switch plugin.Type {
case consts.PluginTypeSecrets, consts.PluginTypeCredential:
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
case consts.PluginTypeDatabase:
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
default:
return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type)
}
if versionErr != nil {
c.logger.Warn("Error determining plugin version", "error", versionErr)
if errors.Is(versionErr, ErrPluginUnableToRun) {
return nil, versionErr
}
} else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version {
c.logger.Error("Plugin self-reported version did not match requested version",
"plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("%w: %s reported version (%s) did not match requested version (%s)",
ErrPluginVersionMismatch, plugin.Name, runningVersion.Version, plugin.Version)
} else if plugin.Version == "" && runningVersion.Version != "" {
plugin.Version = runningVersion.Version
_, err := semver.NewVersion(plugin.Version)
if err != nil {
return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err)
}
}

entry := &pluginutil.PluginRunner{
Name: plugin.Name,
Type: plugin.Type,
Version: plugin.Version,
Command: plugin.Command,
OCIImage: plugin.OCIImage,
Runtime: plugin.Runtime,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}

buf, err := json.Marshal(entry)
if err != nil {
return nil, fmt.Errorf("failed to encode plugin entry: %w", err)
}

storageKey := path.Join(plugin.Type.String(), plugin.Name)
if plugin.Version != "" {
storageKey = path.Join(storageKey, plugin.Version)
}
logicalEntry := logical.StorageEntry{
Key: storageKey,
Value: buf,
}
if err := c.catalogView.Put(ctx, &logicalEntry); err != nil {
return nil, fmt.Errorf("failed to persist plugin entry: %w", err)
}
return entry, nil
}

// Delete is used to remove an external plugin from the catalog. Builtin plugins
// can not be deleted.
func (c *PluginCatalog) Delete(ctx context.Context, name string, pluginType consts.PluginType, pluginVersion string) error {
Expand Down
138 changes: 138 additions & 0 deletions vault/plugincatalog/plugin_catalog_stubs_oss.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build !enterprise

package plugincatalog

import (
"context"
"encoding/json"
"errors"
"fmt"
"path"
"path/filepath"

semver "github.com/hashicorp/go-version"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

// setInternal creates a new plugin entry in the catalog and persists it to storage
func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) {
command := plugin.Command
if plugin.OCIImage == "" {
// Best effort check to make sure the command isn't breaking out of the
// configured plugin directory.
command = filepath.Join(c.directory, plugin.Command)
sym, err := filepath.EvalSymlinks(command)
if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err)
}
symAbs, err := filepath.Abs(filepath.Dir(sym))
if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err)
}

if symAbs != c.directory {
return nil, errors.New("cannot execute files outside of configured plugin directory")
}
}

// entryTmp should only be used for the below type and version checks. It uses the
// full command instead of the relative command because get() normally prepends
// the plugin directory to the command, but we can't use get() here.
entryTmp := &pluginutil.PluginRunner{
Name: plugin.Name,
Command: command,
OCIImage: plugin.OCIImage,
Runtime: plugin.Runtime,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}
if entryTmp.OCIImage != "" && entryTmp.Runtime != "" {
var err error
entryTmp.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entryTmp.Runtime, consts.PluginRuntimeTypeContainer)
if err != nil {
return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", plugin.Name, err)
}
}
// If the plugin type is unknown, we want to attempt to determine the type
if plugin.Type == consts.PluginTypeUnknown {
var err error
plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if err != nil {
return nil, err
}
if plugin.Type == consts.PluginTypeUnknown {
return nil, ErrPluginBadType
}
}

// getting the plugin version is best-effort, so errors are not fatal
runningVersion := logical.EmptyPluginVersion
var versionErr error
switch plugin.Type {
case consts.PluginTypeSecrets, consts.PluginTypeCredential:
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
case consts.PluginTypeDatabase:
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
default:
return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type)
}
if versionErr != nil {
c.logger.Warn("Error determining plugin version", "error", versionErr)
if errors.Is(versionErr, ErrPluginUnableToRun) {
return nil, versionErr
}
} else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version {
c.logger.Error("Plugin self-reported version did not match requested version",
"plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("%w: %s reported version (%s) did not match requested version (%s)",
ErrPluginVersionMismatch, plugin.Name, runningVersion.Version, plugin.Version)
} else if plugin.Version == "" && runningVersion.Version != "" {
plugin.Version = runningVersion.Version
_, err := semver.NewVersion(plugin.Version)
if err != nil {
return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err)
}
}

entry := &pluginutil.PluginRunner{
Name: plugin.Name,
Type: plugin.Type,
Version: plugin.Version,
Command: plugin.Command,
OCIImage: plugin.OCIImage,
Runtime: plugin.Runtime,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}

buf, err := json.Marshal(entry)
if err != nil {
return nil, fmt.Errorf("failed to encode plugin entry: %w", err)
}

storageKey := path.Join(plugin.Type.String(), plugin.Name)
if plugin.Version != "" {
storageKey = path.Join(storageKey, plugin.Version)
}
logicalEntry := logical.StorageEntry{
Key: storageKey,
Value: buf,
}
if err := c.catalogView.Put(ctx, &logicalEntry); err != nil {
return nil, fmt.Errorf("failed to persist plugin entry: %w", err)
}
return entry, nil
}

func (c *PluginCatalog) entValidate(context.Context) error {
return nil
}

0 comments on commit 4f14f7b

Please sign in to comment.