Skip to content

Commit

Permalink
Implement broadcasting with device scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak authored and inducer committed Aug 9, 2022
1 parent 09fd4c3 commit db77dc7
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 46 deletions.
34 changes: 28 additions & 6 deletions pycuda/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_elwise_module(
after_loop="",
):
from pycuda.compiler import SourceModule

return SourceModule(
"""
#include <pycuda-complex.hpp>
Expand Down Expand Up @@ -464,15 +463,28 @@ def get_linear_combination_kernel(summand_descriptors, dtype_z):


@context_dependent_memoize
def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z):
def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z,
x_is_scalar=False, y_is_scalar=False):
"""
Returns a kernel corresponding to ``z = ax + by``.
:arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`.
:arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`.
"""
out_t = dtype_to_ctype(dtype_z)
x = "x[0]" if x_is_scalar else "x[i]"
ax = f"a*(({out_t}) {x})"
y = "y[0]" if y_is_scalar else "y[i]"
by = f"b*(({out_t}) {y})"
result = f"{ax} + {by}"
return get_elwise_kernel(
"%(tp_x)s a, %(tp_x)s *x, %(tp_y)s b, %(tp_y)s *y, %(tp_z)s *z"
% {
"tp_x": dtype_to_ctype(dtype_x),
"tp_y": dtype_to_ctype(dtype_y),
"tp_z": dtype_to_ctype(dtype_z),
},
"z[i] = a*x[i] + b*y[i]",
f"z[i] = {result}",
"axpbyz",
)

Expand All @@ -488,15 +500,25 @@ def get_axpbz_kernel(dtype_x, dtype_z):


@context_dependent_memoize
def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator):
def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator,
x_is_scalar=False, y_is_scalar=False):
"""
Returns a kernel corresponding to ``z = x (operator) y``.
:arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`.
:arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`.
"""
x = "x[0]" if x_is_scalar else "x[i]"
y = "y[0]" if y_is_scalar else "y[i]"
result = f"{x} {operator} {y}"
return get_elwise_kernel(
"%(tp_x)s *x, %(tp_y)s *y, %(tp_z)s *z"
% {
"tp_x": dtype_to_ctype(dtype_x),
"tp_y": dtype_to_ctype(dtype_y),
"tp_z": dtype_to_ctype(dtype_z),
},
"z[i] = x[i] %s y[i]" % operator,
f"z[i] = {result}",
"multiply",
)

Expand Down Expand Up @@ -760,8 +782,8 @@ def get_scalar_op_kernel(dtype_x, dtype_a, dtype_y, operator):
"%(tp_x)s *x, %(tp_a)s a, %(tp_y)s *y"
% {
"tp_x": dtype_to_ctype(dtype_x),
"tp_a": dtype_to_ctype(dtype_a),
"tp_y": dtype_to_ctype(dtype_y),
"tp_a": dtype_to_ctype(dtype_a),
},
"y[i] = x[i] %s a" % operator,
"scalarop_kernel",
Expand Down
119 changes: 82 additions & 37 deletions pycuda/gpuarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def _get_common_dtype(obj1, obj2):
return _get_common_dtype_base(obj1, obj2, has_double_support())


def _get_broadcasted_binary_op_result(obj1, obj2,
dtype_getter=_get_common_dtype):

if obj1.shape == obj2.shape:
return obj1._new_like_me(dtype_getter(obj1, obj2))
elif obj1.shape == ():
return obj2._new_like_me(dtype_getter(obj1, obj2))
elif obj2.shape == ():
return obj1._new_like_me(dtype_getter(obj1, obj2))
else:
raise NotImplementedError("Broadcasting binary operator with shapes:"
f" {obj1.shape}, {obj2.shape}.")
# {{{ vector types


Expand Down Expand Up @@ -141,20 +153,22 @@ def func(self, other):
raise RuntimeError(
"only contiguous arrays may " "be used as arguments to this operation"
)

if isinstance(other, GPUArray):
assert self.shape == other.shape

if isinstance(other, GPUArray) and (self, GPUArray):
if not other.flags.forc:
raise RuntimeError(
"only contiguous arrays may "
"be used as arguments to this operation"
)

result = self._new_like_me()
result = _get_broadcasted_binary_op_result(self, other)
func = elementwise.get_binary_op_kernel(
self.dtype, other.dtype, result.dtype, operator
)
self.dtype,
other.dtype,
result.dtype,
operator,
x_is_scalar=(self.shape == ()),
y_is_scalar=(other.shape == ()))

func.prepared_async_call(
self._grid,
self._block,
Expand All @@ -166,7 +180,8 @@ def func(self, other):
)

return result
else: # scalar operator
elif isinstance(self, GPUArray): # scalar operator
assert np.isscalar(other)
result = self._new_like_me()
func = elementwise.get_scalar_op_kernel(self.dtype,
np.dtype(type(other)),
Expand All @@ -181,6 +196,8 @@ def func(self, other):
self.mem_size,
)
return result
else:
return AssertionError

return func

Expand Down Expand Up @@ -400,38 +417,41 @@ def ptr(self):
def _axpbyz(self, selffac, other, otherfac, out, add_timer=None, stream=None):
"""Compute ``out = selffac * self + otherfac*other``,
where `other` is a vector.."""
assert self.shape == other.shape
if not self.flags.forc or not other.flags.forc:
raise RuntimeError(
"only contiguous arrays may " "be used as arguments to this operation"
)

func = elementwise.get_axpbyz_kernel(self.dtype, other.dtype, out.dtype)

assert ((self.shape == other.shape == out.shape)
or ((self.shape == ()) and other.shape == out.shape)
or ((other.shape == ()) and self.shape == out.shape))
func = elementwise.get_axpbyz_kernel(
self.dtype, other.dtype, out.dtype,
x_is_scalar=(self.shape == ()),
y_is_scalar=(other.shape == ()))
if add_timer is not None:
add_timer(
3 * self.size,
func.prepared_timed_call(
self._grid,
out._grid,
selffac,
self.gpudata,
out.gpudata,
otherfac,
other.gpudata,
out.gpudata,
self.mem_size,
out.mem_size,
),
)
else:
func.prepared_async_call(
self._grid,
self._block,
out._grid,
out._block,
stream,
selffac,
self.gpudata,
otherfac,
other.gpudata,
out.gpudata,
self.mem_size,
out.mem_size,
)

return out
Expand Down Expand Up @@ -463,16 +483,26 @@ def _elwise_multiply(self, other, out, stream=None):
raise RuntimeError(
"only contiguous arrays may " "be used as arguments to this operation"
)
assert ((self.shape == other.shape == out.shape)
or ((self.shape == ()) and other.shape == out.shape)
or ((other.shape == ()) and self.shape == out.shape))

func = elementwise.get_binary_op_kernel(
self.dtype,
other.dtype,
out.dtype,
"*",
x_is_scalar=(self.shape == ()),
y_is_scalar=(other.shape == ()))

func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "*")
func.prepared_async_call(
self._grid,
self._block,
out._grid,
out._block,
stream,
self.gpudata,
other.gpudata,
out.gpudata,
self.mem_size,
out.mem_size,
)

return out
Expand Down Expand Up @@ -509,17 +539,25 @@ def _div(self, other, out, stream=None):
"only contiguous arrays may " "be used as arguments to this operation"
)

assert self.shape == other.shape
assert ((self.shape == other.shape == out.shape)
or ((self.shape == ()) and other.shape == out.shape)
or ((other.shape == ()) and self.shape == out.shape))

func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "/")
func = elementwise.get_binary_op_kernel(
self.dtype,
other.dtype,
out.dtype,
"/",
x_is_scalar=(self.shape == ()),
y_is_scalar=(other.shape == ()))
func.prepared_async_call(
self._grid,
self._block,
out._grid,
out._block,
stream,
self.gpudata,
other.gpudata,
out.gpudata,
self.mem_size,
out.mem_size,
)

return out
Expand All @@ -546,31 +584,35 @@ def __add__(self, other):

if isinstance(other, GPUArray):
# add another vector
result = self._new_like_me(_get_common_dtype(self, other))
result = _get_broadcasted_binary_op_result(self, other)
return self._axpbyz(1, other, 1, result)
else:

elif np.isscalar(other):
# add a scalar
if other == 0:
return self.copy()
else:
result = self._new_like_me(_get_common_dtype(self, other))
return self._axpbz(1, other, result)

else:
return NotImplemented
__radd__ = __add__

def __sub__(self, other):
"""Substract an array from an array or a scalar from an array."""

if isinstance(other, GPUArray):
result = self._new_like_me(_get_common_dtype(self, other))
result = _get_broadcasted_binary_op_result(self, other)
return self._axpbyz(1, other, -1, result)
else:
elif np.isscalar(other):
if other == 0:
return self.copy()
else:
# create a new array for the result
result = self._new_like_me(_get_common_dtype(self, other))
return self._axpbz(1, -other, result)
else:
return NotImplemented

def __rsub__(self, other):
"""Substracts an array by a scalar or an array::
Expand Down Expand Up @@ -602,11 +644,13 @@ def __neg__(self):

def __mul__(self, other):
if isinstance(other, GPUArray):
result = self._new_like_me(_get_common_dtype(self, other))
result = _get_broadcasted_binary_op_result(self, other)
return self._elwise_multiply(other, result)
else:
elif np.isscalar(other):
result = self._new_like_me(_get_common_dtype(self, other))
return self._axpbz(other, 0, result)
else:
return NotImplemented

def __rmul__(self, scalar):
result = self._new_like_me(_get_common_dtype(self, scalar))
Expand All @@ -624,16 +668,17 @@ def __div__(self, other):
x = self / n
"""
if isinstance(other, GPUArray):
result = self._new_like_me(_get_common_dtype(self, other))
result = _get_broadcasted_binary_op_result(self, other)
return self._div(other, result)
else:
elif np.isscalar(other):
if other == 1:
return self.copy()
else:
# create a new array for the result
result = self._new_like_me(_get_common_dtype(self, other))
return self._axpbz(1 / other, 0, result)

else:
return NotImplemented
__truediv__ = __div__

def __rdiv__(self, other):
Expand Down
Loading

0 comments on commit db77dc7

Please sign in to comment.