Skip to content

Commit

Permalink
yolo experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongzihua committed Jan 10, 2018
1 parent d987607 commit 0e5776a
Show file tree
Hide file tree
Showing 113 changed files with 23,330 additions and 2 deletions.
56 changes: 54 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,54 @@
# pytorch-YOLO-v1
an experiment for yolo-v1, including training and testing.
## pytorch YOLO-v1

[中文](中文.md)

**This is a testing repository, which can not repeat the result in original [paper](https://arxiv.org/pdf/1506.02640.pdf)**

**I will lead a discussion later, if you are interested in it, it will be welcome to contact me. If you find any bug in it, please let me know.**

I write this code for learning useage. In yoloLoss.py, i write forward only, with autograd mechanism, backward will be done automatically.

### 1. Dependency
- pytorch 0.2.0_2
- opencv
- visdom
- tqdm

### 2. Prepare

1. Download voc2012train dataset
2. Download voc2007test dataset
3. Convert xml annotations to txt file, for the purpose of using dataset.py, you should put the xml_2_txt.py in the same folder of voc dataset, or change *Annotations* path in xml_2_txt.py

### 3. Train
Run python train.py

*Be careful:* 1. change the image file path 2. I recommend you install [visdom](https://github.com/facebookresearch/visdom) and run it

### 4. Evaluation
Run python eval_voc.py

*be careful* 1. change the image file path

### 5. Discussion

1. Overfit problem

I draw the training loss curve and testing loss curve, it is obvious it has overfitting. I did many data augmentation to overcome it, but it improved little.

![loss](experimentIMG/yoloLoss.svg)

2. Activation function in the last fc layer

The origin paper use linear activation functiona for the final layer, it's output will in [-inf,+inf], but the target is in [0,1], so i use sigmoid activation function to replace it. I think this is more reasonable, if you konw the detail about it, please let me know.

Update: I did another experiment. I use linear activation, set learning rate carefully as the paper, and replace sqrt(w), sqrt(h) to (wh) to avoid nan problem. But the result is not good too.

### 6. result
1. on the train dataset, map is about 0.5~. Some result image is in trainIMgresult

![](trainIMGresult/2007_000032.jpg)

2. on the test dataset, map is about 0.2~. Some result image is in testIMGresult. test result is not well.

![](testIMGresult/000004.jpg)
259 changes: 259 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
#encoding:utf-8
#
#created by xiongzihua
#
'''
txt描述文件 image_name.jpg num x y w h c x y w h c 这样就是说一张图片中有两个目标
'''
import os
import sys
import os.path

import random
import numpy as np

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import cv2

class yoloDataset(data.Dataset):
image_size = 224
def __init__(self,root,list_file,train,transform):
print('data init')
self.root=root
self.train = train
self.transform=transform
self.fnames = []
self.boxes = []
self.labels = []
self.mean = (123,117,104)#RGB

with open(list_file) as f:
lines = f.readlines()

for line in lines:
splited = line.strip().split()
self.fnames.append(splited[0])
num_faces = int(splited[1])
box=[]
label=[]
for i in range(num_faces):
x = float(splited[2+5*i])
y = float(splited[3+5*i])
x2 = float(splited[4+5*i])
y2 = float(splited[5+5*i])
c = splited[6+5*i]
box.append([x,y,x2,y2])
label.append(int(c)+1)
self.boxes.append(torch.Tensor(box))
self.labels.append(torch.LongTensor(label))
self.num_samples = len(self.boxes)

def __getitem__(self,idx):
fname = self.fnames[idx]
img = cv2.imread(os.path.join(self.root+fname))
boxes = self.boxes[idx].clone()
labels = self.labels[idx].clone()

if self.train:
#img = self.random_bright(img)
img, boxes = self.random_flip(img, boxes)
img,boxes = self.randomScale(img,boxes)
img = self.randomBlur(img)
img = self.RandomBrightness(img)
img = self.RandomHue(img)
img = self.RandomSaturation(img)
img,boxes,labels = self.randomShift(img,boxes,labels)

h,w,_ = img.shape
boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes)
img = self.BGR2RGB(img) #because pytorch pretrained model use RGB
img = self.subMean(img,self.mean) #减去均值
img = cv2.resize(img,(self.image_size,self.image_size))
target = self.encoder(boxes,labels)# 7x7x30
for t in self.transform:
img = t(img)

return img,target
def __len__(self):
return self.num_samples

def encoder(self,boxes,labels):
'''
boxes (tensor) [[x1,y1,x2,y2],[]]
labels (tensor) [...]
return 7x7x30
'''
target = torch.zeros((7,7,30))
cell_size = 1./7
wh = boxes[:,2:]-boxes[:,:2]
cxcy = (boxes[:,2:]+boxes[:,:2])/2
for i in range(cxcy.size()[0]):
cxcy_sample = cxcy[i]
ij = (cxcy_sample/cell_size).ceil()-1 #
target[int(ij[1]),int(ij[0]),4] = 1
target[int(ij[1]),int(ij[0]),9] = 1
target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1
xy = ij*cell_size #匹配到的网格的左上角相对坐标
delta_xy = (cxcy_sample -xy)/cell_size
target[int(ij[1]),int(ij[0]),2:4] = wh[i]
target[int(ij[1]),int(ij[0]),:2] = delta_xy
target[int(ij[1]),int(ij[0]),7:9] = wh[i]
target[int(ij[1]),int(ij[0]),5:7] = delta_xy
return target
def BGR2RGB(self,img):
return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
def BGR2HSV(self,img):
return cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
def HSV2BGR(self,img):
return cv2.cvtColor(img,cv2.COLOR_HSV2BGR)

def RandomBrightness(self,bgr):
if random.random() < 0.5:
hsv = self.BGR2HSV(bgr)
h,s,v = cv2.split(hsv)
adjust = random.choice([0.5,1.5])
v = v*adjust
v = np.clip(v, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h,s,v))
bgr = self.HSV2BGR(hsv)
return bgr
def RandomSaturation(self,bgr):
if random.random() < 0.5:
hsv = self.BGR2HSV(bgr)
h,s,v = cv2.split(hsv)
adjust = random.choice([0.5,1.5])
s = s*adjust
s = np.clip(s, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h,s,v))
bgr = self.HSV2BGR(hsv)
return bgr
def RandomHue(self,bgr):
if random.random() < 0.5:
hsv = self.BGR2HSV(bgr)
h,s,v = cv2.split(hsv)
adjust = random.choice([0.5,1.5])
h = h*adjust
h = np.clip(h, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h,s,v))
bgr = self.HSV2BGR(hsv)
return bgr

