Skip to content

Commit

Permalink
decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongzihua committed Jul 15, 2018
1 parent 869f9f5 commit 847400e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def decoder(pred):
mask1 = contain > 0.1 #大于阈值
mask2 = (contain==contain.max()) #we always select the best contain_prob what ever it>0.9
mask = (mask1+mask2).gt(0)
min_score,min_index = torch.min(mask,2) #每个cell只选最大概率的那个预测框
# min_score,min_index = torch.min(contain,2) #每个cell只选最大概率的那个预测框
for i in range(grid_num):
for j in range(grid_num):
for b in range(2):
index = min_index[i,j]
mask[i,j,index] = 0
# index = min_index[i,j]
# mask[i,j,index] = 0
if mask[i,j,b] == 1:
#print(i,j,b)
box = pred[i,j,b*5:b*5+4]
Expand Down Expand Up @@ -168,7 +168,7 @@ def predict_gpu(model,image_name,root_path=''):
model.load_state_dict(torch.load('best.pth'))
model.eval()
model.cuda()
image_name = 'person.jpg'
image_name = 'dog.jpg'
image = cv2.imread(image_name)
print('predicting...')
result = predict_gpu(model,image_name)
Expand Down

0 comments on commit 847400e

Please sign in to comment.