Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#59 from HuangShiqing/paddlebox
Browse files Browse the repository at this point in the history
abacus-aibox-991 fix the bug of add_float_mask_data's error place
  • Loading branch information
HuangShiqing authored Mar 18, 2024
2 parents 2e9a451 + d5d1860 commit e581cbe
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 67 deletions.
14 changes: 11 additions & 3 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,11 @@ class FloatMaskMetricMsg : public MetricMsg {
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
auto pre_var_place = GetVarPlace(exe_scope, pred_varname_);
auto label_var_place = GetVarPlace(exe_scope, label_varname_);
auto mask_var_place = GetVarPlace(exe_scope, mask_varname_);
cal->add_float_mask_data(
pred_data, label_data, mask_data, label_len, place);
pred_data, label_data, mask_data, label_len, pre_var_place, label_var_place, mask_var_place);
}

protected:
Expand Down Expand Up @@ -775,8 +778,11 @@ class ContinueMaskMetricMsg : public MetricMsg {
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
auto pre_var_place = GetVarPlace(exe_scope, pred_varname_);
auto label_var_place = GetVarPlace(exe_scope, label_varname_);
auto mask_var_place = GetVarPlace(exe_scope, mask_varname_);
cal->add_continue_mask_data(
pred_data, label_data, mask_data, label_len, place);
pred_data, label_data, mask_data, label_len, pre_var_place, label_var_place, mask_var_place);
}

protected:
Expand Down Expand Up @@ -913,8 +919,10 @@ class NanInfMetricMsg : public MetricMsg {
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
auto pre_var_place = GetVarPlace(exe_scope, pred_varname_);
auto label_var_place = GetVarPlace(exe_scope, label_varname_);
cal->add_nan_inf_data(
pred_data, label_data, label_len, place);
pred_data, label_data, label_len, pre_var_place, label_var_place);
}
};

Expand Down
128 changes: 67 additions & 61 deletions paddle/fluid/framework/fleet/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,31 +194,34 @@ void BasicAucCalculator::add_mask_data(const float* d_pred,
void BasicAucCalculator::add_float_mask_data(const float* d_pred,
const float* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place) {
if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<int64_t> h_mask;
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label,
const paddle::platform::Place& place_mask) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<int64_t> h_mask;
const float* add_pred = d_pred;
const float* add_label = d_label;
const int64_t* add_mask = d_mask;
if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) {
h_pred.resize(batch_size);
SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred);
add_pred = h_pred.data();
}
if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) {
h_label.resize(batch_size);
SyncCopyD2H(h_label.data(), d_label, batch_size, place_label);
add_label = h_label.data();
}
if (platform::is_gpu_place(place_mask) || platform::is_xpu_place(place_mask)) {
h_mask.resize(batch_size);

SyncCopyD2H(h_pred.data(), d_pred, batch_size, place);
SyncCopyD2H(h_label.data(), d_label, batch_size, place);
SyncCopyD2H(h_mask.data(), d_mask, batch_size, place);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data_with_float_label(h_pred[i], h_label[i]);
}
}
} else {
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (d_mask[i]) {
add_unlock_data_with_float_label(d_pred[i], d_label[i]);
}
SyncCopyD2H(h_mask.data(), d_mask, batch_size, place_mask);
add_mask = h_mask.data();
}
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (add_mask[i]) {
add_unlock_data_with_float_label(add_pred[i], add_label[i]);
}
}
}
Expand All @@ -229,31 +232,34 @@ void BasicAucCalculator::add_continue_mask_data(
const float* d_label,
const int64_t* d_mask,
int batch_size,
const paddle::platform::Place& place) {
if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<int64_t> h_mask;
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label,
const paddle::platform::Place& place_mask) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<int64_t> h_mask;
const float* add_pred = d_pred;
const float* add_label = d_label;
const int64_t* add_mask = d_mask;
if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) {
h_pred.resize(batch_size);
SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred);
add_pred = h_pred.data();
}
if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) {
h_label.resize(batch_size);
SyncCopyD2H(h_label.data(), d_label, batch_size, place_label);
add_label = h_label.data();
}
if (platform::is_gpu_place(place_mask) || platform::is_xpu_place(place_mask)) {
h_mask.resize(batch_size);

SyncCopyD2H(h_pred.data(), d_pred, batch_size, place);
SyncCopyD2H(h_label.data(), d_label, batch_size, place);
SyncCopyD2H(h_mask.data(), d_mask, batch_size, place);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data_with_continue_label(h_pred[i], h_label[i]);
}
}
} else {
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (d_mask[i]) {
add_unlock_data_with_continue_label(d_pred[i], d_label[i]);
}
SyncCopyD2H(h_mask.data(), d_mask, batch_size, place_mask);
add_mask = h_mask.data();
}
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (add_mask[i]) {
add_unlock_data_with_continue_label(add_pred[i], add_label[i]);
}
}
}
Expand Down Expand Up @@ -438,26 +444,26 @@ void BasicAucCalculator::add_uid_data(const float* d_pred,
void BasicAucCalculator::add_nan_inf_data(const float* d_pred,
const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place){
if (platform::is_gpu_place(place) || platform::is_xpu_place(place)) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label){
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
const float* add_pred = d_pred;
const int64_t* add_label = d_label;
if (platform::is_gpu_place(place_pred) || platform::is_xpu_place(place_pred)) {
h_pred.resize(batch_size);
SyncCopyD2H(h_pred.data(), d_pred, batch_size, place_pred);
add_pred = h_pred.data();
}
if (platform::is_gpu_place(place_label) || platform::is_xpu_place(place_label)) {
h_label.resize(batch_size);
SyncCopyD2H(h_pred.data(), d_pred, batch_size, place);
SyncCopyD2H(h_label.data(), d_label, batch_size, place);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_nan_inf_unlock_data(h_pred[i], h_label[i]);
}
} else {
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_nan_inf_unlock_data(d_pred[i], d_label[i]);
}
SyncCopyD2H(h_label.data(), d_label, batch_size, place_label);
add_label = h_label.data();
}
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_nan_inf_unlock_data(add_pred[i], add_label[i]);
}

}

void BasicAucCalculator::add_uid_unlock_data(double pred,
Expand Down
11 changes: 8 additions & 3 deletions paddle/fluid/framework/fleet/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,17 @@ class BasicAucCalculator {
const float* d_label,
const int64_t* d_mask,
int batch_size,
const paddle::platform::Place& place);
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label,
const paddle::platform::Place& place_mask);
// add continue data
void add_continue_mask_data(const float* d_pred,
const float* d_label,
const int64_t* d_mask,
int batch_size,
const paddle::platform::Place& place);
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label,
const paddle::platform::Place& place_mask);
// add sample data
void add_sample_data(const float* d_pred,
const int64_t* d_label,
Expand All @@ -109,7 +113,8 @@ class BasicAucCalculator {
void add_nan_inf_data(const float* d_pred,
const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place);
const paddle::platform::Place& place_pred,
const paddle::platform::Place& place_label);

void compute();
void computeContinueMsg();
Expand Down

0 comments on commit e581cbe

Please sign in to comment.