Skip to content

Commit

Permalink
Fixes for numpy 2.0 (#928)
Browse files Browse the repository at this point in the history
* Remove tests for float_ for numpy2.0

* Fix is_different for numpy=2.0

* Add tox test environments for numpy 2.0

* Add pull_request trigger to tests

* Fix s3 tests (moto.mock_s3 -> moto.mock_aws)

* Reproduce old np.array_equal behavior

* Update tensorflow test configurations
  • Loading branch information
thequilo authored Aug 26, 2024
1 parent cd90ee1 commit 356c310
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 56 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Tests

on:
- push
- pull_request

jobs:
pytest:
Expand Down
19 changes: 16 additions & 3 deletions sacred/config/custom_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,19 @@ def type_changed(old_value, new_value):
def is_different(old_value, new_value):
"""Numpy aware comparison between two values."""
if opt.has_numpy:
return not opt.np.array_equal(old_value, new_value)
else:
return old_value != new_value
# Reproduces np.array_equal from numpy<2
# np.array_equal raises an exception when the arguments are scalar and
# differ in type (e.g. int and str) in numpy>=2.0
try:
old_value = opt.np.asarray(old_value)
new_value = opt.np.asarray(new_value)
except:
return False
else:
result = old_value == new_value
if isinstance(result, bool):
return result
else:
return result.all()

return old_value != new_value
2 changes: 0 additions & 2 deletions tests/test_config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"uint16",
"uint32",
"uint64",
"float_",
"float16",
"float32",
"float64",
Expand All @@ -49,7 +48,6 @@ def test_normalize_or_die_for_numpy_datatypes(typename):
"uint16",
"uint32",
"uint64",
"float_",
"float16",
"float32",
"float64",
Expand Down
16 changes: 8 additions & 8 deletions tests/test_observers/test_s3_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_file_data(bucket_name, key):
return s3.Object(bucket_name, key).get()["Body"].read()


@moto.mock_s3
@moto.mock_aws
def test_fs_observer_started_event_creates_bucket(observer, sample_run):
_id = observer.started_event(**sample_run)
run_dir = s3_join(BASEDIR, str(_id))
Expand All @@ -102,7 +102,7 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run):
}


@moto.mock_s3
@moto.mock_aws
def test_fs_observer_started_event_increments_run_id(observer, sample_run):
_id = observer.started_event(**sample_run)
_id2 = observer.started_event(**sample_run)
Expand All @@ -119,15 +119,15 @@ def test_s3_observer_equality():
assert obs_one != different_bucket


@moto.mock_s3
@moto.mock_aws
def test_raises_error_on_duplicate_id_directory(observer, sample_run):
observer.started_event(**sample_run)
sample_run["_id"] = 1
with pytest.raises(FileExistsError):
observer.started_event(**sample_run)


@moto.mock_s3
@moto.mock_aws
def test_completed_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -145,7 +145,7 @@ def test_completed_event_updates_run_json(observer, sample_run):
assert run["status"] == "COMPLETED"


@moto.mock_s3
@moto.mock_aws
def test_interrupted_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -163,7 +163,7 @@ def test_interrupted_event_updates_run_json(observer, sample_run):
assert run["status"] == "SERVER_EXPLODED"


@moto.mock_s3
@moto.mock_aws
def test_failed_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -181,7 +181,7 @@ def test_failed_event_updates_run_json(observer, sample_run):
assert run["status"] == "FAILED"


@moto.mock_s3
@moto.mock_aws
def test_queued_event_updates_run_json(observer, sample_run):
del sample_run["start_time"]
sample_run["queue_time"] = T2
Expand All @@ -194,7 +194,7 @@ def test_queued_event_updates_run_json(observer, sample_run):
assert run["status"] == "QUEUED"


@moto.mock_s3
@moto.mock_aws
def test_artifact_event_works(observer, sample_run, tmpfile):
observer.started_event(**sample_run)
observer.artifact_event("test_artifact.py", tmpfile.name)
Expand Down
52 changes: 9 additions & 43 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# and then run "tox" from this directory.

[tox]
envlist = py{38,39,310,311}, setup, numpy-{120,121,123}, tensorflow-{26,27,28,29,210,211}
envlist = py{38,39,310,311}, setup, numpy-{120,121,123,200}, tensorflow-{212,216}

[testenv]
deps =
Expand Down Expand Up @@ -53,68 +53,34 @@ deps =
commands =
pytest tests/test_config {posargs}

[testenv:tensorflow-115]
[testenv:numpy-200]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=1.15.0
numpy~=2.0.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-26]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.6.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-27]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.7.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}
pytest tests/test_config {posargs}

[testenv:tensorflow-28]
[testenv:tensorflow-212]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.8.0
numpy<2.0.0
tensorflow~=2.12.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-29]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.9.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-210]
[testenv:tensorflow-216]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.10.0
tensorflow~=2.16.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-211]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.11.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:setup]
basepython = python
Expand Down

0 comments on commit 356c310

Please sign in to comment.