Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code block from tensors.md gets ValueError with numpy>=2.0 #1968

Closed
leshabirukov opened this issue Nov 27, 2024 · 3 comments
Closed

Code block from tensors.md gets ValueError with numpy>=2.0 #1968

leshabirukov opened this issue Nov 27, 2024 · 3 comments
Labels
module: IR Intermediate representation topic: documentation Improvements or additions to documentation

Comments

@leshabirukov
Copy link
Contributor

docs/intermediate_representation/tensors.md contains this code block:

at line 157:

```{eval-rst}
.. exec_code::

    from onnxscript import ir
    import numpy as np

    array = np.array([0b1, 0b11], dtype=np.uint8)
    # The array is reinterpreted using the ml_dtypes package
    tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
    print(tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
    print("tensor.numpy():", tensor.numpy())  # [0.00195312 0.00585938]

    # Compute
    times_100 = tensor.numpy() * 100
    print("times_100:", times_100)

    # Create a new tensor out of the new value; dtype must be specified
    new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
    # You can also directly create the tensor from the float8 array without specifying dtype
    # new_tensor = ir.Tensor(times_100)
    print("new_tensor:", new_tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
    print("new_tensor == times_100", new_tensor.numpy() == times_100)  # array([ True,  True])
```

Itried to run it with numpy==2.0.2

and get

Traceback (most recent call last):
  File "D:\quick\2024\November\onnxscript-main\.my-stuff\tryit.py", line 24, in <module>
    print("new_tensor == times_100", new_tensor.numpy() == times_100)  # array([ True,  True])
ValueError: operands could not be broadcast together with shapes (8,) (2,) 
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1953125 0.5859375]
new_tensor: Tensor<FLOAT8E4M3FN,[8]>(array([0, 0, 4, 1.75, 0, 0, 0.0546875, 1.875], dtype='float8_e4m3fn'), name=None)

With numpy-1.26.4 it works well.
Also, it works well if uncomment line 176

new_tensor = ir.Tensor(times_100)

Actually, I hit this while tried to build docs, with 'wrong' for requirements-dev.txt numpy version, but I suppose, if onnxscript itself not restricts numpy version, so must its docs, so error.pipfreeze.txt

@justinchuby justinchuby added bug Something isn't working contribution welcome We welcome code contributions for this module: IR Intermediate representation labels Nov 29, 2024
@justinchuby
Copy link
Collaborator

Thanks! It appears the line

ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)

Has an encoding problem and does not create a correct tensor. We need to fix that.

@justinchuby justinchuby added topic: documentation Improvements or additions to documentation and removed bug Something isn't working contribution welcome We welcome code contributions for this labels Nov 29, 2024
@justinchuby
Copy link
Collaborator

The line in documentation needs to be changed:

- times_100 = tensor.numpy() * 100
+ tensor.numpy() * np.array(100, dtype=ml_dtypes.float8_e4m3fn)

This is because in numpy 2 the result of tensor.numpy() * 100 is promoted to float32. The following line times_100.view(np.uint8) will then incorrectly view the 32-bit dtype as 8-bit.

@leshabirukov
Copy link
Contributor Author

Great! Thanks for fast responce.
But for me it needs also

+ import ml_dtypes

- print("new_tensor == times_100", new_tensor.numpy() == times_100)  # array([ True,  True])
+ print("new_tensor == times_100", new_tensor == times_100)  # array([ True,  True])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: IR Intermediate representation topic: documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants