Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving inference time #321

Closed
tushardhadiwal opened this issue Jan 7, 2021 · 2 comments
Closed

Improving inference time #321

tushardhadiwal opened this issue Jan 7, 2021 · 2 comments

Comments

@tushardhadiwal
Copy link

Hi,

I came across this trick for improving inference time. opencv/opencv#14827 (comment)
While converting yolov4.cfg to tensort engine file, The cfg file that Used did not have nms-threshold=0 set in all yolo layers.
I do see some code in this repo for nms boxes etc.

Will i get any speedup in inference time of yolov4 if I add the values in cfg file? Or is this already taken care of while building trt engine?

Thanks

@jkjung-avt
Copy link
Owner

Thanks for raising the question. But that does not apply to the code in this repo.

If you can read CUDA code, you'd see that I don't do NMS in the yolo_layer plugin.

// CalDetection(): This kernel processes 1 yolo layer calculation. It
// distributes calculations so that 1 GPU thread would be responsible
// for each grid/anchor combination.
// NOTE: The output (x, y, w, h) are between 0.0 and 1.0
// (relative to orginal image width and height).
__global__ void CalDetection(const float *input, float *output,
int batch_size,
int yolo_width, int yolo_height,
int num_anchors, const float *anchors,
int num_classes, int input_w, int input_h,
float scale_x_y)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x;
Detection* det = ((Detection*) output) + idx;
int total_grids = yolo_width * yolo_height;
if (idx >= batch_size * total_grids * num_anchors) return;
int info_len = 5 + num_classes;
//int batch_idx = idx / (total_grids * num_anchors);
int group_idx = idx / total_grids;
int anchor_idx = group_idx % num_anchors;
const float* cur_input = input + group_idx * (info_len * total_grids) + (idx % total_grids);
int class_id;
float max_cls_logit = -CUDART_INF_F; // minus infinity
for (int i = 5; i < info_len; ++i) {
float l = *(cur_input + i * total_grids);
if (l > max_cls_logit) {
max_cls_logit = l;
class_id = i - 5;
}
}
float max_cls_prob = sigmoidGPU(max_cls_logit);
float box_prob = sigmoidGPU(*(cur_input + 4 * total_grids));
//if (max_cls_prob < IGNORE_THRESH || box_prob < IGNORE_THRESH)
// return;
int row = (idx % total_grids) / yolo_width;
int col = (idx % total_grids) % yolo_width;
det->bbox[0] = (col + scale_sigmoidGPU(*(cur_input + 0 * total_grids), scale_x_y)) / yolo_width; // [0, 1]
det->bbox[1] = (row + scale_sigmoidGPU(*(cur_input + 1 * total_grids), scale_x_y)) / yolo_height; // [0, 1]
det->bbox[2] = __expf(*(cur_input + 2 * total_grids)) * *(anchors + 2 * anchor_idx + 0) / input_w; // [0, 1]
det->bbox[3] = __expf(*(cur_input + 3 * total_grids)) * *(anchors + 2 * anchor_idx + 1) / input_h; // [0, 1]
det->bbox[0] -= det->bbox[2] / 2; // shift from center to top-left
det->bbox[1] -= det->bbox[3] / 2;
det->det_confidence = box_prob;
det->class_id = class_id;
det->class_confidence = max_cls_prob;
}

Instead, I do NMS with python as shown below. The NMS code is written in python and indeed could be slow. You might improve FPS by optimizing this part (for example, replace it with C++ code).

# NMS
nms_detections = np.zeros((0, 7), dtype=detections.dtype)
for class_id in set(detections[:, 5]):
idxs = np.where(detections[:, 5] == class_id)
cls_detections = detections[idxs]
keep = _nms_boxes(cls_detections, nms_threshold)
nms_detections = np.concatenate(
[nms_detections, cls_detections[keep]], axis=0)

@BigJoon
Copy link
Contributor

BigJoon commented Jan 10, 2021

Thanks for the detailed explanation. I'll think about optimizing the NMS code part with C++ code.
I think it will be a fun and proud work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants