@@ -54,12 +54,15 @@ def train_mnist(h5_file, epochs=10):
5454 model .summary ()
5555
5656 model .compile (optimizer = 'adam' , loss = "categorical_crossentropy" , metrics = ["categorical_accuracy" ])
57- H = model .fit (x_train , y_train , batch_size = 128 , epochs = EPOCHS , verbose = 1 , validation_data = (x_test , y_test ), shuffle = True )
57+ H = model .fit (x_train , y_train , batch_size = 128 , epochs = epochs , verbose = 1 , validation_data = (x_test , y_test ), shuffle = True )
5858
5959 model .save (h5_file )
6060
6161def generate_test_files (out_dir , x , y ):
6262
63+ if not os .path .exists (out_dir ):
64+ os .makedirs (out_dir )
65+
6366 expect_bytes = 28 * 28 * 1
6467 classes = numpy .unique (y )
6568 X_series = pandas .Series ([s for s in x ])
@@ -146,29 +149,34 @@ def format_shape(t : tuple[int]):
146149 assert os .path .exists (tmld_file ), tmld_file
147150 assert os .path .exists (header_file ), header_file
148151
152+ return tmld_file
149153
150154def main ():
151155
152156 h5_file = "mnist_cnn.h5"
153- tinymaix_tools_dir = './TinyMaix/tools'
157+ tinymaix_tools_dir = '../../dependencies/TinyMaix/tools'
158+ assert os .path .exists (tinymaix_tools_dir ), tinymaix_tools_dir
159+
160+ quantize_data = None # disables quantization
161+ quantize_data = os .path .join (tinymaix_tools_dir , 'quant_img_mnist/' )
162+ if quantize_data is not None :
163+ assert os .path .exists (quantize_data )
164+ precision = 'int8' if quantize_data else 'fp32'
154165
166+ # Run training
155167 train_mnist (h5_file )
156168
157169 #data = x_test[1]
158170
159- quantize_data = True # disables quantization
160- quantize_data = './TinyMaix/tools/quant_img_mnist/'
161- precision = 'int8' if quantize_data else 'fp32'
162-
163171 # Export the model using TinyMaix
164-
165- generate_tinymaix_model (h5_file ,
172+ out = generate_tinymaix_model (h5_file ,
166173 input_shape = (28 ,28 ,1 ),
167174 output_shape = (1 ,),
168175 tools_dir = tinymaix_tools_dir ,
169176 precision = precision ,
170177 quantize_data = quantize_data ,
171178 )
179+ print ('Wrote model to' , out )
172180
173181if __name__ == '__main__' :
174182 main ()
0 commit comments