Skip to content

Commit

Permalink
Merge pull request #4359 from cyphar/rootfs-create-mountpoint-refactor
Browse files Browse the repository at this point in the history
rootfs: consolidate mountpoint creation logic
  • Loading branch information
AkihiroSuda authored Jul 29, 2024
2 parents 459ce2f + 1410a69 commit 3d7bc3b
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 91 deletions.
35 changes: 7 additions & 28 deletions libcontainer/criu_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,7 @@ func (c *Container) restoreNetwork(req *criurpc.CriuReq, criuOpts *CriuOpts) {
// restore using CRIU. This function is inspired from the code in
// rootfs_linux.go.
func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
me := mountEntry{Mount: m}
dest, err := securejoin.SecureJoin(c.config.Rootfs, m.Destination)
if err != nil {
return err
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
if err := checkProcMount(c.config.Rootfs, dest, me); err != nil {
return err
}
switch m.Device {
case "cgroup":
if m.Device == "cgroup" {
// No mount point(s) need to be created:
//
// * for v1, mount points are saved by CRIU because
Expand All @@ -542,23 +532,12 @@ func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
// * for v2, /sys/fs/cgroup is a real mount, but
// the mountpoint appears as soon as /sys is mounted
return nil
case "bind":
// For bind-mounts (unlike other filesystem types), we need to check if
// the source exists.
fi, _, err := me.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
default:
// for all other filesystems just create the mountpoints
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
me := mountEntry{Mount: m}
// For all other filesystems, just make the target.
if _, err := createMountpoint(c.config.Rootfs, me); err != nil {
return fmt.Errorf("create criu restore mountpoint for %s mount: %w", me.Destination, err)
}
return nil
}
Expand Down
135 changes: 72 additions & 63 deletions libcontainer/rootfs_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {

for _, b := range binds {
if c.cgroupns {
// We just created the tmpfs, and so we can just use filepath.Join
// here (not to mention we want to make sure we create the path
// inside the tmpfs, so we don't want to resolve symlinks).
subsystemPath := filepath.Join(c.root, b.Destination)
subsystemName := filepath.Base(b.Destination)
if err := os.MkdirAll(subsystemPath, 0o755); err != nil {
return err
}
Expand All @@ -319,7 +323,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}
var (
source = "cgroup"
data = filepath.Base(subsystemPath)
data = subsystemName
)
if data == "systemd" {
data = cgroups.CgroupNamePrefix + data
Expand Down Expand Up @@ -349,14 +353,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}

func mountCgroupV2(m *configs.Mount, c *mountConfig) error {
dest, err := securejoin.SecureJoin(c.root, m.Destination)
if err != nil {
return err
}
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
err = utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
err := utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
return mountViaFds(m.Source, nil, m.Destination, dstFd, "cgroup2", uintptr(m.Flags), m.Data)
})
if err == nil || !(errors.Is(err, unix.EPERM) || errors.Is(err, unix.EBUSY)) {
Expand Down Expand Up @@ -482,6 +479,65 @@ func statfsToMountFlags(st unix.Statfs_t) int {
return flags
}

var errRootfsToFile = errors.New("config tries to change rootfs to file")

func createMountpoint(rootfs string, m mountEntry) (string, error) {
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
if err != nil {
return "", err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return "", fmt.Errorf("check proc-safety of %s mount: %w", m.Destination, err)
}

switch m.Device {
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// Error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return "", err
}
// If the original source is not a directory, make the target a file.
if !fi.IsDir() {
// Make sure we aren't tricked into trying to make the root a file.
if rootfs == dest {
return "", fmt.Errorf("%w: file bind mount over rootfs", errRootfsToFile)
}
// Make the parent directory.
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return "", fmt.Errorf("make parent dir of file bind-mount: %w", err)
}
// Make the target file.
f, err := os.OpenFile(dest, os.O_CREATE, 0o755)
if err != nil {
return "", fmt.Errorf("create target of file bind-mount: %w", err)
}
_ = f.Close()
// Nothing left to do.
return dest, nil
}

case "tmpfs":
// If the original target exists, copy the mode for the tmpfs mount.
if stat, err := os.Stat(dest); err == nil {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt

// Nothing left to do.
return dest, nil
}
}

if err := os.MkdirAll(dest, 0o755); err != nil {
return "", err
}
return dest, nil
}

func mountToRootfs(c *mountConfig, m mountEntry) error {
rootfs := c.root

Expand All @@ -495,7 +551,7 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
// TODO: This won't be necessary once we switch to libpathrs and we can
// stop all of these symlink-exchange attacks.
dest := filepath.Clean(m.Destination)
if !strings.HasPrefix(dest, rootfs) {
if !utils.IsLexicallyInRoot(rootfs, dest) {
// Do not use securejoin as it resolves symlinks.
dest = filepath.Join(rootfs, dest)
}
Expand All @@ -516,37 +572,19 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
return mountPropagate(m, rootfs, "")
}

mountLabel := c.label
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
dest, err := createMountpoint(rootfs, m)
if err != nil {
return err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return err
return fmt.Errorf("create mountpoint for %s mount: %w", m.Destination, err)
}
mountLabel := c.label

switch m.Device {
case "mqueue":
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
if err := mountPropagate(m, rootfs, ""); err != nil {
return err
}
return label.SetFileLabel(dest, mountLabel)
case "tmpfs":
if stat, err := os.Stat(dest); err != nil {
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
} else {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt
}

if m.Extensions&configs.EXT_COPYUP == configs.EXT_COPYUP {
err = doTmpfsCopyUp(m, rootfs, mountLabel)
} else {
Expand All @@ -555,15 +593,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {

return err
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we will be
// unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
// open_tree()-related shenanigans are all handled in mountViaFds.
if err := mountPropagate(m, rootfs, mountLabel); err != nil {
return err
Expand Down Expand Up @@ -679,9 +708,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
}
return mountCgroupV1(m.Mount, c)
default:
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
return mountPropagate(m, rootfs, mountLabel)
}
}
Expand Down Expand Up @@ -899,6 +925,9 @@ func createDeviceNode(rootfs string, node *devices.Device, bind bool) error {
if err != nil {
return err
}
if dest == rootfs {
return fmt.Errorf("%w: mknod over rootfs", errRootfsToFile)
}
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return err
}
Expand Down Expand Up @@ -1169,26 +1198,6 @@ func chroot() error {
return nil
}

// createIfNotExists creates a file or a directory only if it does not already exist.
func createIfNotExists(path string, isDir bool) error {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
if isDir {
return os.MkdirAll(path, 0o755)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE, 0o755)
if err != nil {
return err
}
_ = f.Close()
}
}
return nil
}

// readonlyPath will make a path read only.
func readonlyPath(path string) error {
if err := mount(path, path, "", unix.MS_BIND|unix.MS_REC, ""); err != nil {
Expand Down
15 changes: 15 additions & 0 deletions libcontainer/utils/utils_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
_ "unsafe" // for go:linkname

Expand Down Expand Up @@ -260,3 +261,17 @@ func ProcThreadSelf(subpath string) (string, ProcThreadSelfCloser) {
func ProcThreadSelfFd(fd uintptr) (string, ProcThreadSelfCloser) {
return ProcThreadSelf("fd/" + strconv.FormatUint(uint64(fd), 10))
}

// IsLexicallyInRoot is shorthand for strings.HasPrefix(path+"/", root+"/"),
// but properly handling the case where path or root are "/".
//
// NOTE: The return value only make sense if the path doesn't contain "..".
func IsLexicallyInRoot(root, path string) bool {
if root != "/" {
root += "/"
}
if path != "/" {
path += "/"
}
return strings.HasPrefix(path, root)
}

0 comments on commit 3d7bc3b

Please sign in to comment.