Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[µTVM] Add platform timer and RPCTimeEvaluator to enable AutoTVM #6964

Merged
merged 5 commits into from
Dec 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions apps/bundle_deploy/bundle.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,9 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return g_memory_manager->Free(g_memory_manager, ptr, ctx);
}

tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorFunctionCallNotImplemented; }

tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorFunctionCallNotImplemented;
}
6 changes: 6 additions & 0 deletions apps/bundle_deploy/bundle_static.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return g_memory_manager->Free(g_memory_manager, ptr, ctx);
}

tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorFunctionCallNotImplemented; }

tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorFunctionCallNotImplemented;
}
7 changes: 7 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,13 @@ TVM_DLL int TVMObjectRetain(TVMObjectHandle obj);
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

/*!
* \brief Free a TVMByteArray returned from TVMFuncCall, and associated memory.
* \param arr The TVMByteArray instance.
* \return 0 on success, -1 on failure.
*/
TVM_DLL int TVMByteArrayFree(TVMByteArray* arr);

/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/runtime/crt/error_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ typedef enum {
kTvmErrorCategoryGenerated = 6,
kTvmErrorCategoryGraphRuntime = 7,
kTvmErrorCategoryFunctionCall = 8,
kTvmErrorCategoryTimeEvaluator = 9,
} tvm_crt_error_category_t;

typedef enum {
Expand Down Expand Up @@ -77,6 +78,7 @@ typedef enum {
kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1),
kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2),
kTvmErrorPlatformNoMemory = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 3),
kTvmErrorPlatformTimerBadState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 4),

// Common error codes returned from generated functions.
kTvmErrorGeneratedInvalidStorageId = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGenerated, 0),
Expand All @@ -91,6 +93,9 @@ typedef enum {
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),

// Time Evaluator - times functions for use with debug runtime.
kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0),

// System errors are always negative integers; this mask indicates presence of a system error.
// Cast tvm_crt_error_t to a signed integer to interpret the negative error code.
kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)),
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/runtime/crt/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx);

/*! \brief Start a device timer.
*
* The device timer used must not be running.
*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformTimerStart();

/*! \brief Stop the running device timer and get the elapsed time (in microseconds).
*
* The device timer used must be running.
*
* \param elapsed_time_seconds Pointer to write elapsed time into.
*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..error import register_error
from .._ffi import get_global_func
from ..contrib import graph_runtime
from ..contrib.debugger import debug_runtime
from ..rpc import RPCSession
from .transport import IoTimeoutError
from .transport import TransportLogger
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,15 @@ int TVMFuncFree(TVMFunctionHandle func) {
API_END();
}

int TVMByteArrayFree(TVMByteArray* arr) {
if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) {
return 0; // Thread-local storage does not need explicit deleting.
}

delete arr;
return 0;
}

int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();
Expand Down
150 changes: 142 additions & 8 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ static const TVMModule* registered_modules[TVM_CRT_MAX_REGISTERED_MODULES];
/*! \brief Passed as `module_index` to EncodeFunctionHandle. */
static const tvm_module_index_t kGlobalFuncModuleIndex = TVM_CRT_MAX_REGISTERED_MODULES;

/*! \brief Special module handle for retur values from RPCTimeEvaluator. */
static const tvm_module_index_t kTimeEvaluatorModuleIndex = 0x7fff;

static int DecodeModuleHandle(TVMModuleHandle handle, tvm_module_index_t* out_module_index) {
tvm_module_index_t module_index;

Expand Down Expand Up @@ -185,20 +188,36 @@ static int DecodeFunctionHandle(TVMFunctionHandle handle, tvm_module_index_t* mo
(tvm_module_index_t)(((uintptr_t)handle) >> (sizeof(tvm_function_index_t) * 8));
unvalidated_module_index &= ~0x8000;

if (unvalidated_module_index > kGlobalFuncModuleIndex) {
TVMAPIErrorf("invalid module handle: index=%08x", unvalidated_module_index);
return -1;
} else if (unvalidated_module_index < kGlobalFuncModuleIndex &&
registered_modules[unvalidated_module_index] == NULL) {
TVMAPIErrorf("unregistered module: index=%08x", unvalidated_module_index);
return -1;
if (unvalidated_module_index != kTimeEvaluatorModuleIndex) {
if (unvalidated_module_index > kGlobalFuncModuleIndex) {
TVMAPIErrorf("invalid module handle: index=%08x", unvalidated_module_index);
return -1;
} else if (unvalidated_module_index < kGlobalFuncModuleIndex &&
registered_modules[unvalidated_module_index] == NULL) {
TVMAPIErrorf("unregistered module: index=%08x", unvalidated_module_index);
return -1;
}
}

*function_index = ((uint32_t)((uintptr_t)handle)) & ~0x8000;
*module_index = unvalidated_module_index;
return 0;
}

int TVMByteArrayFree(TVMByteArray* arr) {
DLContext ctx = {kDLCPU, 0};
int to_return = TVMPlatformMemoryFree((void*)arr->data, ctx);
if (to_return != 0) {
return to_return;
}

return TVMPlatformMemoryFree((void*)arr, ctx);
}

tvm_crt_error_t RunTimeEvaluator(tvm_function_index_t function_index, TVMValue* args,
int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

int TVMFuncCall(TVMFunctionHandle func_handle, TVMValue* arg_values, int* type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
tvm_module_index_t module_index;
Expand All @@ -211,7 +230,10 @@ int TVMFuncCall(TVMFunctionHandle func_handle, TVMValue* arg_values, int* type_c
return -1;
}

if (module_index == kGlobalFuncModuleIndex) {
if (module_index == kTimeEvaluatorModuleIndex) {
return RunTimeEvaluator(function_index, arg_values, type_codes, num_args, ret_val,
ret_type_code);
} else if (module_index == kGlobalFuncModuleIndex) {
resource_handle = NULL;
registry = &global_func_registry.registry;
} else {
Expand Down Expand Up @@ -315,6 +337,8 @@ int TVMFuncFree(TVMFunctionHandle func) {
return 0;
}

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);
tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -351,10 +375,120 @@ tvm_crt_error_t TVMInitializeRuntime() {
error = TVMFuncRegisterGlobal("tvm.rpc.server.ModuleGetFunction", &ModuleGetFunction, 0);
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
}

if (error != kTvmErrorNoError) {
TVMPlatformMemoryFree(registry_backing_memory, ctx);
TVMPlatformMemoryFree(func_registry_memory, ctx);
}

return error;
}

typedef struct {
uint16_t function_index;
TVMFunctionHandle func_to_time;
TVMContext ctx;
int number;
int repeat;
int min_repeat_ms;
} time_evaluator_state_t;

static time_evaluator_state_t g_time_evaluator_state;

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
ret_val[0].v_handle = NULL;
ret_type_code[0] = kTVMNullptr;
if (num_args < 8) {
TVMAPIErrorf("not enough args");
return kTvmErrorFunctionCallNumArguments;
}
if (type_codes[0] != kTVMModuleHandle || type_codes[1] != kTVMStr ||
type_codes[2] != kTVMArgInt || type_codes[3] != kTVMArgInt || type_codes[4] != kTVMArgInt ||
type_codes[5] != kTVMArgInt || type_codes[6] != kTVMArgInt || type_codes[7] != kTVMStr) {
TVMAPIErrorf("one or more invalid arg types");
return kTvmErrorFunctionCallWrongArgType;
}

TVMModuleHandle mod = (TVMModuleHandle)args[0].v_handle;
const char* name = args[1].v_str;
g_time_evaluator_state.ctx.device_type = args[2].v_int64;
g_time_evaluator_state.ctx.device_id = args[3].v_int64;
g_time_evaluator_state.number = args[4].v_int64;
g_time_evaluator_state.repeat = args[5].v_int64;
g_time_evaluator_state.min_repeat_ms = args[6].v_int64;

int ret_code =
TVMModGetFunction(mod, name, /* query_imports */ 0, &g_time_evaluator_state.func_to_time);
if (ret_code != 0) {
return ret_code;
}

g_time_evaluator_state.function_index++;
ret_val[0].v_handle =
EncodeFunctionHandle(kTimeEvaluatorModuleIndex, g_time_evaluator_state.function_index);
ret_type_code[0] = kTVMPackedFuncHandle;
return kTvmErrorNoError;
}

tvm_crt_error_t RunTimeEvaluator(tvm_function_index_t function_index, TVMValue* args,
int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
if (function_index != g_time_evaluator_state.function_index) {
return kTvmErrorTimeEvaluatorBadHandle;
}

// TODO(areusch): should *really* rethink needing to return doubles
DLContext result_byte_ctx = {kDLCPU, 0};
TVMByteArray* result_byte_arr;
tvm_crt_error_t err =
TVMPlatformMemoryAllocate(sizeof(TVMByteArray), result_byte_ctx, (void*)&result_byte_arr);
if (err != kTvmErrorNoError) {
return err;
}
size_t data_size = sizeof(double) * g_time_evaluator_state.repeat;
err = TVMPlatformMemoryAllocate(data_size, result_byte_ctx, (void*)&result_byte_arr->data);
if (err != kTvmErrorNoError) {
return err;
liangfu marked this conversation as resolved.
Show resolved Hide resolved
}
result_byte_arr->size = data_size;
double min_repeat_seconds = ((double)g_time_evaluator_state.min_repeat_ms) / 1000;
double* iter = (double*)result_byte_arr->data;
for (int i = 0; i < g_time_evaluator_state.repeat; i++) {
double repeat_res_seconds = 0.0;
int exec_count = 0;
// do-while structure ensures we run even when `min_repeat_ms` isn't set (i.e., is 0).
do {
tvm_crt_error_t ret_code = TVMPlatformTimerStart();
if (ret_code != kTvmErrorNoError) {
return ret_code;
}

for (int j = 0; j < g_time_evaluator_state.number; j++) {
ret_code = TVMFuncCall(g_time_evaluator_state.func_to_time, args, type_codes, num_args,
ret_val, ret_type_code);
if (ret_code != 0) {
return ret_code;
}
}
exec_count += g_time_evaluator_state.number;

double curr_res_seconds;
ret_code = TVMPlatformTimerStop(&curr_res_seconds);
if (ret_code != kTvmErrorNoError) {
return ret_code;
}
repeat_res_seconds += curr_res_seconds;
} while (repeat_res_seconds < min_repeat_seconds);
double mean_exec_seconds = repeat_res_seconds / exec_count;
*iter = mean_exec_seconds;
iter++;
}

*ret_type_code = kTVMBytes;
ret_val->v_handle = result_byte_arr;
return kTvmErrorNoError;
}
2 changes: 1 addition & 1 deletion src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what's the reason for increasing the registry size? just having a nice power of 2?


/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
Expand Down
23 changes: 12 additions & 11 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,30 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return memory_manager->Free(memory_manager, ptr, ctx);
}

high_resolution_clock::time_point g_utvm_start_time;
steady_clock::time_point g_utvm_start_time;
int g_utvm_timer_running = 0;

int TVMPlatformTimerStart() {
tvm_crt_error_t TVMPlatformTimerStart() {
if (g_utvm_timer_running) {
std::cerr << "timer already running" << std::endl;
return -1;
return kTvmErrorPlatformTimerBadState;
}
g_utvm_start_time = high_resolution_clock::now();
g_utvm_start_time = std::chrono::steady_clock::now();
g_utvm_timer_running = 1;
return 0;
return kTvmErrorNoError;
}

int TVMPlatformTimerStop(double* res_us) {
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
if (!g_utvm_timer_running) {
std::cerr << "timer not running" << std::endl;
return -1;
return kTvmErrorPlatformTimerBadState;
}
auto utvm_stop_time = high_resolution_clock::now();
duration<double, std::micro> time_span(utvm_stop_time - g_utvm_start_time);
*res_us = time_span.count();
auto utvm_stop_time = std::chrono::steady_clock::now();
std::chrono::microseconds time_span =
std::chrono::duration_cast<std::chrono::microseconds>(utvm_stop_time - g_utvm_start_time);
*elapsed_time_seconds = static_cast<double>(time_span.count()) / 1e6;
g_utvm_timer_running = 0;
return 0;
return kTvmErrorNoError;
}
}

Expand Down
Loading