@@ -395,13 +395,23 @@ struct LSTM:
395
395
396
396
void save_from_cache (Layer_Loader& loader, Cache& cache)
397
397
{
398
+ auto & z = cache.load (write_key (), default_tensor_factory ());
399
+ auto & i = cache.load (input_key (), default_tensor_factory ());
400
+ auto & f = cache.load (forget_key (), default_tensor_factory ());
401
+ auto & o = cache.load (output_key (), default_tensor_factory ());
398
402
auto & c = cache.load (cell_key (), default_tensor_factory ());
399
- loader.save_variable (c, " cellstate" );
400
403
404
+ loader.save_variable (z, " write_gate_values" );
405
+ loader.save_variable (i, " input_gate_values" );
406
+ loader.save_variable (f, " forget_gate_values" );
407
+ loader.save_variable (o, " output_gate_values" );
408
+ loader.save_variable (c, " cellstate" );
401
409
402
410
if (cache.contains (predict_cell_key ())) {
403
- auto & pc = cache.load (predict_cell_key ());
404
- loader.save_variable (pc, " predict_celltate" );
411
+ auto & pc = cache.load (
412
+ predict_cell_key (),
413
+ default_predict_tensor_factory ());
414
+ loader.save_variable (pc, " predict_cellstate" );
405
415
}
406
416
}
407
417
@@ -439,14 +449,25 @@ struct LSTM:
439
449
bo_opt.load (loader, " bo_opt" );
440
450
}
441
451
442
- void load_from_cache (Layer_Loader& loader, Cache& cache)
452
+ void load_to_cache (Layer_Loader& loader, Cache& cache)
443
453
{
454
+ auto & z = cache.load (write_key (), default_tensor_factory ());
455
+ auto & i = cache.load (input_key (), default_tensor_factory ());
456
+ auto & f = cache.load (forget_key (), default_tensor_factory ());
457
+ auto & o = cache.load (output_key (), default_tensor_factory ());
444
458
auto & c = cache.load (cell_key (), default_tensor_factory ());
459
+
460
+ loader.load_variable (z, " write_gate_values" );
461
+ loader.load_variable (i, " input_gate_values" );
462
+ loader.load_variable (f, " forget_gate_values" );
463
+ loader.load_variable (o, " output_gate_values" );
445
464
loader.load_variable (c, " cellstate" );
446
465
447
- if (loader.file_exists (1 , " cellstate" )) {
448
- auto & pc = cache.load (predict_cell_key ());
449
- loader.load_variable (pc, " predict_celltate" );
466
+ if (loader.file_exists (1 , " predict_cellstate" )) {
467
+ auto & pc = cache.load (
468
+ predict_cell_key (),
469
+ default_predict_tensor_factory ());
470
+ loader.load_variable (pc, " predict_cellstate" );
450
471
}
451
472
}
452
473
@@ -462,9 +483,7 @@ struct LSTM:
462
483
auto default_tensor_factory ()
463
484
{
464
485
return [&]() {
465
- mat m (this ->output_size (), this ->batch_size ());
466
- m.zero ();
467
- return m;
486
+ return mat (this ->output_size (), this ->batch_size ()).zero ();
468
487
};
469
488
}
470
489
0 commit comments