Skip to content

Commit

Permalink
Merge pull request #209 from l1b0k/ipforward
Browse files Browse the repository at this point in the history
cni: set ip_forward when call cni check
  • Loading branch information
l1b0k authored May 21, 2021
2 parents c3dc842 + d909e14 commit 728f9bf
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 108 deletions.
66 changes: 7 additions & 59 deletions pkg/sysctl/sysctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,15 @@
package sysctl

import (
"fmt"
"io"
"bytes"
"io/ioutil"
"os"
"path/filepath"
"strings"
)

const (
prefixDir = "/proc/sys"
)

// Setting represents a sysctl setting. Its purpose it to be able to iterate
// over a slice of settings.
type Setting struct {
Name string
Val string
IgnoreErr bool
}

func fullPath(name string) string {
return filepath.Join(prefixDir, strings.Replace(name, ".", "/", -1))
}

func writeSysctl(name string, value string) error {
fPath := fullPath(name)
f, err := os.OpenFile(fPath, os.O_RDWR, 0644)
if err != nil {
return fmt.Errorf("could not open the sysctl file %s: %s",
fPath, err)
}
defer f.Close()
if _, err := io.WriteString(f, value); err != nil {
return fmt.Errorf("could not write to the systctl file %s: %s",
fPath, err)
func EnsureConf(fPath string, cfg string) error {
if content, err := ioutil.ReadFile(fPath); err == nil {
if bytes.Equal(bytes.TrimSpace(content), []byte(cfg)) {
return nil
}
}
return nil
}

// Disable disables the given sysctl parameter.
func Disable(name string) error {
return writeSysctl(name, "0")
}

// Enable enables the given sysctl parameter.
func Enable(name string) error {
return writeSysctl(name, "1")
}

// Write writes the given sysctl parameter.
func Write(name string, val string) error {
return writeSysctl(name, val)
}

// Read reads the given sysctl parameter.
func Read(name string) (string, error) {
fPath := fullPath(name)
val, err := ioutil.ReadFile(fPath)
if err != nil {
return "", fmt.Errorf("Failed to read %s: %s", fPath, val)
}

return strings.TrimRight(string(val), "\n"), nil
return ioutil.WriteFile(fPath, []byte(cfg), 0644)
}
1 change: 1 addition & 0 deletions plugin/driver/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"

terwayTypes "github.com/AliyunContainerService/terway/types"

"github.com/containernetworking/cni/pkg/types"
"github.com/containernetworking/plugins/pkg/ns"
)
Expand Down
12 changes: 2 additions & 10 deletions plugin/driver/ipvlan.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strconv"
"syscall"

"github.com/AliyunContainerService/terway/pkg/sysctl"
terwayTypes "github.com/AliyunContainerService/terway/types"

"github.com/containernetworking/plugins/pkg/ns"
Expand Down Expand Up @@ -215,15 +214,8 @@ func (d *IPvlanDriver) Check(cfg *CheckConfig) error {
Log.Debugf("route is changed")
cfg.RecordPodEvent("default route is updated")
}
err = sysctl.Enable(fmt.Sprintf("net.ipv4.conf.%s.forwarding", cfg.ContainerIFName))
if err != nil {
return err
}
err = sysctl.Disable(fmt.Sprintf("net.ipv4.conf.%s.rp_filter", cfg.ContainerIFName))
if err != nil {
return err
}
return nil

return EnsureNetConfSet(true, false)
})
if err != nil {
return err
Expand Down
12 changes: 2 additions & 10 deletions plugin/driver/raw_nic.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"math/rand"
"time"

"github.com/AliyunContainerService/terway/pkg/sysctl"
"github.com/containernetworking/plugins/pkg/ns"
"github.com/vishvananda/netlink"
)
Expand Down Expand Up @@ -163,15 +162,8 @@ func (r *RawNicDriver) Check(cfg *CheckConfig) error {
Log.Debugf("route is changed")
cfg.RecordPodEvent("default route is updated")
}
err = sysctl.Enable(fmt.Sprintf("net.ipv4.conf.%s.forwarding", cfg.ContainerIFName))
if err != nil {
return err
}
err = sysctl.Disable(fmt.Sprintf("net.ipv4.conf.%s.rp_filter", cfg.ContainerIFName))
if err != nil {
return err
}
return nil

return EnsureNetConfSet(true, false)
})
return nil
}
Expand Down
77 changes: 48 additions & 29 deletions plugin/driver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import (
"time"

terwayIP "github.com/AliyunContainerService/terway/pkg/ip"
terwaySysctl "github.com/AliyunContainerService/terway/pkg/sysctl"
terwayTypes "github.com/AliyunContainerService/terway/types"
k8snet "k8s.io/apimachinery/pkg/util/net"

"github.com/containernetworking/cni/pkg/types"
"github.com/containernetworking/plugins/pkg/ip"
"github.com/containernetworking/plugins/pkg/utils/sysctl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
k8snet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/wait"
)

