Skip to content

Commit

Permalink
Merge pull request #54 from hugofloresgarcia/master
Browse files Browse the repository at this point in the history
Ensure exported models output float tensors
  • Loading branch information
caillonantoine authored Jun 21, 2023
2 parents e144b5d + 8c6fa75 commit eab72a1
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def register_method(
(f"Wrong output length for method \"{method_name}\", "
f"expected {test_buffer_size//out_ratio} "
f"got {y.shape[2]}"))
if y.dtype != torch.float:
raise ValueError(f"Output tensor must be of type float, got {y.dtype}")

if cc.MAX_BATCH_SIZE > 1:
logging.info(f"Testing method {method_name} with mc.nn~ API")
Expand Down

0 comments on commit eab72a1

Please sign in to comment.