Skip to content

Commit 1fa1d31

Browse files
Enhance and restructure source code files
1 parent 6f0a678 commit 1fa1d31

17 files changed

+97
-1634
lines changed

classifier-elghalaba.ipynb

-1,492
This file was deleted.

classifier.py

+5-43
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1-
from sklearn.neighbors import KNeighborsClassifier
2-
from sklearn.neural_network import MLPClassifier # MLP is an NN
3-
from sklearn import svm, tree
4-
import numpy as np
5-
import argparse
6-
import imutils
71
import cv2
82
import os
9-
import skimage.io as io
103
import random
11-
from skimage.transform import rotate
12-
from sklearn.model_selection import train_test_split
13-
import pickle
14-
import matplotlib.pyplot as plt
4+
import argparse
5+
import numpy as np
6+
from sklearn import svm
157

168

179
########## Variables ##########
@@ -22,32 +14,11 @@
2214
np.random.seed(random_seed)
2315

2416
classifiers = {
25-
'SVM': svm.LinearSVC(random_state=random_seed),
26-
'KNN': KNeighborsClassifier(n_neighbors=7),
27-
'NN': MLPClassifier(solver='sgd', random_state=random_seed, hidden_layer_sizes=(500,), max_iter=200, verbose=1),
28-
'TREE': tree.DecisionTreeClassifier(random_state=0, max_depth=10)
17+
'SVM': svm.LinearSVC(random_state=random_seed)
2918
}
3019

3120
########## Methods ##########
3221

33-
34-
def extract_raw_pixels(img):
35-
return cv2.resize(img, target_img_size).flatten()
36-
37-
38-
def extract_sift_features(img):
39-
img = cv2.resize(img, target_img_size)
40-
41-
sift = cv2.SIFT_create()
42-
_, features = sift.detectAndCompute(img, None)
43-
44-
try:
45-
return features.flatten()
46-
except:
47-
io.imshow(img)
48-
io.show()
49-
50-
5122
def extract_hog_features(img):
5223
img = cv2.resize(img, target_img_size)
5324
win_size = (32, 32)
@@ -64,15 +35,8 @@ def extract_hog_features(img):
6435
h = h.flatten()
6536
return h.flatten()
6637

67-
6838
def extract_features(img, feature_set='hog'):
69-
if feature_set == 'hog':
70-
return extract_hog_features(img)
71-
elif feature_set == 'sift':
72-
return extract_sift_features(img)
73-
else:
74-
return extract_raw_pixels(img)
75-
39+
return extract_hog_features(img)
7640

7741
def get_directories():
7842
directories = []
@@ -83,7 +47,6 @@ def get_directories():
8347

8448
return directories
8549

86-
8750
def load_dataset(feature_set='hog'):
8851
labels = []
8952
features = []
@@ -105,7 +68,6 @@ def load_dataset(feature_set='hog'):
10568

10669
return features, labels
10770

108-
10971
def run_experiment(train_features, test_features, train_labels, test_labels, model_name):
11072
model = classifiers[model_name]
11173
print('############## Training', model_name, "##############")

helper_methods.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1-
import numpy as np
21
import cv2
2+
import random
3+
import numpy as np
4+
from sklearn import svm
35

46
# List of maps and needed variables #
7+
8+
########## Variables ##########
9+
random_seed = 42
10+
random.seed(random_seed)
11+
target_img_size = (32, 32)
12+
np.random.seed(random_seed)
13+
classifiers = {
14+
'SVM': svm.LinearSVC(random_state=random_seed)
15+
}
16+
517
direct_labels = ['x', 'b', 'clef', 'dot', 'hash', 'd', 't_2', 't_4', 'symbol_bb', 'barline']
618
direct_texts = {'x':'##', 'b':'&', 'hash':'#', 'd':'', 'symbol_bb':'&&', 'dot':'.', 'clef':'', 't_2':'2', 't_4':'4', 'barline':''}
719

