Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now respecting typing.no_type_check. Also some doc fixes. Bump version. #216

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fixes for docs and no-numpy dependencies.
patrick-kidger committed Jun 13, 2024

Verified

This commit was signed with the committer’s verified signature.
commit c154d2cda27d57b3840ad9e003c2bcc3a46f7a84
2 changes: 2 additions & 0 deletions docs/api/advanced-features.md
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
members:
false

::: jaxtyping.make_numpy_struct_dtype

## Printing axis bindings

::: jaxtyping.print_bindings
17 changes: 9 additions & 8 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
@@ -764,31 +764,32 @@ class _Cls(AbstractDtype):
Key = _make_dtype(_prng_key, "Key")


def make_numpy_struct_dtype(dtype: np.dtype, name: str):
def make_numpy_struct_dtype(dtype: "np.dtype", name: str):
"""Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays)
It does exact match on the name, order, and dtype of all its fields.
It performs an exact match on the name, order, and dtype of all its fields.
!!! Example
```python
label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')
```
after that, you can use it just like any AbstractDtype
after that, you can use it just like any other [`jaxtyping.AbstractDtype`][]:
```python
a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
```
**Arguments:**
- `dtype`: The numpy dtype that the returned annotation matches
- `name`: The python class name for the returned dtype annotation
- `dtype`: The numpy structured dtype to use.
- `name`: The name to use for the returned Python class.
**Returns:**
A type annotation with classname `name` and matching exactly `dtype`.
It can be used like any usual subclasses of AbstractDtypes.
A type annotation with classname `name` that matches exactly `dtype` when used like
any other [`jaxtyping.AbstractDtype`][].
"""
if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)):
raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}")
14 changes: 11 additions & 3 deletions jaxtyping/_storage.py
Original file line number Diff line number Diff line change
@@ -112,9 +112,17 @@ def shape_str(memos) -> str:
def print_bindings():
"""Prints the values of the current jaxtyping axis bindings. Intended for debugging.
That is, whilst doing runtime type checking, so that e.g. the `foo` and `bar` of
`Float[Array, "foo bar"]` are assigned values -- this function will print out those
values.
For example, this can be used to find the values bound to `foo` and `bar` in
```python
@jaxtyped(typechecker=...)
def f(x: Float[Array, "foo bar"]):
print_bindings()
...
```
noting that these values are bounding during runtime typechecking, so that the
[`jaxtyping.jaxtyped`][] decorator is required.
**Arguments:**