diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 464e475cbdd69..9b05528f4a949 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -731,8 +731,11 @@ class FloatMaskMetricMsg : public MetricMsg { "the predict data length should be consistent with " "the label data length")); auto cal = GetCalculator(); + auto pre_var_place = GetVarPlace(exe_scope, pred_varname_); + auto label_var_place = GetVarPlace(exe_scope, label_varname_); + auto mask_var_place = GetVarPlace(exe_scope, mask_varname_); cal->add_float_mask_data( - pred_data, label_data, mask_data, label_len, place); + pred_data, label_data, mask_data, label_len, pre_var_place, label_var_place, mask_var_place); } protected: @@ -775,8 +778,11 @@ class ContinueMaskMetricMsg : public MetricMsg { "the predict data length should be consistent with " "the label data length")); auto cal = GetCalculator(); + auto pre_var_place = GetVarPlace(exe_scope, pred_varname_); + auto label_var_place = GetVarPlace(exe_scope, label_varname_); + auto mask_var_place = GetVarPlace(exe_scope, mask_varname_); cal->add_continue_mask_data( - pred_data, label_data, mask_data, label_len, place); + pred_data, label_data, mask_data, label_len, pre_var_place, label_var_place, mask_var_place); } protected: @@ -913,8 +919,10 @@ class NanInfMetricMsg : public MetricMsg { "the predict data length should be consistent with " "the label data length")); auto cal = GetCalculator(); + auto pre_var_place = GetVarPlace(exe_scope, pred_varname_); + auto label_var_place = GetVarPlace(exe_scope, label_varname_); cal->add_nan_inf_data( - pred_data, label_data, label_len, place); + pred_data, label_data, label_len, pre_var_place, label_var_place); } }; diff --git a/paddle/fluid/framework/fleet/metrics.cc b/paddle/fluid/framework/fleet/metrics.cc index 03a7eef062987..ae7111b4d28b5 100644 --- a/paddle/fluid/framework/fleet/metrics.cc +++ b/paddle/fluid/framework/fleet/metrics.cc @@ -194,31 +194,34 @@ void BasicAucCalculator::add_mask_data(const float* d_pred, void BasicAucCalculator::add_float_mask_data(const float* d_pred, const float* d_label, const int64_t* d_mask, int batch_size, - const paddle::platform::Place& place) { - if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) { - thread_local std::vector h_pred; - thread_local std::vector h_label; - thread_local std::vector h_mask; + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label, + const paddle::platform::Place& place_mask) { + thread_local std::vector h_pred; + thread_local std::vector h_label; + thread_local std::vector h_mask; + const float* add_pred = d_pred; + const float* add_label = d_label; + const int64_t* add_mask = d_mask; + if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) { h_pred.resize(batch_size); + SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred); + add_pred = h_pred.data(); + } + if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) { h_label.resize(batch_size); + SyncCopyD2H(h_label.data(), d_label, batch_size, place_label); + add_label = h_label.data(); + } + if (platform::is_gpu_place(place_mask) || platform::is_xpu_place(place_mask)) { h_mask.resize(batch_size); - - SyncCopyD2H(h_pred.data(), d_pred, batch_size, place); - SyncCopyD2H(h_label.data(), d_label, batch_size, place); - SyncCopyD2H(h_mask.data(), d_mask, batch_size, place); - - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - if (h_mask[i]) { - add_unlock_data_with_float_label(h_pred[i], h_label[i]); - } - } - } else { - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - if (d_mask[i]) { - add_unlock_data_with_float_label(d_pred[i], d_label[i]); - } + SyncCopyD2H(h_mask.data(), d_mask, batch_size, place_mask); + add_mask = h_mask.data(); + } + std::lock_guard lock(_table_mutex); + for (int i = 0; i < batch_size; ++i) { + if (add_mask[i]) { + add_unlock_data_with_float_label(add_pred[i], add_label[i]); } } } @@ -229,31 +232,34 @@ void BasicAucCalculator::add_continue_mask_data( const float* d_label, const int64_t* d_mask, int batch_size, - const paddle::platform::Place& place) { - if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) { - thread_local std::vector h_pred; - thread_local std::vector h_label; - thread_local std::vector h_mask; + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label, + const paddle::platform::Place& place_mask) { + thread_local std::vector h_pred; + thread_local std::vector h_label; + thread_local std::vector h_mask; + const float* add_pred = d_pred; + const float* add_label = d_label; + const int64_t* add_mask = d_mask; + if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) { h_pred.resize(batch_size); + SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred); + add_pred = h_pred.data(); + } + if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) { h_label.resize(batch_size); + SyncCopyD2H(h_label.data(), d_label, batch_size, place_label); + add_label = h_label.data(); + } + if (platform::is_gpu_place(place_mask) || platform::is_xpu_place(place_mask)) { h_mask.resize(batch_size); - - SyncCopyD2H(h_pred.data(), d_pred, batch_size, place); - SyncCopyD2H(h_label.data(), d_label, batch_size, place); - SyncCopyD2H(h_mask.data(), d_mask, batch_size, place); - - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - if (h_mask[i]) { - add_unlock_data_with_continue_label(h_pred[i], h_label[i]); - } - } - } else { - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - if (d_mask[i]) { - add_unlock_data_with_continue_label(d_pred[i], d_label[i]); - } + SyncCopyD2H(h_mask.data(), d_mask, batch_size, place_mask); + add_mask = h_mask.data(); + } + std::lock_guard lock(_table_mutex); + for (int i = 0; i < batch_size; ++i) { + if (add_mask[i]) { + add_unlock_data_with_continue_label(add_pred[i], add_label[i]); } } } @@ -438,26 +444,26 @@ void BasicAucCalculator::add_uid_data(const float* d_pred, void BasicAucCalculator::add_nan_inf_data(const float* d_pred, const int64_t* d_label, int batch_size, - const paddle::platform::Place& place){ - if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) { - thread_local std::vector h_pred; - thread_local std::vector h_label; + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label){ + thread_local std::vector h_pred; + thread_local std::vector h_label; + const float* add_pred = d_pred; + const int64_t* add_label = d_label; + if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) { h_pred.resize(batch_size); + SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred); + add_pred = h_pred.data(); + } + if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) { h_label.resize(batch_size); - SyncCopyD2H(h_pred.data(), d_pred, batch_size, place); - SyncCopyD2H(h_label.data(), d_label, batch_size, place); - - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - add_nan_inf_unlock_data(h_pred[i], h_label[i]); - } - } else { - std::lock_guard lock(_table_mutex); - for (int i = 0; i < batch_size; ++i) { - add_nan_inf_unlock_data(d_pred[i], d_label[i]); - } + SyncCopyD2H(h_label.data(), d_label, batch_size, place_label); + add_label = h_label.data(); + } + std::lock_guard lock(_table_mutex); + for (int i = 0; i < batch_size; ++i) { + add_nan_inf_unlock_data(add_pred[i], add_label[i]); } - } void BasicAucCalculator::add_uid_unlock_data(double pred, diff --git a/paddle/fluid/framework/fleet/metrics.h b/paddle/fluid/framework/fleet/metrics.h index 6ce4b8d78b2ce..e7e227f222dfe 100644 --- a/paddle/fluid/framework/fleet/metrics.h +++ b/paddle/fluid/framework/fleet/metrics.h @@ -83,13 +83,17 @@ class BasicAucCalculator { const float* d_label, const int64_t* d_mask, int batch_size, - const paddle::platform::Place& place); + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label, + const paddle::platform::Place& place_mask); // add continue data void add_continue_mask_data(const float* d_pred, const float* d_label, const int64_t* d_mask, int batch_size, - const paddle::platform::Place& place); + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label, + const paddle::platform::Place& place_mask); // add sample data void add_sample_data(const float* d_pred, const int64_t* d_label, @@ -109,7 +113,8 @@ class BasicAucCalculator { void add_nan_inf_data(const float* d_pred, const int64_t* d_label, int batch_size, - const paddle::platform::Place& place); + const paddle::platform::Place& place_pred, + const paddle::platform::Place& place_label); void compute(); void computeContinueMsg();