Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/radical/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,13 +830,21 @@ def ru_open(*args, **kwargs):

# ------------------------------------------------------------------------------
#
def find_port(port_min=10000, port_max=65535):
def find_port(port_min: int = None,
port_max: int = None) -> Union[int, None]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add iface?

'''
Find a free port in the given range. The range defaults to 10000-65535.
Returns `None` if no free port could be found.
'''

for port in range(port_min, port_max):
if port_min is None:
port_min = 10000
if port_max is None:
port_max = 65535

for port in range(port_min, port_max + 1):
time.sleep(0.1)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('', port))
Expand Down
2 changes: 1 addition & 1 deletion src/radical/utils/zmq/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, url : Optional[str] = None,
path : Optional[str] = None,
persistent: bool = False) -> None:

super().__init__(url=url, uid=uid, path=path)
super().__init__(port=url, uid=uid, path=path)

if persistent:
path = '%s/%s.db' % (self._path, self._uid)
Expand Down
127 changes: 22 additions & 105 deletions src/radical/utils/zmq/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..debug import get_exception_trace
from ..serialize import to_msgpack, from_msgpack

from .utils import no_intr
from .utils import no_intr, zmq_bind


# --------------------------------------------------------------------------
Expand All @@ -30,14 +30,13 @@ class Server(object):

# --------------------------------------------------------------------------
#
def __init__(self, url: Optional[str] = None,
uid: Optional[str] = None,
path: Optional[str] = None) -> None:
def __init__(self, port: Optional[Union[int, str]] = None,
uid : Optional[str] = None,
path: Optional[str] = None) -> None:

# this server offers only synchronous communication: a request will be
# worked upon and answered before the next request is received.

self._url = url
self._cbs = dict()
self._path = path

