Skip to content

Commit 279581f

Browse files
committed
Fix saving/loading bug in LSTM
1 parent 8f26bde commit 279581f

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

include/neural_networks/Layer_Loader.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ struct Layer_Loader {
9595
}
9696

9797
bool file_exists(std::string filename) {
98-
#ifdef BC_USE_EXPERIMENTAL_FILE_SYSTEM
99-
return std::experimental::filesystem::exists(filename);
100-
#else
98+
//TODO re-add
99+
//#ifdef BC_USE_EXPERIMENTAL_FILE_SYSTEM
100+
// return std::experimental::filesystem::exists(filename);
101+
//#else
101102
return std::ifstream(filename).good();
102-
#endif
103+
//#endif
103104
}
104105

105106
static std::string dimension_to_tensor_name(int dimension) {

include/neural_networks/Network.h

+2
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ struct NeuralNetwork {
263263
loader.set_current_layer_name(layer.classname());
264264
loader.make_current_directory();
265265
layer.save(loader);
266+
layer.save_from_cache(loader, layer.get_cache());
266267
index++;
267268
});
268269
}
@@ -284,6 +285,7 @@ struct NeuralNetwork {
284285
loader.set_current_layer_index(index);
285286
loader.set_current_layer_name(layer.classname());
286287
layer.load(loader);
288+
layer.load_to_cache(loader, layer.get_cache());
287289
index++;
288290
});
289291

include/neural_networks/layers/LSTM.h

+29-10
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,23 @@ struct LSTM:
395395

396396
void save_from_cache(Layer_Loader& loader, Cache& cache)
397397
{
398+
auto& z = cache.load(write_key(), default_tensor_factory());
399+
auto& i = cache.load(input_key(), default_tensor_factory());
400+
auto& f = cache.load(forget_key(), default_tensor_factory());
401+
auto& o = cache.load(output_key(), default_tensor_factory());
398402
auto& c = cache.load(cell_key(), default_tensor_factory());
399-
loader.save_variable(c, "cellstate");
400403

404+
loader.save_variable(z, "write_gate_values");
405+
loader.save_variable(i, "input_gate_values");
406+
loader.save_variable(f, "forget_gate_values");
407+
loader.save_variable(o, "output_gate_values");
408+
loader.save_variable(c, "cellstate");
401409

402410
if (cache.contains(predict_cell_key())) {
403-
auto& pc = cache.load(predict_cell_key());
404-
loader.save_variable(pc, "predict_celltate");
411+
auto& pc = cache.load(
412+
predict_cell_key(),
413+
default_predict_tensor_factory());
414+
loader.save_variable(pc, "predict_cellstate");
405415
}
406416
}
407417

@@ -439,14 +449,25 @@ struct LSTM:
439449
bo_opt.load(loader, "bo_opt");
440450
}
441451

442-
void load_from_cache(Layer_Loader& loader, Cache& cache)
452+
void load_to_cache(Layer_Loader& loader, Cache& cache)
443453
{
454+
auto& z = cache.load(write_key(), default_tensor_factory());
455+
auto& i = cache.load(input_key(), default_tensor_factory());
456+
auto& f = cache.load(forget_key(), default_tensor_factory());
457+
auto& o = cache.load(output_key(), default_tensor_factory());
444458
auto& c = cache.load(cell_key(), default_tensor_factory());
459+
460+
loader.load_variable(z, "write_gate_values");
461+
loader.load_variable(i, "input_gate_values");
462+
loader.load_variable(f, "forget_gate_values");
463+
loader.load_variable(o, "output_gate_values");
445464
loader.load_variable(c, "cellstate");
446465

447-
if (loader.file_exists(1, "cellstate")) {
448-
auto& pc = cache.load(predict_cell_key());
449-
loader.load_variable(pc, "predict_celltate");
466+
if (loader.file_exists(1, "predict_cellstate")) {
467+
auto& pc = cache.load(
468+
predict_cell_key(),
469+
default_predict_tensor_factory());
470+
loader.load_variable(pc, "predict_cellstate");
450471
}
451472
}
452473

@@ -462,9 +483,7 @@ struct LSTM:
462483
auto default_tensor_factory()
463484
{
464485
return [&]() {
465-
mat m(this->output_size(), this->batch_size());
466-
m.zero();
467-
return m;
486+
return mat(this->output_size(), this->batch_size()).zero();
468487
};
469488
}
470489

0 commit comments

Comments
 (0)