Skip to content

Commit

Permalink
Added a missing branch to mgpu.FragmentedArray.astype
Browse files Browse the repository at this point in the history
Previously, an unsupported cast produced a `NameError` instead.

PiperOrigin-RevId: 676010161
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 18, 2024
1 parent 6236b8f commit 442e863
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def astype(self, new_dtype: ir.Type):
convert = arith.sitofp
elif from_float and to_integer:
convert = arith.fptosi
else:
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
new_registers = np.empty_like(self.registers)
match self.layout:
case WGMMAFragLayout():
Expand Down

0 comments on commit 442e863

Please sign in to comment.