Expand All @@ -58,91 +57,20 @@ def __init__(self, url: Optional[str] = None,
self.register_request('echo', self._request_echo)
self.register_request('fail', self._request_fail)

if not self._url:
self._url = 'tcp://*:10000-11000'

# URLs can specify port ranges to use - check if that is the case (see
# default above) and initilize iterator. The URL is expected to have
# the form:
#
# <proto>://<iface>:<ports>/
#
# where
# <proto>: any protocol accepted by zmq, defaults to `tcp`
# <iface>: IP number of interface to bind to defaults to `*`
# <ports>: port range to find port to bind to defaults to `*`
#
# The port range can be formed as:
#
# '*' : any port
# '100+' : any port equal or larger than 100
# '100-' : any port equal or larger than 100
# '100-110': any port equal or larger than 100, up to 110
tmp = self._url.split(':', 2)
assert len(tmp) == 3
self._proto = tmp[0]
self._iface = tmp[1].lstrip('/')
self._ports = tmp[2].replace('+', '-')

tmp = self._ports.split('-')

self._port_this : Union[int, str, None] = None
self._port_start: Optional[int]
self._port_stop : Optional[int]

if len(tmp) == 0:
self._port_start = 1
self._port_stop = None
elif len(tmp) == 1:
if tmp[0] == '*':
self._port_this = '*'
self._port_start = None
self._port_stop = None
else:
self._port_start = int(tmp[0])
self._port_stop = int(tmp[0])
elif len(tmp) == 2:
if tmp[0]: self._port_start = int(tmp[0])
else : self._port_start = 1
if tmp[1]: self._port_stop = int(tmp[1])
else : self._port_stop = None
else:
raise RuntimeError('cannot parse port spec %s' % self._ports)


# --------------------------------------------------------------------------
#
def _iterate_ports(self) -> Iterator[Union[int, str, None]]:

if self._port_this == '*':
# leave scanning to zmq
yield self._port_this

if self._port_this is None:
# initialize range iterator
self._port_this = self._port_start

if self._port_stop is None:
while True:
yield self._port_this
self._port_this += 1

# FIXME: interpret hostname part as specification for the interface to
# be used.
# `ports` can specify as port ranges to use - check if that is the case
if port is None : pmin = pmax = None
elif isinstance(port, str):
if '-' in port : pmin, pmax = port.split('-', 1)
else : pmin = pmax = port
elif isinstance(port, int): pmin = pmax = port
else:
# make type checker happy
assert isinstance(self._port_this, int)
assert isinstance(self._port_start, int)
raise ValueError('invalid port specification: %s' % str(port))

while self._port_this <= self._port_stop:
yield self._port_this
self._port_this += 1


# --------------------------------------------------------------------------
#
def _iterate_urls(self) -> Iterator[str]:

for port in self._iterate_ports():
yield '%s://%s:%s' % (self._proto, self._iface, port)
self._pmin = int(pmin) if pmin else None
self._pmax = int(pmax) if pmax else None


# --------------------------------------------------------------------------
Expand All @@ -163,7 +91,7 @@ def addr(self) -> Optional[str]:
#
def start(self) -> None:

self._log.info('start bridge %s', self._uid)
self._log.info('start server %s', self._uid)

if self._thread:
raise RuntimeError('`start()` can be called only once')
Expand All @@ -179,20 +107,20 @@ def start(self) -> None:
#
def stop(self) -> None:

self._log.info('stop bridge %s', self._uid)
self._log.info('stop server %s', self._uid)
self._term.set()


# --------------------------------------------------------------------------
#
def wait(self) -> None:

self._log.info('wait bridge %s', self._uid)
self._log.info('wait server %s', self._uid)

if self._thread:
self._thread.join()

self._log.info('wait bridge %s', self._uid)
self._log.info('wait server %s', self._uid)


# --------------------------------------------------------------------------
Expand Down Expand Up @@ -249,20 +177,9 @@ def _work(self) -> None:
self._sock.linger = _LINGER_TIMEOUT
self._sock.hwm = _HIGH_WATER_MARK

for url in self._iterate_urls():
try:
self._log.debug('try url %s', url)
self._sock.bind(url)
self._log.debug('success')
break
except zmq.error.ZMQError as e:
if 'Address already in use' in str(e):
self._log.warn('port in use - try next (%s)' % url)
else:
raise

addr = Url(as_string(self._sock.getsockopt(zmq.LAST_ENDPOINT)))
addr.host = get_hostip()
addr = zmq_bind(self._sock, port_min=self._pmin,
port_max=self._pmax)
assert addr
self._addr = str(addr)

self._up.set()
Expand Down
18 changes: 14 additions & 4 deletions src/radical/utils/zmq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,28 @@ def sock_connect(sock, url, hop=None):

# ------------------------------------------------------------------------------
#
def zmq_bind(sock):
def zmq_bind(sock, port_min: int = None,
port_max: int = None) -> Url:

while True:
port = find_port()
prev = -1
port = find_port(port_min, port_max)

while prev != port:

if not port:
raise RuntimeError('no port found in range %s - %s'
% (port_min, port_max))
try:
sock.bind('tcp://*:%s' % port)
addr = Url(as_string(sock.getsockopt(zmq.LAST_ENDPOINT)))
addr.host = get_hostip()
return addr
except:
except Exception as e:
pass

prev = port
port = find_port(port_min, port_max)

raise RuntimeError('could not bind to any port')


Expand Down
19 changes: 7 additions & 12 deletions tests/unittests/test_zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_init(self, mocked_profiler, mocked_logger):
s = Server()
self.assertTrue (s.uid.startswith('server'))
self.assertIsNone(s.addr)
self.assertEqual (s._url, 'tcp://*:10000-11000')
self.assertEqual (s._pmin, None)
self.assertEqual (s._pmax, None)

self.assertFalse(s._up.is_set())
self.assertFalse(s._term.is_set())
Expand All @@ -71,13 +72,9 @@ def test_init(self, mocked_profiler, mocked_logger):
s = Server(uid=uid)
self.assertEqual(s.uid, uid)

with self.assertRaises(AssertionError):
# port(s) not set
Server(url='tcp://*')

with self.assertRaises(RuntimeError):
with self.assertRaises(ValueError):
# port(s) set incorrectly
Server(url='tcp://*:10000-11000-22000')
Server(port='10000-11000-22000')

# default callbacks
self.assertIn('echo', s._cbs)
Expand Down Expand Up @@ -117,7 +114,7 @@ def test_exec_output(self, mocked_profiler, mocked_logger, mocked_init):
@mock.patch('radical.utils.zmq.server.Profiler')
def test_start(self, mocked_profiler, mocked_logger):

s = Server(url='tcp://*:12345')
s = Server(port=12345)
s.start()
self.assertTrue(s.addr.endswith('12345'))

Expand All @@ -127,10 +124,8 @@ def test_start(self, mocked_profiler, mocked_logger):
# `start()` can be called only once
s.start()

s2 = Server(url='tcp://*:12345-')
s2 = Server(port='12345-')
s2.start()
# logged warning about port already in use
self.assertTrue(s2._log.warn.called)
self.assertTrue(s2.addr.endswith('12346'))

s2.stop()
Expand All @@ -152,7 +147,7 @@ def test_zmq(self, mocked_profiler, mocked_logger, mocked_zmq_ctx):
mocked_zmq_ctx().socket().bind = mock.Mock(
side_effect=zmq.error.ZMQError(msg='random ZMQ error'))

with self.assertRaises(zmq.error.ZMQError):
with self.assertRaises(RuntimeError):
s._work()


Expand Down