-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xiongzihua
committed
Jan 10, 2018
1 parent
d987607
commit 0e5776a
Showing
113 changed files
with
23,330 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,54 @@ | ||
# pytorch-YOLO-v1 | ||
an experiment for yolo-v1, including training and testing. | ||
## pytorch YOLO-v1 | ||
|
||
[中文](中文.md) | ||
|
||
**This is a testing repository, which can not repeat the result in original [paper](https://arxiv.org/pdf/1506.02640.pdf)** | ||
|
||
**I will lead a discussion later, if you are interested in it, it will be welcome to contact me. If you find any bug in it, please let me know.** | ||
|
||
I write this code for learning useage. In yoloLoss.py, i write forward only, with autograd mechanism, backward will be done automatically. | ||
|
||
### 1. Dependency | ||
- pytorch 0.2.0_2 | ||
- opencv | ||
- visdom | ||
- tqdm | ||
|
||
### 2. Prepare | ||
|
||
1. Download voc2012train dataset | ||
2. Download voc2007test dataset | ||
3. Convert xml annotations to txt file, for the purpose of using dataset.py, you should put the xml_2_txt.py in the same folder of voc dataset, or change *Annotations* path in xml_2_txt.py | ||
|
||
### 3. Train | ||
Run python train.py | ||
|
||
*Be careful:* 1. change the image file path 2. I recommend you install [visdom](https://github.com/facebookresearch/visdom) and run it | ||
|
||
### 4. Evaluation | ||
Run python eval_voc.py | ||
|
||
*be careful* 1. change the image file path | ||
|
||
### 5. Discussion | ||
|
||
1. Overfit problem | ||
|
||
I draw the training loss curve and testing loss curve, it is obvious it has overfitting. I did many data augmentation to overcome it, but it improved little. | ||
|
||
 | ||
|
||
2. Activation function in the last fc layer | ||
|
||
The origin paper use linear activation functiona for the final layer, it's output will in [-inf,+inf], but the target is in [0,1], so i use sigmoid activation function to replace it. I think this is more reasonable, if you konw the detail about it, please let me know. | ||
|
||
Update: I did another experiment. I use linear activation, set learning rate carefully as the paper, and replace sqrt(w), sqrt(h) to (wh) to avoid nan problem. But the result is not good too. | ||
|
||
### 6. result | ||
1. on the train dataset, map is about 0.5~. Some result image is in trainIMgresult | ||
|
||
 | ||
|
||
2. on the test dataset, map is about 0.2~. Some result image is in testIMGresult. test result is not well. | ||
|
||
 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
#encoding:utf-8 | ||
# | ||
#created by xiongzihua | ||
# | ||
''' | ||
txt描述文件 image_name.jpg num x y w h c x y w h c 这样就是说一张图片中有两个目标 | ||
''' | ||
import os | ||
import sys | ||
import os.path | ||
|
||
import random | ||
import numpy as np | ||
|
||
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
|
||
import cv2 | ||
|
||
class yoloDataset(data.Dataset): | ||
image_size = 224 | ||
def __init__(self,root,list_file,train,transform): | ||
print('data init') | ||
self.root=root | ||
self.train = train | ||
self.transform=transform | ||
self.fnames = [] | ||
self.boxes = [] | ||
self.labels = [] | ||
self.mean = (123,117,104)#RGB | ||
|
||
with open(list_file) as f: | ||
lines = f.readlines() | ||
|
||
for line in lines: | ||
splited = line.strip().split() | ||
self.fnames.append(splited[0]) | ||
num_faces = int(splited[1]) | ||
box=[] | ||
label=[] | ||
for i in range(num_faces): | ||
x = float(splited[2+5*i]) | ||
y = float(splited[3+5*i]) | ||
x2 = float(splited[4+5*i]) | ||
y2 = float(splited[5+5*i]) | ||
c = splited[6+5*i] | ||
box.append([x,y,x2,y2]) | ||
label.append(int(c)+1) | ||
self.boxes.append(torch.Tensor(box)) | ||
self.labels.append(torch.LongTensor(label)) | ||
self.num_samples = len(self.boxes) | ||
|
||
def __getitem__(self,idx): | ||
fname = self.fnames[idx] | ||
img = cv2.imread(os.path.join(self.root+fname)) | ||
boxes = self.boxes[idx].clone() | ||
labels = self.labels[idx].clone() | ||
|
||
if self.train: | ||
#img = self.random_bright(img) | ||
img, boxes = self.random_flip(img, boxes) | ||
img,boxes = self.randomScale(img,boxes) | ||
img = self.randomBlur(img) | ||
img = self.RandomBrightness(img) | ||
img = self.RandomHue(img) | ||
img = self.RandomSaturation(img) | ||
img,boxes,labels = self.randomShift(img,boxes,labels) | ||
|
||
h,w,_ = img.shape | ||
boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes) | ||
img = self.BGR2RGB(img) #because pytorch pretrained model use RGB | ||
img = self.subMean(img,self.mean) #减去均值 | ||
img = cv2.resize(img,(self.image_size,self.image_size)) | ||
target = self.encoder(boxes,labels)# 7x7x30 | ||
for t in self.transform: | ||
img = t(img) | ||
|
||
return img,target | ||
def __len__(self): | ||
return self.num_samples | ||
|
||
def encoder(self,boxes,labels): | ||
''' | ||
boxes (tensor) [[x1,y1,x2,y2],[]] | ||
labels (tensor) [...] | ||
return 7x7x30 | ||
''' | ||
target = torch.zeros((7,7,30)) | ||
cell_size = 1./7 | ||
wh = boxes[:,2:]-boxes[:,:2] | ||
cxcy = (boxes[:,2:]+boxes[:,:2])/2 | ||
for i in range(cxcy.size()[0]): | ||
cxcy_sample = cxcy[i] | ||
ij = (cxcy_sample/cell_size).ceil()-1 # | ||
target[int(ij[1]),int(ij[0]),4] = 1 | ||
target[int(ij[1]),int(ij[0]),9] = 1 | ||
target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1 | ||
xy = ij*cell_size #匹配到的网格的左上角相对坐标 | ||
delta_xy = (cxcy_sample -xy)/cell_size | ||
target[int(ij[1]),int(ij[0]),2:4] = wh[i] | ||
target[int(ij[1]),int(ij[0]),:2] = delta_xy | ||
target[int(ij[1]),int(ij[0]),7:9] = wh[i] | ||
target[int(ij[1]),int(ij[0]),5:7] = delta_xy | ||
return target | ||
def BGR2RGB(self,img): | ||
return cv2.cvtColor(img,cv2.COLOR_BGR2RGB) | ||
def BGR2HSV(self,img): | ||
return cv2.cvtColor(img,cv2.COLOR_BGR2HSV) | ||
def HSV2BGR(self,img): | ||
return cv2.cvtColor(img,cv2.COLOR_HSV2BGR) | ||
|
||
def RandomBrightness(self,bgr): | ||
if random.random() < 0.5: | ||
hsv = self.BGR2HSV(bgr) | ||
h,s,v = cv2.split(hsv) | ||
adjust = random.choice([0.5,1.5]) | ||
v = v*adjust | ||
v = np.clip(v, 0, 255).astype(hsv.dtype) | ||
hsv = cv2.merge((h,s,v)) | ||
bgr = self.HSV2BGR(hsv) | ||
return bgr | ||
def RandomSaturation(self,bgr): | ||
if random.random() < 0.5: | ||
hsv = self.BGR2HSV(bgr) | ||
h,s,v = cv2.split(hsv) | ||
adjust = random.choice([0.5,1.5]) | ||
s = s*adjust | ||
s = np.clip(s, 0, 255).astype(hsv.dtype) | ||
hsv = cv2.merge((h,s,v)) | ||
bgr = self.HSV2BGR(hsv) | ||
return bgr | ||
def RandomHue(self,bgr): | ||
if random.random() < 0.5: | ||
hsv = self.BGR2HSV(bgr) | ||
h,s,v = cv2.split(hsv) | ||
adjust = random.choice([0.5,1.5]) | ||
h = h*adjust | ||
h = np.clip(h, 0, 255).astype(hsv.dtype) | ||
hsv = cv2.merge((h,s,v)) | ||
bgr = self.HSV2BGR(hsv) | ||
return bgr | ||
|
||
def randomBlur(self,bgr): | ||
if random.random()<0.5: | ||
bgr = cv2.blur(bgr,(5,5)) | ||
return bgr | ||
|
||
def randomShift(self,bgr,boxes,labels): | ||
#平移变换 | ||
center = (boxes[:,2:]+boxes[:,:2])/2 | ||
if random.random() <0.5: | ||
height,width,c = bgr.shape | ||
after_shfit_image = np.zeros((height,width,c),dtype=bgr.dtype) | ||
after_shfit_image[:,:,:] = (104,117,123) #bgr | ||
shift_x = random.uniform(-width*0.2,width*0.2) | ||
shift_y = random.uniform(-height*0.2,height*0.2) | ||
#print(bgr.shape,shift_x,shift_y) | ||
#原图像的平移 | ||
if shift_x>=0 and shift_y>=0: | ||
after_shfit_image[int(shift_y):,int(shift_x):,:] = bgr[:height-int(shift_y),:width-int(shift_x),:] | ||
elif shift_x>=0 and shift_y<0: | ||
after_shfit_image[:height+int(shift_y),int(shift_x):,:] = bgr[-int(shift_y):,:width-int(shift_x),:] | ||
elif shift_x <0 and shift_y >=0: | ||
after_shfit_image[int(shift_y):,:width+int(shift_x),:] = bgr[:height-int(shift_y),-int(shift_x):,:] | ||
elif shift_x<0 and shift_y<0: | ||
after_shfit_image[:height+int(shift_y),:width+int(shift_x),:] = bgr[-int(shift_y):,-int(shift_x):,:] | ||
|
||
shift_xy = torch.FloatTensor([[int(shift_x),int(shift_y)]]).expand_as(center) | ||
center = center + shift_xy | ||
mask1 = (center[:,0] >0) & (center[:,0] < width) | ||
mask2 = (center[:,1] >0) & (center[:,1] < height) | ||
mask = (mask1 & mask2).view(-1,1) | ||
boxes_in = boxes[mask.expand_as(boxes)].view(-1,4) | ||
if len(boxes_in) == 0: | ||
return bgr,boxes,labels | ||
box_shift = torch.FloatTensor([[int(shift_x),int(shift_y),int(shift_x),int(shift_y)]]).expand_as(boxes_in) | ||
boxes_in = boxes_in+box_shift | ||
labels_in = labels[mask.view(-1)] | ||
return after_shfit_image,boxes_in,labels_in | ||
return bgr,boxes,labels | ||
|
||
def randomScale(self,bgr,boxes): | ||
#固定住高度,以0.6-1.4伸缩宽度,做图像形变 | ||
if random.random() < 0.5: | ||
scale = random.uniform(0.6,1.4) | ||
height,width,c = bgr.shape | ||
bgr = cv2.resize(bgr,(int(width*scale),height)) | ||
scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes) | ||
boxes = boxes * scale_tensor | ||
return bgr,boxes | ||
return bgr,boxes | ||
|
||
def randomCrop(self,bgr,boxes,labels): | ||
if random.random() < 0.5: | ||
center = (boxes[:,2:]+boxes[:,:2])/2 | ||
height,width,c = bgr.shape | ||
h = random.uniform(0.6*height,height) | ||
w = random.uniform(0.6*width,width) | ||
x = random.uniform(0,width-w) | ||
y = random.uniform(0,height-h) | ||
x,y,h,w = int(x),int(y),int(h),int(w) | ||
|
||
center = center - torch.FloatTensor([[x,y]]).expand_as(center) | ||
mask1 = (center[:,0]>0) & (center[:,0]<w) | ||
mask2 = (center[:,1]>0) & (center[:,1]<h) | ||
mask = (mask1 & mask2).view(-1,1) | ||
|
||
boxes_in = boxes[mask.expand_as(boxes)].view(-1,4) | ||
if(len(boxes_in)==0): | ||
return bgr,boxes,labels | ||
box_shift = torch.FloatTensor([[x,y,x,y]]).expand_as(boxes_in) | ||
|
||
boxes_in = boxes_in - box_shift | ||
labels_in = labels[mask.view(-1)] | ||
img_croped = bgr[y:y+h,x:x+w,:] | ||
return img_croped,boxes_in,labels_in | ||
return bgr,boxes,labels | ||
|
||
|
||
|
||
|
||
def subMean(self,bgr,mean): | ||
mean = np.array(mean, dtype=np.float32) | ||
bgr = bgr - mean | ||
return bgr | ||
|
||
def random_flip(self, im, boxes): | ||
if random.random() < 0.5: | ||
im_lr = np.fliplr(im).copy() | ||
h,w,_ = im.shape | ||
xmin = w - boxes[:,2] | ||
xmax = w - boxes[:,0] | ||
boxes[:,0] = xmin | ||
boxes[:,2] = xmax | ||
return im_lr, boxes | ||
return im, boxes | ||
def random_bright(self, im, delta=16): | ||
alpha = random.random() | ||
if alpha > 0.3: | ||
im = im * alpha + random.randrange(-delta,delta) | ||
im = im.clip(min=0,max=255).astype(np.uint8) | ||
return im | ||
|
||
def main(): | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
file_root = '/media/xiong/449C8E929C8E7DE4/codedata/voc2007/VOCdevkit_train/VOC2007/JPEGImages/' | ||
train_dataset = yoloDataset(root=file_root,list_file='voc2007train.txt',train=True,transform = [transforms.ToTensor()] ) | ||
train_loader = DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=0) | ||
train_iter = iter(train_loader) | ||
img,target = next(train_iter) | ||
print(img,target) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
Oops, something went wrong.