Skip to content

Commit

Permalink
Optionally use NumPy to allocate buffers (#5750)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham authored Feb 24, 2022
1 parent 577ef40 commit 5553177
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
6 changes: 3 additions & 3 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, Listener
from .registry import Backend
from .utils import ensure_concrete_host, from_frames, to_frames
from .utils import ensure_concrete_host, from_frames, host_array, to_frames

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024):
self._using_default_buffer = True

self._default_len = max(min_read_size, 16) # need at least 16 bytes of buffer
self._default_buffer = memoryview(bytearray(self._default_len))
self._default_buffer = host_array(self._default_len)
# Index in default_buffer pointing to the first unparsed byte
self._default_start = 0
# Index in default_buffer pointing to the last written byte
Expand Down Expand Up @@ -258,7 +258,7 @@ def _parse_frame_lengths(self):
self._default_start += 8 * n_read

if n_read == needed:
self._frames = [memoryview(bytearray(n)) for n in self._frame_lengths]
self._frames = [host_array(n) for n in self._frame_lengths]
self._frame_index = 0
self._frame_nbytes_needed = (
self._frame_lengths[0] if self._frame_lengths else 0
Expand Down
10 changes: 8 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener
from .registry import Backend
from .utils import ensure_concrete_host, from_frames, get_tcp_server_address, to_frames
from .utils import (
ensure_concrete_host,
from_frames,
get_tcp_server_address,
host_array,
to_frames,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -205,7 +211,7 @@ async def read(self, deserializers=None):
frames_nbytes = await stream.read_bytes(fmt_size)
(frames_nbytes,) = struct.unpack(fmt, frames_nbytes)

frames = memoryview(bytearray(frames_nbytes))
frames = host_array(frames_nbytes)
# Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1)
for i, j in sliding_window(
2, range(0, frames_nbytes + C_INT_MAX, C_INT_MAX)
Expand Down
14 changes: 3 additions & 11 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, Listener
from .registry import Backend, backends
from .utils import ensure_concrete_host, from_frames, to_frames
from .utils import ensure_concrete_host, from_frames, host_array, to_frames

logger = logging.getLogger(__name__)

Expand All @@ -37,7 +37,6 @@
else:
ucp = None # type: ignore

host_array = None
device_array = None
pre_existing_cuda_context = False
cuda_context_created = False
Expand All @@ -53,7 +52,8 @@ def synchronize_stream(stream=0):


def init_once():
global ucp, host_array, device_array
global ucp, device_array
global ucx_create_endpoint, ucx_create_listener
global pre_existing_cuda_context, cuda_context_created

if ucp is not None:
Expand Down Expand Up @@ -110,14 +110,6 @@ def init_once():

ucp.init(options=ucx_config, env_takes_precedence=True)

# Find the function, `host_array()`, to use when allocating new host arrays
try:
import numpy

host_array = lambda n: numpy.empty((n,), dtype="u1")
except ImportError:
host_array = lambda n: bytearray(n)

# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm
Expand Down
20 changes: 20 additions & 0 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@
OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD)


# Find the function, `host_array()`, to use when allocating new host arrays
try:
# Use NumPy, when available, to avoid memory initialization cost.
# A `bytearray` is zero-initialized using `calloc`, which we don't need.
# `np.empty` both skips the zero-initialization, and
# uses hugepages when available ( https://github.com/numpy/numpy/pull/14216 ).
import numpy

def numpy_host_array(n: int) -> memoryview:
return memoryview(numpy.empty((n,), dtype="u1")) # type: ignore

host_array = numpy_host_array
except ImportError:

def builtin_host_array(n: int) -> memoryview:
return memoryview(bytearray(n))

host_array = builtin_host_array


async def to_frames(
msg,
allow_offload=True,
Expand Down

0 comments on commit 5553177

Please sign in to comment.