diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 149b0ac9a..afdfdb330 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -36,6 +36,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kASTPrintEnable, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String); TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index 16586d4f9..52a784a26 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -61,6 +61,8 @@ static constexpr const char *kLayoutVisualizationEnable = static constexpr const char *kLayoutVisualizationFormats = "tl.layout_visualization_formats"; static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags"; +static constexpr const char *kDisableDataRaceCheck = + "tl.disable_data_race_check"; /*! * \brief Whether to disable thread storage synchronization diff --git a/src/transform/verify_parallel_loop.cc b/src/transform/verify_parallel_loop.cc index 1aa2c9347..8cf2e4887 100644 --- a/src/transform/verify_parallel_loop.cc +++ b/src/transform/verify_parallel_loop.cc @@ -194,7 +194,8 @@ struct ParallelLoopVerifier : public ConstrVisitor { } } void VisitStmt_(const BufferStoreNode *op) override { - if (reducers.count(op->buffer->data)) { + if (reducers.count(op->buffer->data) || op->buffer.scope() == "local.var" || + op->buffer.scope() == "local") { StmtExprVisitor::VisitStmt_(op); return; } @@ -227,14 +228,14 @@ struct ParallelLoopVerifier : public ConstrVisitor { } } if (!failed_vars.empty()) { - LOG(FATAL) << "Potential data race detected: `" << op->buffer - << op->indices << "`" - << "is written by multiple threads of loop vars: " - << failed_vars << ", Counterexample:\n" - << analyzer.z3_prover.GetModel(failed_var_expr) - << "If you believe this is a false positive, pass " - "`PassKey.TL_DISABLE_DATA_RACE_CHECK` to pass key to " - "disable this check."; + LOG(WARNING) << "Data race detected: `" << op->buffer << op->indices + << "`" + << "is written by multiple threads in loop " << failed_vars + << ", Example:\n" + << analyzer.z3_prover.GetModel(failed_var_expr) + << "If you believe this is a false positive, pass " + "`PassKey.TL_DISABLE_DATA_RACE_CHECK` to pass key to " + "disable this check."; } StmtExprVisitor::VisitStmt_(op); }