@@ -470,6 +470,7 @@ pub(crate) struct PeerState {
470
470
intercept_scid_by_user_channel_id : HashMap < u128 , u64 > ,
471
471
intercept_scid_by_channel_id : HashMap < ChannelId , u64 > ,
472
472
pending_requests : HashMap < LSPSRequestId , LSPS2Request > ,
473
+ needs_persist : bool ,
473
474
}
474
475
475
476
impl PeerState {
@@ -478,16 +479,19 @@ impl PeerState {
478
479
let pending_requests = new_hash_map ( ) ;
479
480
let intercept_scid_by_user_channel_id = new_hash_map ( ) ;
480
481
let intercept_scid_by_channel_id = new_hash_map ( ) ;
482
+ let needs_persist = true ;
481
483
Self {
482
484
outbound_channels_by_intercept_scid,
483
485
pending_requests,
484
486
intercept_scid_by_user_channel_id,
485
487
intercept_scid_by_channel_id,
488
+ needs_persist,
486
489
}
487
490
}
488
491
489
492
fn insert_outbound_channel ( & mut self , intercept_scid : u64 , channel : OutboundJITChannel ) {
490
493
self . outbound_channels_by_intercept_scid . insert ( intercept_scid, channel) ;
494
+ self . needs_persist |= true ;
491
495
}
492
496
493
497
fn prune_expired_request_state ( & mut self ) {
@@ -506,6 +510,7 @@ impl PeerState {
506
510
// We abort the flow, and prune any data kept.
507
511
self . intercept_scid_by_channel_id . retain ( |_, iscid| intercept_scid != iscid) ;
508
512
self . intercept_scid_by_user_channel_id . retain ( |_, iscid| intercept_scid != iscid) ;
513
+ // TODO: Remove peer state entry from the KVStore
509
514
return false ;
510
515
}
511
516
true
@@ -533,6 +538,7 @@ impl_writeable_tlv_based!(PeerState, {
533
538
( 2 , intercept_scid_by_user_channel_id, required) ,
534
539
( 4 , intercept_scid_by_channel_id, required) ,
535
540
( _unused, pending_requests, ( static_value, new_hash_map( ) ) ) ,
541
+ ( _unused, needs_persist, ( static_value, false ) ) ,
536
542
} ) ;
537
543
538
544
macro_rules! get_or_insert_peer_state_entry {
@@ -823,6 +829,9 @@ where
823
829
match outer_state_lock. get ( counterparty_node_id) {
824
830
Some ( inner_state_lock) => {
825
831
let mut peer_state = inner_state_lock. lock ( ) . unwrap ( ) ;
832
+ peer_state. needs_persist |= peer_state
833
+ . outbound_channels_by_intercept_scid
834
+ . contains_key ( & intercept_scid) ;
826
835
if let Some ( jit_channel) =
827
836
peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
828
837
{
@@ -910,6 +919,8 @@ where
910
919
match outer_state_lock. get ( counterparty_node_id) {
911
920
Some ( inner_state_lock) => {
912
921
let mut peer_state = inner_state_lock. lock ( ) . unwrap ( ) ;
922
+ peer_state. needs_persist |=
923
+ peer_state. intercept_scid_by_channel_id . contains_key ( & channel_id) ;
913
924
if let Some ( intercept_scid) =
914
925
peer_state. intercept_scid_by_channel_id . get ( & channel_id) . copied ( )
915
926
{
@@ -978,6 +989,8 @@ where
978
989
match outer_state_lock. get ( counterparty_node_id) {
979
990
Some ( inner_state_lock) => {
980
991
let mut peer_state = inner_state_lock. lock ( ) . unwrap ( ) ;
992
+ peer_state. needs_persist |=
993
+ peer_state. intercept_scid_by_channel_id . contains_key ( & next_channel_id) ;
981
994
if let Some ( intercept_scid) =
982
995
peer_state. intercept_scid_by_channel_id . get ( & next_channel_id) . copied ( )
983
996
{
@@ -1082,6 +1095,7 @@ where
1082
1095
peer_state. intercept_scid_by_user_channel_id . remove ( & user_channel_id) ;
1083
1096
peer_state. outbound_channels_by_intercept_scid . remove ( & intercept_scid) ;
1084
1097
peer_state. intercept_scid_by_channel_id . retain ( |_, & mut scid| scid != intercept_scid) ;
1098
+ peer_state. needs_persist |= true ;
1085
1099
1086
1100
Ok ( ( ) )
1087
1101
}
@@ -1113,6 +1127,8 @@ where
1113
1127
err : format ! ( "Could not find a channel with user_channel_id {}" , user_channel_id) ,
1114
1128
} ) ?;
1115
1129
1130
+ peer_state. needs_persist |=
1131
+ peer_state. outbound_channels_by_intercept_scid . contains_key ( & intercept_scid) ;
1116
1132
let jit_channel = peer_state
1117
1133
. outbound_channels_by_intercept_scid
1118
1134
. get_mut ( & intercept_scid)
@@ -1162,6 +1178,8 @@ where
1162
1178
match outer_state_lock. get ( counterparty_node_id) {
1163
1179
Some ( inner_state_lock) => {
1164
1180
let mut peer_state = inner_state_lock. lock ( ) . unwrap ( ) ;
1181
+ peer_state. needs_persist |=
1182
+ peer_state. intercept_scid_by_user_channel_id . contains_key ( & user_channel_id) ;
1165
1183
if let Some ( intercept_scid) =
1166
1184
peer_state. intercept_scid_by_user_channel_id . get ( & user_channel_id) . copied ( )
1167
1185
{
@@ -1484,7 +1502,16 @@ where
1484
1502
) ;
1485
1503
return Err ( err) ;
1486
1504
} ,
1487
- Some ( entry) => entry. lock ( ) . unwrap ( ) . encode ( ) ,
1505
+ Some ( entry) => {
1506
+ let mut peer_state_lock = entry. lock ( ) . unwrap ( ) ;
1507
+ if !peer_state_lock. needs_persist {
1508
+ // We already have persisted otherwise by now.
1509
+ return Ok ( ( ) ) ;
1510
+ } else {
1511
+ peer_state_lock. needs_persist = false ;
1512
+ peer_state_lock. encode ( )
1513
+ }
1514
+ } ,
1488
1515
}
1489
1516
} ;
1490
1517
@@ -1498,6 +1525,14 @@ where
1498
1525
encoded,
1499
1526
)
1500
1527
. await
1528
+ . map_err ( |e| {
1529
+ self . per_peer_state
1530
+ . read ( )
1531
+ . unwrap ( )
1532
+ . get ( & counterparty_node_id)
1533
+ . map ( |p| p. lock ( ) . unwrap ( ) . needs_persist = true ) ;
1534
+ e
1535
+ } )
1501
1536
}
1502
1537
1503
1538
pub ( crate ) async fn persist ( & self ) -> Result < ( ) , lightning:: io:: Error > {
@@ -1506,7 +1541,10 @@ where
1506
1541
// time.
1507
1542
let need_persist: Vec < PublicKey > = {
1508
1543
let outer_state_lock = self . per_peer_state . read ( ) . unwrap ( ) ;
1509
- outer_state_lock. iter ( ) . filter_map ( |( k, v) | Some ( * k) ) . collect ( )
1544
+ outer_state_lock
1545
+ . iter ( )
1546
+ . filter_map ( |( k, v) | if v. lock ( ) . unwrap ( ) . needs_persist { Some ( * k) } else { None } )
1547
+ . collect ( )
1510
1548
} ;
1511
1549
1512
1550
for counterparty_node_id in need_persist. into_iter ( ) {
0 commit comments