From 2f6e0c5345670433b1a6b360ac6abf88c47243e5 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Fri, 22 Sep 2023 00:26:13 +0100 Subject: [PATCH] Add TempDir option to UnixSocketConfig (#282) * Add TempDir option to UnixSocketConfig. Allows clients to specify a folder where plugin-specific Unix socket directories should be created. Still defaults to $TMPDIR (if set) or /tmp. * Improve UnixSocketConfig field names and comments * Document exported Unix socket environment variables --- client.go | 33 ++++++++++++++++----------------- client_unix_test.go | 12 +++++++++++- constants.go | 7 ++++++- server.go | 9 ++++++++- 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index b6024afc..cd89985a 100644 --- a/client.go +++ b/client.go @@ -252,16 +252,15 @@ type UnixSocketConfig struct { // client process must be a member of this group or chown will fail. Group string - // The directory to create Unix sockets in. Internally managed by go-plugin - // and deleted when the plugin is killed. - directory string -} - -func unixSocketConfigFromEnv() UnixSocketConfig { - return UnixSocketConfig{ - Group: os.Getenv(EnvUnixSocketGroup), - directory: os.Getenv(EnvUnixSocketDir), - } + // TempDir specifies the base directory to use when creating a plugin-specific + // temporary directory. It is expected to already exist and be writable. If + // not set, defaults to the directory chosen by os.MkdirTemp. + TempDir string + + // The directory to create Unix sockets in. Internally created and managed + // by go-plugin and deleted when the plugin is killed. Will be created + // inside TempDir if specified. + socketDir string } // ReattachConfig is used to configure a client to reattach to an @@ -467,7 +466,7 @@ func (c *Client) Kill() { c.l.Lock() runner := c.runner addr := c.address - hostSocketDir := c.unixSocketCfg.directory + hostSocketDir := c.unixSocketCfg.socketDir c.l.Unlock() // If there is no runner or ID, there is nothing to kill. @@ -652,7 +651,7 @@ func (c *Client) Start() (addr net.Addr, err error) { } if c.config.UnixSocketConfig != nil { - c.unixSocketCfg.Group = c.config.UnixSocketConfig.Group + c.unixSocketCfg = *c.config.UnixSocketConfig } if c.unixSocketCfg.Group != "" { @@ -662,22 +661,22 @@ func (c *Client) Start() (addr net.Addr, err error) { var runner runner.Runner switch { case c.config.RunnerFunc != nil: - c.unixSocketCfg.directory, err = os.MkdirTemp("", "plugin-dir") + c.unixSocketCfg.socketDir, err = os.MkdirTemp(c.unixSocketCfg.TempDir, "plugin-dir") if err != nil { return nil, err } // os.MkdirTemp creates folders with 0o700, so if we have a group // configured we need to make it group-writable. if c.unixSocketCfg.Group != "" { - err = setGroupWritable(c.unixSocketCfg.directory, c.unixSocketCfg.Group, 0o770) + err = setGroupWritable(c.unixSocketCfg.socketDir, c.unixSocketCfg.Group, 0o770) if err != nil { return nil, err } } - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.directory)) - c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.directory) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.socketDir)) + c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.socketDir) - runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.directory) + runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.socketDir) if err != nil { return nil, err } diff --git a/client_unix_test.go b/client_unix_test.go index 6c1f16a3..fa9b0a4a 100644 --- a/client_unix_test.go +++ b/client_unix_test.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "runtime" "syscall" "testing" @@ -29,6 +30,11 @@ func TestSetGroup(t *testing.T) { if err != nil { t.Fatal(err) } + baseTempDir := t.TempDir() + baseTempDir, err = filepath.EvalSymlinks(baseTempDir) + if err != nil { + t.Fatal(err) + } for name, tc := range map[string]struct { group string }{ @@ -41,7 +47,8 @@ func TestSetGroup(t *testing.T) { HandshakeConfig: testHandshake, Plugins: testPluginMap, UnixSocketConfig: &UnixSocketConfig{ - Group: tc.group, + Group: tc.group, + TempDir: baseTempDir, }, RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) { // Run tests inside the RunnerFunc to ensure we don't race @@ -49,6 +56,9 @@ func TestSetGroup(t *testing.T) { // to start properly. // Test that it creates a directory with the proper owners and permissions. + if filepath.Dir(tmpDir) != baseTempDir { + t.Errorf("Expected base TempDir to be %s, but tmpDir was %s", baseTempDir, tmpDir) + } info, err := os.Lstat(tmpDir) if err != nil { t.Fatal(err) diff --git a/constants.go b/constants.go index b66fa799..32e58602 100644 --- a/constants.go +++ b/constants.go @@ -4,6 +4,11 @@ package plugin const ( - EnvUnixSocketDir = "PLUGIN_UNIX_SOCKET_DIR" + // EnvUnixSocketDir specifies the directory that _plugins_ should create unix + // sockets in. Does not affect client behavior. + EnvUnixSocketDir = "PLUGIN_UNIX_SOCKET_DIR" + + // EnvUnixSocketGroup specifies the owning, writable group to set for Unix + // sockets created by _plugins_. Does not affect client behavior. EnvUnixSocketGroup = "PLUGIN_UNIX_SOCKET_GROUP" ) diff --git a/server.go b/server.go index 4b0f2b76..1ba5f231 100644 --- a/server.go +++ b/server.go @@ -134,6 +134,13 @@ type ServeTestConfig struct { SyncStdio bool } +func unixSocketConfigFromEnv() UnixSocketConfig { + return UnixSocketConfig{ + Group: os.Getenv(EnvUnixSocketGroup), + socketDir: os.Getenv(EnvUnixSocketDir), + } +} + // protocolVersion determines the protocol version and plugin set to be used by // the server. In the event that there is no suitable version, the last version // in the config is returned leaving the client to report the incompatibility. @@ -547,7 +554,7 @@ func serverListener_tcp() (net.Listener, error) { } func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) { - tf, err := os.CreateTemp(unixSocketCfg.directory, "plugin") + tf, err := os.CreateTemp(unixSocketCfg.socketDir, "plugin") if err != nil { return nil, err }