Skip to content

Commit

Permalink
Merge pull request #69 from paganpasta/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
paganpasta committed Mar 11, 2023
2 parents cbdfadb + abb8678 commit 7da790c
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-test:
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9 ]
python-version: [ 3.8, 3.9 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: nbqa-isort
- id: nbqa-flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.2.7"
__version__ = "0.2.8"

from . import experimental, layers, models, utils
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [{name = "Contributing Authors", email = "[email protected]"}]
dynamic = ["version", "description"]
readme = "README.md"
license = { file="LICENSE.md" }
requires-python = ">=3.7"
requires-python = ">=3.8"
classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_vit_block(self, getkey):
def forward(net, x, keys, attn=False):
nonlocal c_counter
c_counter += 1
return eqx.filter_vmap(net)(x, return_attention=attn, key=keys)
return jax.vmap(net, in_axes=(0, None))(x, attn, key=keys)

random_input = jax.random.uniform(key=getkey(), shape=(1, 8, 32))
answer, answer_attn = (1, 8, 32), (1, 1, 4, 8, 8)
Expand Down

0 comments on commit 7da790c

Please sign in to comment.