@@ -140,3 +152,19 @@ def preprocess_img(img_path):
140152
# 4. Return image shape (width, height) and processed image #
141153
n, m = img.shape
142154
return n, m, img
155+
156+
def extract_hog_features(img):
157+
img = cv2.resize(img, target_img_size)
158+
win_size = (32, 32)
159+
cell_size = (4, 4)
160+
block_size_in_cells = (2, 2)
161+
162+
block_size = (block_size_in_cells[1] * cell_size[1],
163+
block_size_in_cells[0] * cell_size[0])
164+
block_stride = (cell_size[1], cell_size[0])
165+
nbins = 9
166+
hog = cv2.HOGDescriptor(win_size, block_size,
167+
block_stride, cell_size, nbins)
168+
h = hog.compute(img)
169+
h = h.flatten()
170+
return h.flatten()

input/01.PNG

23.4 KB
Loading

input/02.PNG

77.3 KB
Loading

input/03.PNG

22.2 KB
Loading

input/04.PNG

20.6 KB
Loading

input/05.PNG

27 KB
Loading

input/06.PNG

23.1 KB
Loading

input/07.PNG

27.8 KB
Loading

input/08.PNG

28.1 KB
Loading

input/09.PNG

38 KB
Loading

input/10.PNG

26.8 KB
Loading

main.py

+59-57
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from classifier import *
21
from preprocessing import *
32
from staff_removal import *
43
from helper_methods import *
54

65
import argparse
76
import os
87
import datetime
8+
99
# Initialize parser
1010
parser = argparse.ArgumentParser()
11-
1211
parser.add_argument("inputfolder", help = "Input File")
1312
parser.add_argument("outputfolder", help = "Output File")
1413

@@ -22,11 +21,9 @@
2221

2322
# Threshold for line to be considered as an initial staff line #
2423
threshold = 0.8
25-
accidentals = ['x', 'hash', 'b', 'symbol_bb', 'd']
26-
27-
2824
filename = 'model/model.sav'
2925
model = pickle.load(open(filename, 'rb'))
26+
accidentals = ['x', 'hash', 'b', 'symbol_bb', 'd']
3027

3128
def preprocessing(inputfolder, fn, f):
3229
# Get image and its dimensions #
@@ -46,58 +43,63 @@ def preprocessing(inputfolder, fn, f):
4643

4744
return cutted, ref_lines, lines_spacing
4845

46+
def get_target_boundaries(label, cur_symbol, y2):
47+
if label == 'b_8':
48+
cutted_boundaries = cut_boundaries(cur_symbol, 2, y2)
49+
label = 'a_8'
50+
elif label == 'b_8_flipped':
51+
cutted_boundaries = cut_boundaries(cur_symbol, 2, y2)
52+
label = 'a_8_flipped'
53+
elif label == 'b_16':
54+
cutted_boundaries = cut_boundaries(cur_symbol, 4, y2)
55+
label = 'a_16'
56+
elif label == 'b_16_flipped':
57+
cutted_boundaries = cut_boundaries(cur_symbol, 4, y2)
58+
label = 'a_16_flipped'
59+
else:
60+
cutted_boundaries = cut_boundaries(cur_symbol, 1, y2)
61+
62+
return label, cutted_boundaries
63+
64+
def get_label_cutted_boundaries(boundary, height_before, cutted):
65+
# Get the current symbol #
66+
x1, y1, x2, y2 = boundary
67+
cur_symbol = cutted[y1-height_before:y2+1-height_before, x1:x2+1]
68+
69+
# Clean and cut #
70+
cur_symbol = clean_and_cut(cur_symbol)
71+
cur_symbol = 255 - cur_symbol
72+
73+
# Start prediction of the current symbol #
74+
feature = extract_hog_features(cur_symbol)
75+
label = str(model.predict([feature])[0])
76+
77+
return get_target_boundaries(label, cur_symbol, y2)
78+
4979
def process_image(inputfolder, fn, f):
5080
cutted, ref_lines, lines_spacing = preprocessing(inputfolder, fn, f)
5181

5282
last_acc = ''
5383
last_num = ''
5484
height_before = 0
5585

56-
5786
if len(cutted) > 1:
5887
f.write('{\n')
5988

6089

6190
for it in range(len(cutted)):
6291
f.write('[')
6392
is_started = False
64-
cur_img = cutted[it].copy()
6593

66-
67-
symbols_boundries = segmentation(height_before, cutted[it])
68-
symbols_boundries.sort(key = lambda x: (x[0], x[1]))
94+
95+
symbols_boundaries = segmentation(height_before, cutted[it])
96+
symbols_boundaries.sort(key = lambda x: (x[0], x[1]))
6997

70-
symbols = []
71-
for boundry in symbols_boundries:
72-
# Get the current symbol #
73-
x1, y1, x2, y2 = boundry
74-
cur_symbol = cutted[it][y1-height_before:y2+1-height_before, x1:x2+1]
75-
76-
# Clean and cut #
77-
cur_symbol = clean_and_cut(cur_symbol)
78-
cur_symbol = 255 - cur_symbol
98+
for boundary in symbols_boundaries:
99+
label, cutted_boundaries = get_label_cutted_boundaries(boundary, height_before, cutted[it])
79100

80-
# Start prediction of the current symbol #
81-
feature = extract_features(cur_symbol, 'hog')
82-
label = str(model.predict([feature])[0])
83-
84101
if label == 'clef':
85102
is_started = True
86-
87-
if label == 'b_8':
88-
cutted_boundaries = cut_boundaries(cur_symbol, 2, y2)
89-
label = 'a_8'
90-
elif label == 'b_8_flipped':
91-
cutted_boundaries = cut_boundaries(cur_symbol, 2, y2)
92-
label = 'a_8_flipped'
93-
elif label == 'b_16':
94-
cutted_boundaries = cut_boundaries(cur_symbol, 4, y2)
95-
label = 'a_16'
96-
elif label == 'b_16_flipped':
97-
cutted_boundaries = cut_boundaries(cur_symbol, 4, y2)
98-
label = 'a_16_flipped'
99-
else:
100-
cutted_boundaries = cut_boundaries(cur_symbol, 1, y2)
101103

102104
for cutted_boundary in cutted_boundaries:
103105
_, y1, _, y2 = cutted_boundary
@@ -126,29 +128,29 @@ def process_image(inputfolder, fn, f):
126128
if len(cutted) > 1:
127129
f.write('}')
128130

129-
for i in [args.inputfolder]:
131+
def main():
130132
try:
131133
os.mkdir(args.outputfolder)
132-
except OSError as error:
134+
except OSError as error:
133135
pass
134-
135136

136137
list_of_images = os.listdir(args.inputfolder)
137-
138-
for i, fn in enumerate(list_of_images):
139-
# Open the output text file #
140-
file_prefix = fn.split('.')[0]
141-
f = open(f"{args.outputfolder}/{file_prefix}.txt", "w")
142-
143-
144-
# Process each image separately #
145-
try:
146-
process_image(args.inputfolder, fn, f)
147-
except:
148-
print(f'{args.inputfolder}-{fn} has been failed !!')
149-
pass
150-
151-
f.close()
138+
for _, fn in enumerate(list_of_images):
139+
# Open the output text file #
140+
file_prefix = fn.split('.')[0]
141+
f = open(f"{args.outputfolder}/{file_prefix}.txt", "w")
142+
143+
# Process each image separately #
144+
try:
145+
process_image(args.inputfolder, fn, f)
146+
except Exception as e:
147+
print(e)
148+
print(f'{args.inputfolder}-{fn} has been failed !!')
149+
pass
150+
151+
f.close()
152+
print('Finished !!')
152153

153154

154-
print('Finished !!')
155+
if __name__ == "__main__":
156+
main()

output/Output.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Input Folder: ./inputOutput Folder: ./outputDate: 2021-01-12 00:08:22.493123
1+
Input Folder: ./inputOutput Folder: ./outputDate: 2021-01-12 00:56:23.465829

preprocessing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pickle
23
from numpy.linalg import norm
34
from skimage.filters import *
45
from skimage.color import rgb2gray
@@ -71,4 +72,4 @@ def display(img):
7172
cv.resizeWindow('image', 1920, 1080)
7273
cv.imshow('image', img)
7374
if cv.waitKey(0) == 27:
74-
cv2.destoyAllWindows()
75+
cv2.destoyAllWindows()

0 commit comments

Comments
 (0)