-
Notifications
You must be signed in to change notification settings - Fork 54
fix: empty array type mismatch between host and device #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f0ea6c7 to
098fbe0
Compare
Greptile OverviewGreptile SummaryFixed type mismatch between host and device empty arrays by correctly identifying empty arrays as contiguous and preventing them from being misclassified as broadcast arrays.
Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant NumPy
participant cuda.to_device
participant DeviceNDArrayBase
participant Array as dummyarray.Array
participant typeof
User->>NumPy: Create empty array (e.g., shape=(10, 0))
NumPy-->>User: Returns array with size=0, strides may contain 0
User->>cuda.to_device: Transfer array to device
cuda.to_device->>DeviceNDArrayBase: Create device array
DeviceNDArrayBase->>Array: _compute_layout() to set flags
Note over Array: Check if self.size == 0
Array-->>DeviceNDArrayBase: Return {C_CONTIGUOUS: True, F_CONTIGUOUS: True}
User->>typeof: Get type of device array
typeof->>DeviceNDArrayBase: Call __type_name__()
Note over DeviceNDArrayBase: broadcast = 0 in strides AND size != 0
Note over DeviceNDArrayBase: For empty arrays: broadcast = False
DeviceNDArrayBase->>DeviceNDArrayBase: Check flags["C_CONTIGUOUS"]
DeviceNDArrayBase-->>typeof: Return Array(dtype, ndim, 'C')
User->>User: Compare host and device types
Note over User: Types now match: array(dtype, ndim, C)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
/ok to test e4e1b57 |
gmarkall
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for the PR! I think the tests look great, and the fix moves things in the right direction without regressing any functionality.
Whilst examining your changes, I did become quite suspicious of the way we're handling contiguity computation (see comment on the diff). I began to believe that if we can correctly compute contiguity, then we shouldn't need to be special-casing the computation of _numba_type_ in devicearray.py on whether the array is broadcasted or not.
Furthermore, having a zero stride shouldn't also preclude contiguity (from reading the implementation and comments on the NumPy version of this functionality). It made me feel that we ought to be able to apply:
diff --git a/numba_cuda/numba/cuda/cudadrv/devicearray.py b/numba_cuda/numba/cuda/cudadrv/devicearray.py
index 188b2f5a..cfd25b5a 100644
--- a/numba_cuda/numba/cuda/cudadrv/devicearray.py
+++ b/numba_cuda/numba/cuda/cudadrv/devicearray.py
@@ -178,11 +178,9 @@ class DeviceNDArrayBase(_devicearray.DeviceArray):
# of which will be 0, will not match those hardcoded in for 'C' or 'F'
# layouts.
- broadcast = 0 in self.strides and (self.size != 0)
-
- if self.flags["C_CONTIGUOUS"] and not broadcast:
+ if self.flags["C_CONTIGUOUS"]:
layout = "C"
- elif self.flags["F_CONTIGUOUS"] and not broadcast:
+ elif self.flags["F_CONTIGUOUS"]:
layout = "F"
else:
layout = "A"
diff --git a/numba_cuda/numba/cuda/cudadrv/dummyarray.py b/numba_cuda/numba/cuda/cudadrv/dummyarray.py
index 16545954..8b0515c4 100644
--- a/numba_cuda/numba/cuda/cudadrv/dummyarray.py
+++ b/numba_cuda/numba/cuda/cudadrv/dummyarray.py
@@ -275,19 +275,11 @@ class Array(object):
# 13661ac70).
# https://github.com/numpy/numpy/blob/maintenance/1.19.x/numpy/core/src/multiarray/flagsobject.c#L123-L191
+ flags = {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
+
# Records have no dims, and we can treat them as contiguous
if not self.dims:
- return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
-
- # All 0-size arrays are considered contiguous, even if they are multidimensional
- if self.size == 0:
- return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
-
- # If this is a broadcast array then it is not contiguous
- if any([dim.stride == 0 for dim in self.dims]):
- return {"C_CONTIGUOUS": False, "F_CONTIGUOUS": False}
-
- flags = {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
+ return flags
# Check C contiguity
sd = self.itemsizeto allow broadcasted arrays to also be considered contiguous. If we didn't have the check for zero stride then we wouldn't be excluding zero-size arrays from being considered contiguous.
However, this does lead to one failure in test_devicearray_broadcast_host_copy() because the array elements seem to get transposed during the copy over to the device.
I don't want to blow up the scope of this PR and stall a perfectly good fix to jump down a rabbit hole though, so I'd like to instead merge this and potentially follow up on other items later.
| return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True} | ||
|
|
||
| # If this is a broadcast array then it is not contiguous | ||
| if any([dim.stride == 0 for dim in self.dims]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noted that the NumPy implementation that this is following (from the _UpdateContiguousFlags implementation referenced above) doesn't have this check for zero strides.
It's not directly related to this PR, but I think it is suspicious that we still differ in our implementation.
Aims to align more closely with NumPy contiguity logic. The example in the commit message from NVIDIA#612 still runs correctly with this change. I think this needs a little more consideration for now.
- Revert NVIDIA#536 "perf: remove context threading in various pointer abstractions" (NVIDIA#611) - fix: empty array type mismatch between host and device (NVIDIA#612) - fix: warp vote operations must use a constant int for the `mode` parameter (NVIDIA#606)
Fixes #483.
As mentioned here, the bug was due to:
When running this:
Below is the output: