diff --git a/src/profiler/custom_op_profiler.h b/src/profiler/custom_op_profiler.h index 0bafeab36d23..93ae3a64236f 100644 --- a/src/profiler/custom_op_profiler.h +++ b/src/profiler/custom_op_profiler.h @@ -54,13 +54,13 @@ class CustomOpProfiler { */ void OnCustomBegin(const std::string& op_type) { profiler::Profiler *profiler = profiler::Profiler::Get(); - if (!profiler->IsProfiling(profiler::Profiler::kImperative)) { - return; - } const Tid tid = std::this_thread::get_id(); const std::string task_name = MakePythonCodeName(op_type); std::lock_guard lock(mutex_); tid_to_op_type_[tid] = op_type; + if (!profiler->IsProfiling(profiler::Profiler::kImperative)) { + return; + } tasks_[tid] = std::make_unique(task_name.c_str(), &custom_op_domain); tasks_[tid]->start(); } @@ -72,13 +72,13 @@ class CustomOpProfiler { void OnCustomEnd() { const Tid tid = std::this_thread::get_id(); std::lock_guard lock(mutex_); + tid_to_op_type_.erase(tid); // this can happen if we are not profiling if (tasks_.find(tid) == tasks_.end()) { return; } tasks_[tid]->stop(); tasks_.erase(tid); - tid_to_op_type_.erase(tid); } /*! diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py index 6317640121e1..ab7d29f104ca 100644 --- a/tests/python/unittest/test_profiler.py +++ b/tests/python/unittest/test_profiler.py @@ -411,7 +411,6 @@ def create_operator(self, ctx, shapes, dtypes): mx.nd.waitall() profiler.dump(False) debug_str = profiler.dumps(format = 'json') - print(debug_str) target_dict = json.loads(debug_str) ''' We are calling _plus_scalar within MyAdd1 and MyAdd2 and outside both the custom