Skip to content

Commit

Permalink
netlink: MVP EngineThreadUnsafe
Browse files Browse the repository at this point in the history
Doesn't work with NetNS
  • Loading branch information
svinota committed May 4, 2024
1 parent cd3377d commit 8b37520
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions pyroute2/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,12 +749,16 @@ def consume(self, block=False):
else:
flags = MSG_DONTWAIT
data = bytearray(16384)
bufsize, _ = self.socket._sock.recvfrom_into(
log.debug("consume, block=%s", block)
if hasattr(self.socket._sock, 'recvfrom_into'):
bufsize, _ = self.socket._sock.recvfrom_into(
data, 0, flags | MSG_PEEK | MSG_TRUNC
)
self.socket._sock.recvfrom_into(
data, bufsize, flags
)
)
log.debug("consume bufsize=%s", bufsize)
ret = self.socket._sock.recvfrom_into(data, bufsize, flags)
log.debug("consume return %s", ret)
else:
data = self.socket.recv(16843, flags)
return data

def get(
Expand All @@ -765,7 +769,9 @@ def get(
callback=None,
noraise=False,
):
log.debug("get: %s / %s / %s / %s", msg_seq, terminate, callback, noraise)
enough = False
started = False
while not enough:
# step 1. receive as much as we can from the socket
while True:
Expand All @@ -778,27 +784,31 @@ def get(
# step 2. fetch one data block from the buffer
data = self.buffer.pop(0)
# step 3. parse the data block
messages = tuple(
self.marshal.parse(data, msg_seq, callback)
)
last = messages[-1]
messages = tuple(self.marshal.parse(data, msg_seq, callback))
for msg in messages:
if msg['header']['sequence_number'] != msg_seq:
if msg_seq > 0 and msg['header']['sequence_number'] != msg_seq:
continue
msg['header']['target'] = self.target
msg['header']['stats'] = Stats(0, 0, 0)
started = True
log.debug("yield %s", msg['header'])
log.debug("message %s", msg)
yield msg

if last['header']['type'] == NLMSG_DONE:
if started and msg['header']['type'] == NLMSG_DONE:
break

if (
if started and (
(msg_seq == 0)
or (not last['header']['flags'] & NLM_F_MULTI)
or (callable(terminate) and terminate(last))
or (not msg['header']['flags'] & NLM_F_MULTI)
or (callable(terminate) and terminate(msg))
):
enough = True

# drop orphaned NLMSG_ERROR
if msg['header']['type'] == NLMSG_ERROR:
continue


class NetlinkSocketBase:
'''
Expand Down Expand Up @@ -1457,13 +1467,9 @@ def bind(self, groups=0, pid=None, **kwarg):
else:
raise KeyError('no free address available')
# all is OK till now, so start async recv, if we need
if async_cache:
self.buffer_thread = threading.Thread(
name="Netlink async cache", target=self.buffer_thread_routine
)
self.input_from_buffer_queue = True
self.buffer_thread.daemon = True
self.buffer_thread.start()
#if async_cache:
# self.buffer_thread.daemon = True
# self.buffer_thread.start()

def add_membership(self, group):
self.setsockopt(SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, group)
Expand Down

0 comments on commit 8b37520

Please sign in to comment.