Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rootfs: consolidate mountpoint creation logic #4359

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 7 additions & 28 deletions libcontainer/criu_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,7 @@ func (c *Container) restoreNetwork(req *criurpc.CriuReq, criuOpts *CriuOpts) {
// restore using CRIU. This function is inspired from the code in
// rootfs_linux.go.
func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
me := mountEntry{Mount: m}
dest, err := securejoin.SecureJoin(c.config.Rootfs, m.Destination)
if err != nil {
return err
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
if err := checkProcMount(c.config.Rootfs, dest, me); err != nil {
return err
}
switch m.Device {
case "cgroup":
if m.Device == "cgroup" {
// No mount point(s) need to be created:
//
// * for v1, mount points are saved by CRIU because
Expand All @@ -542,23 +532,12 @@ func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
// * for v2, /sys/fs/cgroup is a real mount, but
// the mountpoint appears as soon as /sys is mounted
return nil
case "bind":
// For bind-mounts (unlike other filesystem types), we need to check if
// the source exists.
fi, _, err := me.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
default:
// for all other filesystems just create the mountpoints
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
me := mountEntry{Mount: m}
// For all other filesystems, just make the target.
if _, err := createMountpoint(c.config.Rootfs, me); err != nil {
return fmt.Errorf("create criu restore mountpoint for %s mount: %w", me.Destination, err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to use %q for the destination here? It quotes the string, in case it has spaces and so

}
return nil
}
Expand Down
135 changes: 72 additions & 63 deletions libcontainer/rootfs_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {

for _, b := range binds {
if c.cgroupns {
// We just created the tmpfs, and so we can just use filepath.Join
// here (not to mention we want to make sure we create the path
// inside the tmpfs, so we don't want to resolve symlinks).
subsystemPath := filepath.Join(c.root, b.Destination)
subsystemName := filepath.Base(b.Destination)
if err := os.MkdirAll(subsystemPath, 0o755); err != nil {
return err
}
Expand All @@ -319,7 +323,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}
var (
source = "cgroup"
data = filepath.Base(subsystemPath)
data = subsystemName
)
if data == "systemd" {
data = cgroups.CgroupNamePrefix + data
Expand Down Expand Up @@ -349,14 +353,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}

func mountCgroupV2(m *configs.Mount, c *mountConfig) error {
dest, err := securejoin.SecureJoin(c.root, m.Destination)
if err != nil {
return err
}
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
err = utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
err := utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
return mountViaFds(m.Source, nil, m.Destination, dstFd, "cgroup2", uintptr(m.Flags), m.Data)
})
if err == nil || !(errors.Is(err, unix.EPERM) || errors.Is(err, unix.EBUSY)) {
Expand Down Expand Up @@ -482,6 +479,65 @@ func statfsToMountFlags(st unix.Statfs_t) int {
return flags
}

var errRootfsToFile = errors.New("config tries to change rootfs to file")

func createMountpoint(rootfs string, m mountEntry) (string, error) {
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
if err != nil {
return "", err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return "", fmt.Errorf("check proc-safety of %s mount: %w", m.Destination, err)
}

switch m.Device {
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// Error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return "", err
}
// If the original source is not a directory, make the target a file.
if !fi.IsDir() {
// Make sure we aren't tricked into trying to make the root a file.
if rootfs == dest {
return "", fmt.Errorf("%w: file bind mount over rootfs", errRootfsToFile)
}
// Make the parent directory.
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return "", fmt.Errorf("make parent dir of file bind-mount: %w", err)
}
// Make the target file.
f, err := os.OpenFile(dest, os.O_CREATE, 0o755)
if err != nil {
return "", fmt.Errorf("create target of file bind-mount: %w", err)
}
_ = f.Close()
// Nothing left to do.
return dest, nil
}

case "tmpfs":
// If the original target exists, copy the mode for the tmpfs mount.
if stat, err := os.Stat(dest); err == nil {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt

// Nothing left to do.
return dest, nil
}
}

if err := os.MkdirAll(dest, 0o755); err != nil {
return "", err
}
return dest, nil
}

func mountToRootfs(c *mountConfig, m mountEntry) error {
rootfs := c.root

Expand All @@ -495,7 +551,7 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
// TODO: This won't be necessary once we switch to libpathrs and we can
// stop all of these symlink-exchange attacks.
dest := filepath.Clean(m.Destination)
if !strings.HasPrefix(dest, rootfs) {
if !utils.IsLexicallyInRoot(rootfs, dest) {
// Do not use securejoin as it resolves symlinks.
dest = filepath.Join(rootfs, dest)
}
Expand All @@ -516,37 +572,19 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
return mountPropagate(m, rootfs, "")
}

mountLabel := c.label
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
dest, err := createMountpoint(rootfs, m)
if err != nil {
return err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return err
return fmt.Errorf("create mountpoint for %s mount: %w", m.Destination, err)
}
mountLabel := c.label

switch m.Device {
case "mqueue":
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
if err := mountPropagate(m, rootfs, ""); err != nil {
return err
}
return label.SetFileLabel(dest, mountLabel)
case "tmpfs":
if stat, err := os.Stat(dest); err != nil {
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
} else {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt
}

if m.Extensions&configs.EXT_COPYUP == configs.EXT_COPYUP {
err = doTmpfsCopyUp(m, rootfs, mountLabel)
} else {
Expand All @@ -555,15 +593,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {

return err
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we will be
// unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
// open_tree()-related shenanigans are all handled in mountViaFds.
if err := mountPropagate(m, rootfs, mountLabel); err != nil {
return err
Expand Down Expand Up @@ -679,9 +708,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
}
return mountCgroupV1(m.Mount, c)
default:
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
return mountPropagate(m, rootfs, mountLabel)
}
}
Expand Down Expand Up @@ -899,6 +925,9 @@ func createDeviceNode(rootfs string, node *devices.Device, bind bool) error {
if err != nil {
return err
}
if dest == rootfs {
return fmt.Errorf("%w: mknod over rootfs", errRootfsToFile)
}
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return err
}
Expand Down Expand Up @@ -1169,26 +1198,6 @@ func chroot() error {
return nil
}

// createIfNotExists creates a file or a directory only if it does not already exist.
func createIfNotExists(path string, isDir bool) error {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
if isDir {
return os.MkdirAll(path, 0o755)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE, 0o755)
if err != nil {
return err
}
_ = f.Close()
}
}
return nil
}

// readonlyPath will make a path read only.
func readonlyPath(path string) error {
if err := mount(path, path, "", unix.MS_BIND|unix.MS_REC, ""); err != nil {
Expand Down
15 changes: 15 additions & 0 deletions libcontainer/utils/utils_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
_ "unsafe" // for go:linkname

Expand Down Expand Up @@ -260,3 +261,17 @@ func ProcThreadSelf(subpath string) (string, ProcThreadSelfCloser) {
func ProcThreadSelfFd(fd uintptr) (string, ProcThreadSelfCloser) {
return ProcThreadSelf("fd/" + strconv.FormatUint(uint64(fd), 10))
}

// IsLexicallyInRoot is shorthand for strings.HasPrefix(path+"/", root+"/"),
// but properly handling the case where path or root are "/".
//
// NOTE: The return value only make sense if the path doesn't contain "..".
func IsLexicallyInRoot(root, path string) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have unit tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with adding them later, if that helps.

if root != "/" {
root += "/"
}
if path != "/" {
path += "/"
}
return strings.HasPrefix(path, root)
}
Loading