diff --git a/pkg/netns/netns_linux.go b/pkg/netns/netns_linux.go index 2ed1ac96d..1aa797d10 100644 --- a/pkg/netns/netns_linux.go +++ b/pkg/netns/netns_linux.go @@ -40,6 +40,8 @@ import ( // threadNsPath is the /proc path to the current netns handle for the current thread const threadNsPath = "/proc/thread-self/ns/net" +var errNoFreeName = errors.New("failed to find free netns path name") + // GetNSRunDir returns the dir of where to create the netNS. When running // rootless, it needs to be at a location writable by user. func GetNSRunDir() (string, error) { @@ -61,12 +63,10 @@ func NewNSAtPath(nsPath string) (ns.NetNS, error) { // an object representing that namespace, without switching to it. func NewNS() (ns.NetNS, error) { for i := 0; i < 10000; i++ { - b := make([]byte, 16) - _, err := rand.Reader.Read(b) + nsName, err := getRandomNetnsName() if err != nil { - return nil, fmt.Errorf("failed to generate random netns name: %v", err) + return nil, err } - nsName := fmt.Sprintf("netns-%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) ns, err := NewNSWithName(nsName) if err == nil { return ns, nil @@ -77,7 +77,7 @@ func NewNS() (ns.NetNS, error) { } return nil, err } - return nil, errors.New("failed to find free netns path name") + return nil, errNoFreeName } // NewNSWithName creates a new persistent (bind-mounted) network namespace and returns @@ -100,6 +100,58 @@ func NewNSWithName(name string) (ns.NetNS, error) { return newNSPath(nsPath) } +// NewNSFrom creates a persistent (bind-mounted) network namespace from the +// given netns path, i.e. /proc//ns/net, and returns the new full path to +// the bind mounted file in the netns run dir. +func NewNSFrom(fromNetns string) (string, error) { + nsRunDir, err := GetNSRunDir() + if err != nil { + return "", err + } + + err = makeNetnsDir(nsRunDir) + if err != nil { + return "", err + } + + for i := 0; i < 10000; i++ { + nsName, err := getRandomNetnsName() + if err != nil { + return "", err + } + nsPath := filepath.Join(nsRunDir, nsName) + + // create an empty file to use as at the mount point + err = createNetnsFile(nsPath) + if err != nil { + // retry when the name already exists + if errors.Is(err, os.ErrExist) { + continue + } + return "", err + } + + err = unix.Mount(fromNetns, nsPath, "none", unix.MS_BIND|unix.MS_SHARED|unix.MS_REC, "") + if err != nil { + // Do not leak the ns on errors + _ = os.RemoveAll(nsPath) + return "", fmt.Errorf("failed to bind mount ns at %s: %v", nsPath, err) + } + return nsPath, nil + } + + return "", errNoFreeName +} + +func getRandomNetnsName() (string, error) { + b := make([]byte, 16) + _, err := rand.Reader.Read(b) + if err != nil { + return "", fmt.Errorf("failed to generate random netns name: %v", err) + } + return fmt.Sprintf("netns-%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil +} + func makeNetnsDir(nsRunDir string) error { err := os.MkdirAll(nsRunDir, 0o755) if err != nil { @@ -151,16 +203,22 @@ func makeNetnsDir(nsRunDir string) error { return nil } -func newNSPath(nsPath string) (ns.NetNS, error) { - // create an empty file at the mount point - mountPointFd, err := os.OpenFile(nsPath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o600) +// createNetnsFile created the file with O_EXCL to ensure there are no conflicts with others +// Callers should check for ErrExist and loop over it to find a free file. +func createNetnsFile(path string) error { + mountPointFd, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o600) if err != nil { - return nil, err + return err } - if err := mountPointFd.Close(); err != nil { + return mountPointFd.Close() +} + +func newNSPath(nsPath string) (ns.NetNS, error) { + // create an empty file to use as at the mount point + err := createNetnsFile(nsPath) + if err != nil { return nil, err } - // Ensure the mount point is cleaned up on errors; if the namespace // was successfully mounted this will have no effect because the file // is in-use