Skip to content

Commit 929f86a

Browse files
committed
fix training
1 parent cfd243a commit 929f86a

File tree

5 files changed

+42
-14
lines changed

5 files changed

+42
-14
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
data/
2+
data2/
3+
data3/
4+
logs/
5+
__pycache__/

cvae.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(self, args):
1717
self.args = args
1818
self.logs_dir = args.logs_dir
1919
self.n_dim = args.n_dim
20-
self.dimension = args.dimension
2120
self.image_size = args.image_size
2221
self.num_layers = args.num_layers
2322
self.filters = args.filters

downnload.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import urllib.request
2+
import os
3+
4+
with open("./data/categories.txt", 'r') as f:
5+
classes = f.readlines()
6+
7+
classes = [c.replace('\n','').replace(' ','_') for c in classes]
8+
9+
def download():
10+
11+
base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
12+
for c in classes:
13+
cls_url = c.replace('_', '%20')
14+
path = base+cls_url+'.npy'
15+
print(path)
16+
os.mkdir("data/"+c)
17+
urllib.request.urlretrieve(path, 'data/'+c+'/'+c+'.npy')
18+
19+
download()

train.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.datasets import mnist
44
from keras.utils import to_categorical
55
from cvae import CVAE
6-
6+
from utils import *
77

88

99
def main():
@@ -37,8 +37,12 @@ def main():
3737

3838
args = parser.parse_args()
3939

40-
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
40+
# (train_features, train_labels), (validataion_features, validataion_labels) = get_files('data')
41+
# training, validation = get_data(train_features, train_labels, validataion_features, validataion_labels)
42+
(X_train, Y_train), (X_test, Y_test) = get_files('data')
4143

44+
# (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
45+
print(X_train.shape)
4246
X_train = np.reshape(X_train, [-1, args.image_size, args.image_size, args.image_depth])
4347
X_test = np.reshape(X_test, [-1, args.image_size, args.image_size, args.image_depth])
4448
X_train = X_train.astype('float32') / 255.

utils.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_files(path):
1010
validation_features = None
1111
training_labels = None
1212
validation_labels = None
13-
labels_value = None
13+
labels_value = []
1414
count = 1
1515

1616
#TODO: Load Img
@@ -34,10 +34,10 @@ def get_files(path):
3434
training_labels = np.concatenate([training_labels, new_labels], axis=0)
3535

3636
# concatenate arrays to get the validation data
37-
if validataion_features is None:
38-
validataion_features = np.copy(data[length-(length//10):,:,:,:])
37+
if validation_features is None:
38+
validation_features = np.copy(data[length-(length//10):,:,:,:])
3939
else:
40-
validataion_features = np.concatenate([validataion_features ,data[length-(length//10):,:,:,:]), axis=0)
40+
validation_features = np.concatenate((validation_features, data[length-(length//10):,:,:,:]), axis=0)
4141

4242
# get validation data
4343
if validation_labels is None:
@@ -56,6 +56,7 @@ def get_files(path):
5656
# create training data
5757
def get_data(training_features, training_labels, validation_features, validation_labels):
5858
# get training data
59+
print(training_features.shape, training_labels.shape)
5960
train_imgs = tf.constant(training_features)
6061
train_labels = tf.constant(training_labels)
6162

@@ -68,12 +69,12 @@ def get_data(training_features, training_labels, validation_features, validation
6869

6970
return training_data, validation_data
7071

71-
(train_features, train_labels), (validataion_features, validataion_labels) = get_files('quickdraw_data')
72-
image_label_ds = get_data(train_features, train_labels, validataion_features, validataion_labels)
72+
# (train_features, train_labels), (validataion_features, validataion_labels) = get_files('quickdraw_data')
73+
# image_label_ds = get_data(train_features, train_labels, validataion_features, validataion_labels)
7374

7475

75-
print('image shape: ', image_label_ds.output_shapes[0])
76-
print('label shape: ', image_label_ds.output_shapes[1])
77-
print('types: ', image_label_ds.output_types)
78-
print()
79-
print(image_label_ds)
76+
# print('image shape: ', image_label_ds.output_shapes[0])
77+
# print('label shape: ', image_label_ds.output_shapes[1])
78+
# print('types: ', image_label_ds.output_types)
79+
# print()
80+
# print(image_label_ds)

0 commit comments

Comments
 (0)