-
Notifications
You must be signed in to change notification settings - Fork 383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes for numpy 2.0 #928
Fixes for numpy 2.0 #928
Conversation
sacred/config/custom_containers.py
Outdated
# np.array_equal raises an exception when the arguments are not array in numpy 2.0. | ||
# This issue is only present in 2.0, not in <2.0 or >=2.1 | ||
if isinstance(old_value, opt.np.ndarray) and isinstance(new_value, opt.np.ndarray): | ||
return not opt.np.array_equal(old_value, new_value) | ||
elif isinstance(old_value, opt.np.ndarray) or isinstance(new_value, opt.np.ndarray): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is not identical to the previous check. The issue is numpy/numpy#27271
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should do it
def np_array_equal(a1, a2, equal_nan: bool = False):
# Adapted version of np.array_equal that avoids
# https://github.com/numpy/numpy/issues/27271
assert opt.has_numpy, "numpy is not available"
try:
a1, a2 = opt.np.asarray(a1), opt.np.asarray(a2)
except Exception:
return False
if a1.shape != a2.shape:
return False
if not equal_nan:
result = a1 == a2
if isinstance(result, bool):
return result
return builtins.bool(result.all())
cannot_have_nan = opt.np._dtype_cannot_hold_nan(
a1.dtype
) and opt.np._dtype_cannot_hold_nan(a2.dtype)
if cannot_have_nan:
if a1 is a2:
return True
result = a1 == a2
if isinstance(result, bool):
return result
return builtins.bool(result.all())
if a1 is a2:
# nan will compare equal so an array will compare equal to itself.
return True
# Handling NaN values if equal_nan is True
a1nan, a2nan = opt.np.isnan(a1), opt.np.isnan(a2)
# NaN's occur at different locations
if not (a1nan == a2nan).all():
return False
# Shapes of a1, a2 and masks are guaranteed to be consistent by this point
# Here we don't need to check that the result might be a boolean, because it doesn't
# happen for numerical values.
return builtins.bool((a1[~a1nan] == a2[~a1nan]).all())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this function from np.array_equal
and added checks where necessary to avoid numpy/numpy#27271
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my local version, this fixes all tests together with the exclusion of np.float_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you are right. I forgot the cases where numpy arrays are nested in lists (or similar) that numpy automatically converts to numpy arrays. I don't think it's easy to reproduce the old behavior exactly. We could get closer with something like
try:
return not opt.np.array_equal(old_value, new_value)
except AttributeError:
return old_value != new_value
BTW, the old code works for me on numpy 2.1. So np.array_equal([0], ['1'])
fails in numpy 2.0 but works in 2.1, np.array_equal(0, '1')
fails in 2.0 and 2.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, GitHub didn't display your other messages...
Since equal_nan=False
, we can get away with just a few lines:
try:
old_value = opt.np.asarray(old_value)
new_value = opt.np.asarray(new_value)
except:
return False
else:
result = old_value == new_value
if isinstance(result, bool):
return result
else:
return result.all()
But I found another issue with the original code. It does change behavior in some cases when numpy is present vs when not:
>>> [1, [1,2]] == [1, [1,2]]
True
>>> np.array_equal([1, [1,2]], [1, [1,2]])
False
Which raises the question if we should really keep the old behavior...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very undesired and potentially undesired tbh. I haven't looked deeply into what code paths would be affected by fixing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old behavior should be recovered in numpy 2.0.1 and 2.1.1 numpy/numpy#27271
I don't know when that will be published but that should pretty much eliminate the need to write code that restores the old behavior. I guess the important question is whether one should fix the old behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll open an issue for this and reproduce the old behavior in this PR. I'm not sure yet what impact it would have to fix this. I guess that most sacred users have numpy installed so that fixing this may actually break or change something for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable. However, iirc I had some issues with the identity check once, and the solution was counterintuitive. But, that may have been seml
and not sacred
's fault. Also, it has been years...
It looks like one needs the old |
I think we only need
|
Finally the tests are passing! :) |
Funny, I was sure that I tested that and that it didn't work for all versions, but it actually does |
I also fixed the s3_observer tests to make the tests pass