Skip to content

Commit

Permalink
Fix to issue #20616
Browse files Browse the repository at this point in the history
  • Loading branch information
rujutajoshi232 committed Dec 23, 2024
1 parent 3dd958b commit f328d50
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 8 additions & 1 deletion keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6522,7 +6522,14 @@ def eye(N, M=None, k=0, dtype=None):
Returns:
Tensor with ones on the k-th diagonal and zeros elsewhere.
"""
Raises:
Error if N, M are not integer values.
"""
if not isinstance(N, int):
raise ValueError(f"N must be an integer, got {type(N).__name__}")
if M is not None and not isinstance(M, int):
raise ValueError(f"M must be an integer, got {type(M).__name__}")

return backend.numpy.eye(N, M=M, k=k, dtype=dtype)


Expand Down
2 changes: 2 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4718,6 +4718,8 @@ def test_eye(self):
# Test k < 0 and M < N and M - k > N
self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))



def test_arange(self):
self.assertAllClose(knp.arange(3), np.arange(3))
self.assertAllClose(knp.arange(3, 7), np.arange(3, 7))
Expand Down

0 comments on commit f328d50

Please sign in to comment.