Skip to content

Commit

Permalink
[TF-numpy] Sets check_dtypes default to True in _CheckAgainstNumpy …
Browse files Browse the repository at this point in the history
…and _CompileAndCheck.

PiperOrigin-RevId: 322896712
  • Loading branch information
wangpengmit authored and copybara-github committed Jul 23, 2020
1 parent 08bdb50 commit a50abb3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def assertMultiLineStrippedEqual(self, expected, what):
msg="Found\n{}\nExpecting\n{}".format(what, expected))

def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
check_dtypes=False, tol=None):
check_dtypes=True, tol=None):
args = args_maker()
lax_ans = lax_op(*args)
numpy_ans = numpy_reference_op(*args)
Expand All @@ -805,7 +805,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
def _CompileAndCheck(self,
fun,
args_maker,
check_dtypes,
check_dtypes=True,
rtol=None,
atol=None,
check_eval_on_shapes=True,
Expand Down

0 comments on commit a50abb3

Please sign in to comment.