Skip to content

Commit ccca00a

Browse files
sisleyliBin Li
andauthored
[BugFix]Ensure that bf16 arrays are created as expected (#16436)
Co-authored-by: Bin Li <[email protected]>
1 parent ffa404f commit ccca00a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

python/tvm/runtime/ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def copyfrom(self, source_array):
176176
if (not source_array.flags["C_CONTIGUOUS"]) or (
177177
dtype == "bfloat16" or dtype != np_dtype_str
178178
):
179+
if dtype == "bfloat16":
180+
source_array = np.frombuffer(source_array.tobytes(), "uint16")
179181
source_array = np.ascontiguousarray(
180182
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
181183
)

0 commit comments

Comments
 (0)