Skip to content

Commit

Permalink
netns: obsolete IPC version
Browse files Browse the repository at this point in the history
Send Netlink socket FD back from a network namespace to the parent
process instead of running a proxy process in the namespace.
  • Loading branch information
svinota committed Jun 21, 2024
1 parent 9e2e818 commit cc308a6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 47 deletions.
54 changes: 18 additions & 36 deletions pyroute2/iproute/linux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
import json
import logging
import os
import time
Expand All @@ -26,7 +25,6 @@
NetlinkError,
SkipInode,
)
from pyroute2.netlink.nlsocket import IPCSocket
from pyroute2.netlink.rtnl import (
RTM_DELADDR,
RTM_DELLINK,
Expand Down Expand Up @@ -61,6 +59,7 @@
RTM_NEWTCLASS,
RTM_NEWTFILTER,
RTM_SETLINK,
RTMGRP_DEFAULTS,
RTMGRP_IPV4_IFADDR,
RTMGRP_IPV4_ROUTE,
RTMGRP_IPV4_RULE,
Expand All @@ -85,7 +84,6 @@
IPBatchSocket,
IPRSocket,
)
from pyroute2.netlink.rtnl.marshal import MarshalRtnl
from pyroute2.netlink.rtnl.ndtmsg import ndtmsg
from pyroute2.netlink.rtnl.nsidmsg import nsidmsg
from pyroute2.netlink.rtnl.nsinfmsg import nsinfmsg
Expand All @@ -94,9 +92,6 @@
from pyroute2.netlink.rtnl.rtmsg import rtmsg
from pyroute2.netlink.rtnl.tcmsg import plugins as tc_plugins
from pyroute2.netlink.rtnl.tcmsg import tcmsg
from pyroute2.netns import setns
from pyroute2.plan9 import Tcall
from pyroute2.plan9.server import Plan9Server
from pyroute2.requests.address import AddressFieldFilter, AddressIPRouteFilter
from pyroute2.requests.bridge import (
BridgeFieldFilter,
Expand Down Expand Up @@ -2603,36 +2598,23 @@ class IPRoute(LAB_API, RTNL_API, IPRSocket):
pass


def ipr_call(session, inode, req, response):
arg = json.loads(req['text'])
data = req['data']
if data:
arg['kwarg']['data'] = data
response['err'] = 0
ret = getattr(session.ipr, arg['call'])(*arg['argv'], **arg['kwarg'])
if isinstance(ret, bytes):
response['data'] = ret
response['text'] = ''
else:
response['data'] = b''
response['text'] = json.dumps(ret)
return response


class NetNS(RTNL_API, IPCSocket):

def __init__(self, netns):
super().__init__(target=netns)
self.marshal = MarshalRtnl()

def ipc_server(self):
setns(self.status['target'])
p9 = Plan9Server(use_socket=self.socket.server)
p9.filesystem.create('call').add_callback(Tcall, ipr_call)
p9.filesystem.create('log')
connection = p9.accept()
connection.session.ipr = IPRoute()
connection.serve()
class NetNS(IPRoute):

def __init__(
self,
netns=None,
flags=os.O_CREAT,
target='localhost',
libc=None,
groups=RTMGRP_DEFAULTS,
):
super().__init__(
target=netns if netns is not None else target,
netns=netns,
flags=flags,
libc=libc,
groups=groups,
)


class RawIPRoute(RTNL_API, RawIPRSocket):
Expand Down
73 changes: 64 additions & 9 deletions pyroute2/netlink/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import collections
import errno
import logging
import multiprocessing
import os
import socket
from urllib import parse

from pyroute2 import config
from pyroute2.common import AddrPool
from pyroute2.netlink import NLM_F_MULTI, NLMSG_DONE
from pyroute2.netns import setns
from pyroute2.requests.main import RequestProcessor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,7 +71,6 @@ def put_nowait(self, tag, message):


class CoreProtocol(asyncio.Protocol):

def __init__(self, on_con_lost, enqueue):
self.transport = None
self.enqueue = enqueue
Expand All @@ -77,11 +79,30 @@ def __init__(self, on_con_lost, enqueue):
def connection_made(self, transport):
self.transport = transport

def connection_lost(self, exc):
self.on_con_lost.set_result(True)


class CoreStreamProtocol(CoreProtocol):

def data_received(self, data):
log.debug('SOCK_STREAM enqueue %s bytes' % len(data))
self.enqueue(data, None)


class CoreDatagramProtocol(CoreProtocol):

def datagram_received(self, data, addr):
log.debug('SOCK_DGRAM enqueue %s bytes' % len(data))
self.enqueue(data, addr)

def connection_lost(self, exc):
self.on_con_lost.set_result(True)

def netns_init(ctl, nsname, cls):
setns(nsname)
s = cls()
print(" <<< ", s)
socket.send_fds(ctl, [b'test'], [s.socket.fileno()], 1)
print(" done ")


class CoreSocket:
Expand All @@ -91,26 +112,55 @@ class CoreSocket:
communications, both Netlink and internal RPC.
'''

libc = None
socket = None
compiled = None
endpoint = None
event_loop = None
__spec = None

def __init__(self, target='localhost', rcvsize=16384, use_socket=None):
def __init__(
self,
target='localhost',
rcvsize=16384,
use_socket=None,
netns=None,
flags=os.O_CREAT,
libc=None,
groups=0,
):
# 8<-----------------------------------------
self.spec = CoreSocketSpec(
{
'target': target,
'use_socket': use_socket is not None,
'rcvsize': rcvsize,
'netns': netns,
'flags': flags,
'groups': groups,
}
)
if libc is not None:
self.libc = libc
self.status = self.spec.status
url = parse.urlparse(self.status['target'])
self.scheme = url.scheme if url.scheme else url.path
self.use_socket = use_socket
# 8<-----------------------------------------
# Setup netns
if self.spec['netns'] is not None:
# inspect self.__init__ argument names
ctrl = socket.socketpair()
nsproc = multiprocessing.Process(
target=netns_init,
args=(ctrl[0], self.spec['netns'], type(self)),
)
nsproc.start()
(_, (self.spec['fileno'],), _, _) = socket.recv_fds(
ctrl[1], 1024, 1
)
nsproc.join()
# 8<-----------------------------------------
self.callbacks = [] # [(predicate, callback, args), ...]
self.addr_pool = AddrPool(minaddr=0x000000FF, maxaddr=0x0000FFFF)
self.marshal = None
Expand All @@ -119,8 +169,13 @@ def __init__(self, target='localhost', rcvsize=16384, use_socket=None):
# 8<-----------------------------------------
# Setup the underlying socket
self.socket = self.setup_socket()
self.msg_queue = CoreMessageQueue()
self.event_loop = self.setup_event_loop()
self.event_loop.run_until_complete(self.setup_endpoint())
self.connection_lost = self.event_loop.create_future()
if self.event_loop.is_running():
asyncio.ensure_future(self.setup_endpoint())
else:
self.event_loop.run_until_complete(self.setup_endpoint())

def get_loop(self):
return self.event_loop
Expand All @@ -138,19 +193,19 @@ async def setup_endpoint(self, loop=None):
# Setup asyncio
if self.endpoint is not None:
return
self.msg_queue = CoreMessageQueue()
self.connection_lost = self.event_loop.create_future()
self.endpoint = await self.event_loop.create_datagram_endpoint(
lambda: CoreProtocol(self.connection_lost, self.enqueue),
self.endpoint = await self.event_loop.connect_accepted_socket(
lambda: CoreStreamProtocol(self.connection_lost, self.enqueue),
sock=self.socket,
)

def setup_event_loop(self, event_loop=None):
if event_loop is None:
try:
event_loop = asyncio.get_running_loop()
self.status['event_loop'] = 'auto'
except RuntimeError:
event_loop = asyncio.new_event_loop()
self.status['event_loop'] = 'new'
return event_loop

def setup_socket(self, sock=None):
Expand Down
22 changes: 20 additions & 2 deletions pyroute2/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@
SOL_NETLINK,
nlmsg,
)
from pyroute2.netlink.core import CoreSocket, CoreSocketSpec
from pyroute2.netlink.core import (
CoreDatagramProtocol,
CoreSocket,
CoreSocketSpec,
)
from pyroute2.netlink.exceptions import (
ChaoticException,
NetlinkDumpInterrupted,
Expand Down Expand Up @@ -198,6 +202,9 @@ def __init__(
groups=0,
nlm_echo=False,
use_socket=None,
netns=None,
flags=os.O_CREAT,
libc=None,
):
# 8<-----------------------------------------
self.spec = NetlinkSocketSpec(
Expand All @@ -219,6 +226,8 @@ def __init__(
'nlm_echo': nlm_echo,
'use_socket': use_socket is not None,
'tag_field': 'sequence_number',
'netns': netns,
'flags': flags,
}
)
# TODO: merge capabilities to self.status
Expand All @@ -228,9 +237,18 @@ def __init__(
'create_dummy': True,
'provide_master': config.kernel[0] > 2,
}
super().__init__()
super().__init__(libc=libc)
self.marshal = Marshal()

async def setup_endpoint(self, loop=None):
# Setup asyncio
if self.endpoint is not None:
return
self.endpoint = await self.event_loop.create_datagram_endpoint(
lambda: CoreDatagramProtocol(self.connection_lost, self.enqueue),
sock=self.socket,
)

@property
def uname(self):
return self.status['uname']
Expand Down

0 comments on commit cc308a6

Please sign in to comment.