Expand Down Expand Up @@ -101,37 +101,18 @@ func getRouteTableID(linkIndex int) int {
return 1000 + linkIndex
}

const rpFilterSysctl = "net.ipv4.conf.%s.rp_filter"

// EnsureHostNsConfig setup host namespace configs
func EnsureHostNsConfig() error {
existInterfaces, err := net.Interfaces()
if err != nil {
return fmt.Errorf("error get network interfaces, %w", err)
}

for _, key := range []string{"default", "all"} {
sysctlName := fmt.Sprintf(rpFilterSysctl, key)
if _, err = sysctl.Sysctl(sysctlName, "0"); err != nil {
return fmt.Errorf("error set: %s sysctl value to 0, %w", sysctlName, err)
}

}

for _, existIf := range existInterfaces {
sysctlName := fmt.Sprintf(rpFilterSysctl, existIf.Name)
sysctlValue, err := sysctl.Sysctl(sysctlName)
if err != nil {
continue
}
if sysctlValue != "0" {
if _, err = sysctl.Sysctl(sysctlName, "0"); err != nil {
return fmt.Errorf("error set: %s sysctl value to 0, %w", sysctlName, err)
for _, cfg := range ipv4NetConfig {
err := terwaySysctl.EnsureConf(fmt.Sprintf(cfg[0], key), cfg[1])
if err != nil {
return err
}
}

}
return nil

return EnsureNetConfSet(true, false)
}

// EnsureLinkUp set link up,return changed and err
Expand Down Expand Up @@ -550,11 +531,11 @@ func EnsurePolicyRule(link netlink.Link, ipNetSet *terwayTypes.IPNetSet, tableID
}

func EnableIPv6() error {
_, err := sysctl.Sysctl("net.ipv6.conf.all.disable_ipv6", "0")
err := terwaySysctl.EnsureConf("/proc/sys/net/ipv6/conf/all/disable_ipv6", "0")
if err != nil {
return err
}
_, err = sysctl.Sysctl("net.ipv6.conf.default.disable_ipv6", "0")
err = terwaySysctl.EnsureConf("/proc/sys/net/ipv6/conf/default/disable_ipv6", "0")
if err != nil {
return err
}
Expand Down Expand Up @@ -597,6 +578,44 @@ func GetHostIP(ipv4, ipv6 bool) (*terwayTypes.IPNetSet, error) {
}, nil
}

var ipv4NetConfig = [][]string{
{"/proc/sys/net/ipv4/conf/%s/forwarding", "1"},
{"/proc/sys/net/ipv4/conf/%s/rp_filter", "0"},
}

var ipv6NetConfig = [][]string{
{"/proc/sys/net/ipv6/conf/%s/forwarding", "1"},
{"/proc/sys/net/ipv6/conf/%s/disable_ipv6", "0"},
}

// EnsureNetConfSet will set net config to all link
func EnsureNetConfSet(ipv4, ipv6 bool) error {
links, err := netlink.LinkList()
if err != nil {
return err
}

for _, link := range links {
if ipv4 {
for _, cfg := range ipv4NetConfig {
innerErr := terwaySysctl.EnsureConf(fmt.Sprintf(cfg[0], link.Attrs().Name), cfg[1])
if innerErr != nil {
err = fmt.Errorf("%v, %w", err, innerErr)
}
}
}
if ipv6 {
for _, cfg := range ipv6NetConfig {
innerErr := terwaySysctl.EnsureConf(fmt.Sprintf(cfg[0], link.Attrs().Name), cfg[1])
if innerErr != nil {
err = fmt.Errorf("%v, %w", err, innerErr)
}
}
}
}
return err
}

func EnsureNeighbor(link netlink.Link, hostIPSet *terwayTypes.IPNetSet) (bool, error) {
var changed bool
var err error
Expand Down
2 changes: 2 additions & 0 deletions plugin/terway/cni.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,8 @@ func cmdCheck(args *skel.CmdArgs) error {
logger.Debugf("args: %s", driver.JSONStr(args))
logger.Debugf("ns %s , k8s %s, cni std %s", cniNetns.Path(), driver.JSONStr(k8sConfig), driver.JSONStr(conf))

_ = driver.EnsureHostNsConfig()

terwayBackendClient, closeConn, err := getNetworkClient()
if err != nil {
return errors.Wrap(err, fmt.Sprintf("add cmd: create grpc client, pod: %s-%s",
Expand Down

0 comments on commit 728f9bf

Please sign in to comment.