@@ -389,23 +389,23 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
389389 return too_large;
390390 }
391391
392- uint32_t model_tensor_size = 1 ;
393- for (int i = 0 ; i < (int )tensor->dims ->size ; ++i)
394- model_tensor_size *= (uint32_t )tensor->dims ->data [i];
395-
396- if (*output_tensor_size < model_tensor_size) {
397- NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
398- return too_large;
399- }
400-
401392 if (tensor->quantization .type == kTfLiteNoQuantization ) {
402393 NN_DBG_PRINTF (" No quantization information" );
403- float *ot =
404- tfl_ctx->interpreters [ctx].interpreter ->typed_output_tensor <float >(
405- index);
406-
407- int size = model_tensor_size * sizeof (float );
408- bh_memcpy_s (output_tensor, size, ot, size);
394+ if (*output_tensor_size < tensor->bytes ) {
395+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
396+ return too_large;
397+ }
398+ bh_memcpy_s (output_tensor, *output_tensor_size, tensor->data .data ,
399+ tensor->bytes );
400+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
401+ *output_tensor_size = tensor->bytes ;
402+ #else
403+ /*
404+ * for now, maintain the bug-to-bug compatibility with the old abi,
405+ * where the size here is the number of fp32, not bytes.
406+ */
407+ *output_tensor_size = tensor->bytes / sizeof (float );
408+ #endif
409409 }
410410 else { // TODO: Assuming uint8 quantized networks.
411411 TfLiteAffineQuantization *quant_info =
@@ -414,6 +414,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
414414 NN_ERR_PRINTF (" Quantization per channel is not supported" );
415415 return runtime_error;
416416 }
417+
418+ uint32_t model_tensor_size = 1 ;
419+ for (int i = 0 ; i < (int )tensor->dims ->size ; ++i)
420+ model_tensor_size *= (uint32_t )tensor->dims ->data [i];
421+
422+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
423+ if (*output_tensor_size / sizeof (float ) < model_tensor_size) {
424+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
425+ return too_large;
426+ }
427+ #else
428+ /*
429+ * for now, maintain the bug-to-bug compatibility with the old abi,
430+ * where the size here is the number of fp32, not bytes.
431+ */
432+ if (*output_tensor_size < model_tensor_size) {
433+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
434+ return too_large;
435+ }
436+ #endif
437+
417438 uint8_t *ot = tfl_ctx->interpreters [ctx]
418439 .interpreter ->typed_output_tensor <uint8_t >(index);
419440
@@ -426,9 +447,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
426447 for (uint32_t i = 0 ; i < model_tensor_size; ++i) {
427448 output_tensor_f[i] = (ot[i] - zero_point) * scale;
428449 }
450+
451+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
452+ *output_tensor_size = model_tensor_size * sizeof (float );
453+ #else
454+ /*
455+ * for now, maintain the bug-to-bug compatibility with the old abi,
456+ * where the size here is the number of fp32, not bytes.
457+ */
458+ *output_tensor_size = model_tensor_size;
459+ #endif
429460 }
430461
431- *output_tensor_size = model_tensor_size;
432462 return success;
433463}
434464
0 commit comments