diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index d8289b9b..bbf64bb4 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -82,6 +82,9 @@ class ResultTokens(abc.ABC): def copy_to_host_async(self: "ResultTokens") -> None: """Copy to host asynchronously.""" + # Do nothing for np array + if isinstance(self.data, np.ndarray): + return self.data.copy_to_host_async() def convert_to_numpy(self: "ResultTokens") -> "ResultTokens":