def randomBlur(self,bgr):
if random.random()<0.5:
bgr = cv2.blur(bgr,(5,5))
return bgr

def randomShift(self,bgr,boxes,labels):
#平移变换
center = (boxes[:,2:]+boxes[:,:2])/2
if random.random() <0.5:
height,width,c = bgr.shape
after_shfit_image = np.zeros((height,width,c),dtype=bgr.dtype)
after_shfit_image[:,:,:] = (104,117,123) #bgr
shift_x = random.uniform(-width*0.2,width*0.2)
shift_y = random.uniform(-height*0.2,height*0.2)
#print(bgr.shape,shift_x,shift_y)
#原图像的平移
if shift_x>=0 and shift_y>=0:
after_shfit_image[int(shift_y):,int(shift_x):,:] = bgr[:height-int(shift_y),:width-int(shift_x),:]
elif shift_x>=0 and shift_y<0:
after_shfit_image[:height+int(shift_y),int(shift_x):,:] = bgr[-int(shift_y):,:width-int(shift_x),:]
elif shift_x <0 and shift_y >=0:
after_shfit_image[int(shift_y):,:width+int(shift_x),:] = bgr[:height-int(shift_y),-int(shift_x):,:]
elif shift_x<0 and shift_y<0:
after_shfit_image[:height+int(shift_y),:width+int(shift_x),:] = bgr[-int(shift_y):,-int(shift_x):,:]

shift_xy = torch.FloatTensor([[int(shift_x),int(shift_y)]]).expand_as(center)
center = center + shift_xy
mask1 = (center[:,0] >0) & (center[:,0] < width)
mask2 = (center[:,1] >0) & (center[:,1] < height)
mask = (mask1 & mask2).view(-1,1)
boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
if len(boxes_in) == 0:
return bgr,boxes,labels
box_shift = torch.FloatTensor([[int(shift_x),int(shift_y),int(shift_x),int(shift_y)]]).expand_as(boxes_in)
boxes_in = boxes_in+box_shift
labels_in = labels[mask.view(-1)]
return after_shfit_image,boxes_in,labels_in
return bgr,boxes,labels

def randomScale(self,bgr,boxes):
#固定住高度,以0.6-1.4伸缩宽度,做图像形变
if random.random() < 0.5:
scale = random.uniform(0.6,1.4)
height,width,c = bgr.shape
bgr = cv2.resize(bgr,(int(width*scale),height))
scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes)
boxes = boxes * scale_tensor
return bgr,boxes
return bgr,boxes

def randomCrop(self,bgr,boxes,labels):
if random.random() < 0.5:
center = (boxes[:,2:]+boxes[:,:2])/2
height,width,c = bgr.shape
h = random.uniform(0.6*height,height)
w = random.uniform(0.6*width,width)
x = random.uniform(0,width-w)
y = random.uniform(0,height-h)
x,y,h,w = int(x),int(y),int(h),int(w)

center = center - torch.FloatTensor([[x,y]]).expand_as(center)
mask1 = (center[:,0]>0) & (center[:,0]<w)
mask2 = (center[:,1]>0) & (center[:,1]<h)
mask = (mask1 & mask2).view(-1,1)

boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
if(len(boxes_in)==0):
return bgr,boxes,labels
box_shift = torch.FloatTensor([[x,y,x,y]]).expand_as(boxes_in)

boxes_in = boxes_in - box_shift
labels_in = labels[mask.view(-1)]
img_croped = bgr[y:y+h,x:x+w,:]
return img_croped,boxes_in,labels_in
return bgr,boxes,labels




def subMean(self,bgr,mean):
mean = np.array(mean, dtype=np.float32)
bgr = bgr - mean
return bgr

def random_flip(self, im, boxes):
if random.random() < 0.5:
im_lr = np.fliplr(im).copy()
h,w,_ = im.shape
xmin = w - boxes[:,2]
xmax = w - boxes[:,0]
boxes[:,0] = xmin
boxes[:,2] = xmax
return im_lr, boxes
return im, boxes
def random_bright(self, im, delta=16):
alpha = random.random()
if alpha > 0.3:
im = im * alpha + random.randrange(-delta,delta)
im = im.clip(min=0,max=255).astype(np.uint8)
return im

def main():
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
file_root = '/media/xiong/449C8E929C8E7DE4/codedata/voc2007/VOCdevkit_train/VOC2007/JPEGImages/'
train_dataset = yoloDataset(root=file_root,list_file='voc2007train.txt',train=True,transform = [transforms.ToTensor()] )
train_loader = DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=0)
train_iter = iter(train_loader)
img,target = next(train_iter)
print(img,target)


if __name__ == '__main__':
main()


Loading

0 comments on commit 0e5776a

Please sign in to comment.