@@ -767,6 +767,58 @@ def send_alt_svc(self, previous_state):
767
767
(H2StreamStateMachine .send_on_closed_stream , StreamState .CLOSED ),
768
768
}
769
769
770
+ """
771
+ Wraps a stream state change function to ensure that we keep
772
+ the parent H2Connection's state in sync
773
+ """
774
+ def sync_state_change (func ):
775
+ def wrapper (self , * args , ** kwargs ):
776
+ # Collect state at the beginning.
777
+ start_state = self .state_machine .state
778
+ started_open = self .open
779
+ started_closed = not started_open
780
+
781
+ # Do the state change (if any).
782
+ result = func (self , * args , ** kwargs )
783
+
784
+ # Collect state at the end.
785
+ end_state = self .state_machine .state
786
+ ended_open = self .open
787
+ ended_closed = not ended_open
788
+
789
+ # If at any point we've tranwsitioned to the CLOSED state
790
+ # from any other state, close our stream.
791
+ if end_state == StreamState .CLOSED and start_state != end_state :
792
+ if self ._close_stream_callback :
793
+ self ._close_stream_callback (self .stream_id )
794
+ # Clear callback so we only call this once per stream
795
+ self ._close_stream_callback = None
796
+
797
+ # If we were open, but are now closed, decrement
798
+ # the open stream count, and call the close callback.
799
+ if started_open and ended_closed :
800
+ if self ._decrement_open_stream_count_callback :
801
+ self ._decrement_open_stream_count_callback (self .stream_id ,
802
+ - 1 ,)
803
+ # Clear callback so we only call this once per stream
804
+ self ._decrement_open_stream_count_callback = None
805
+
806
+ if self ._close_stream_callback :
807
+ self ._close_stream_callback (self .stream_id )
808
+ # Clear callback so we only call this once per stream
809
+ self ._close_stream_callback = None
810
+
811
+ # If we were closed, but are now open, increment
812
+ # the open stream count.
813
+ elif started_closed and ended_open :
814
+ if self ._increment_open_stream_count_callback :
815
+ self ._increment_open_stream_count_callback (self .stream_id ,
816
+ 1 ,)
817
+ # Clear callback so we only call this once per stream
818
+ self ._increment_open_stream_count_callback = None
819
+ return result
820
+ return wrapper
821
+
770
822
771
823
class H2Stream (object ):
772
824
"""
@@ -782,18 +834,29 @@ def __init__(self,
782
834
stream_id ,
783
835
config ,
784
836
inbound_window_size ,
785
- outbound_window_size ):
837
+ outbound_window_size ,
838
+ increment_open_stream_count_callback ,
839
+ close_stream_callback ,):
786
840
self .state_machine = H2StreamStateMachine (stream_id )
787
841
self .stream_id = stream_id
788
842
self .max_outbound_frame_size = None
789
843
self .request_method = None
790
844
791
- # The current value of the outbound stream flow control window
845
+ # The current value of the outbound stream flow control window.
792
846
self .outbound_flow_control_window = outbound_window_size
793
847
794
848
# The flow control manager.
795
849
self ._inbound_window_manager = WindowManager (inbound_window_size )
796
850
851
+ # Callback to increment open stream count for the H2Connection.
852
+ self ._increment_open_stream_count_callback = increment_open_stream_count_callback
853
+
854
+ # Callback to decrement open stream count for the H2Connection.
855
+ self ._decrement_open_stream_count_callback = increment_open_stream_count_callback
856
+
857
+ # Callback to clean up state for the H2Connection once we're closed.
858
+ self ._close_stream_callback = close_stream_callback
859
+
797
860
# The expected content length, if any.
798
861
self ._expected_content_length = None
799
862
@@ -850,6 +913,7 @@ def closed_by(self):
850
913
"""
851
914
return self .state_machine .stream_closed_by
852
915
916
+ @sync_state_change
853
917
def upgrade (self , client_side ):
854
918
"""
855
919
Called by the connection to indicate that this stream is the initial
@@ -868,6 +932,7 @@ def upgrade(self, client_side):
868
932
self .state_machine .process_input (input_ )
869
933
return
870
934
935
+ @sync_state_change
871
936
def send_headers (self , headers , encoder , end_stream = False ):
872
937
"""
873
938
Returns a list of HEADERS/CONTINUATION frames to emit as either headers
@@ -917,6 +982,7 @@ def send_headers(self, headers, encoder, end_stream=False):
917
982
918
983
return frames
919
984
985
+ @sync_state_change
920
986
def push_stream_in_band (self , related_stream_id , headers , encoder ):
921
987
"""
922
988
Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
@@ -941,6 +1007,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):
941
1007
942
1008
return frames
943
1009
1010
+ @sync_state_change
944
1011
def locally_pushed (self ):
945
1012
"""
946
1013
Mark this stream as one that was pushed by this peer. Must be called
@@ -954,6 +1021,7 @@ def locally_pushed(self):
954
1021
assert not events
955
1022
return []
956
1023
1024
+ @sync_state_change
957
1025
def send_data (self , data , end_stream = False , pad_length = None ):
958
1026
"""
959
1027
Prepare some data frames. Optionally end the stream.
@@ -981,6 +1049,7 @@ def send_data(self, data, end_stream=False, pad_length=None):
981
1049
982
1050
return [df ]
983
1051
1052
+ @sync_state_change
984
1053
def end_stream (self ):
985
1054
"""
986
1055
End a stream without sending data.
@@ -992,6 +1061,7 @@ def end_stream(self):
992
1061
df .flags .add ('END_STREAM' )
993
1062
return [df ]
994
1063
1064
+ @sync_state_change
995
1065
def advertise_alternative_service (self , field_value ):
996
1066
"""
997
1067
Advertise an RFC 7838 alternative service. The semantics of this are
@@ -1005,6 +1075,7 @@ def advertise_alternative_service(self, field_value):
1005
1075
asf .field = field_value
1006
1076
return [asf ]
1007
1077
1078
+ @sync_state_change
1008
1079
def increase_flow_control_window (self , increment ):
1009
1080
"""
1010
1081
Increase the size of the flow control window for the remote side.
@@ -1020,6 +1091,7 @@ def increase_flow_control_window(self, increment):
1020
1091
wuf .window_increment = increment
1021
1092
return [wuf ]
1022
1093
1094
+ @sync_state_change
1023
1095
def receive_push_promise_in_band (self ,
1024
1096
promised_stream_id ,
1025
1097
headers ,
@@ -1044,6 +1116,7 @@ def receive_push_promise_in_band(self,
1044
1116
)
1045
1117
return [], events
1046
1118
1119
+ @sync_state_change
1047
1120
def remotely_pushed (self , pushed_headers ):
1048
1121
"""
1049
1122
Mark this stream as one that was pushed by the remote peer. Must be
@@ -1057,6 +1130,7 @@ def remotely_pushed(self, pushed_headers):
1057
1130
self ._authority = authority_from_headers (pushed_headers )
1058
1131
return [], events
1059
1132
1133
+ @sync_state_change
1060
1134
def receive_headers (self , headers , end_stream , header_encoding ):
1061
1135
"""
1062
1136
Receive a set of headers (or trailers).
@@ -1091,6 +1165,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
1091
1165
)
1092
1166
return [], events
1093
1167
1168
+ @sync_state_change
1094
1169
def receive_data (self , data , end_stream , flow_control_len ):
1095
1170
"""
1096
1171
Receive some data.
@@ -1114,6 +1189,7 @@ def receive_data(self, data, end_stream, flow_control_len):
1114
1189
events [0 ].flow_controlled_length = flow_control_len
1115
1190
return [], events
1116
1191
1192
+ @sync_state_change
1117
1193
def receive_window_update (self , increment ):
1118
1194
"""
1119
1195
Handle a WINDOW_UPDATE increment.
@@ -1150,6 +1226,7 @@ def receive_window_update(self, increment):
1150
1226
1151
1227
return frames , events
1152
1228
1229
+ @sync_state_change
1153
1230
def receive_continuation (self ):
1154
1231
"""
1155
1232
A naked CONTINUATION frame has been received. This is always an error,
@@ -1162,6 +1239,7 @@ def receive_continuation(self):
1162
1239
)
1163
1240
assert False , "Should not be reachable"
1164
1241
1242
+ @sync_state_change
1165
1243
def receive_alt_svc (self , frame ):
1166
1244
"""
1167
1245
An Alternative Service frame was received on the stream. This frame
@@ -1189,6 +1267,7 @@ def receive_alt_svc(self, frame):
1189
1267
1190
1268
return [], events
1191
1269
1270
+ @sync_state_change
1192
1271
def reset_stream (self , error_code = 0 ):
1193
1272
"""
1194
1273
Close the stream locally. Reset the stream with an error code.
@@ -1202,6 +1281,7 @@ def reset_stream(self, error_code=0):
1202
1281
rsf .error_code = error_code
1203
1282
return [rsf ]
1204
1283
1284
+ @sync_state_change
1205
1285
def stream_reset (self , frame ):
1206
1286
"""
1207
1287
Handle a stream being reset remotely.
@@ -1217,6 +1297,7 @@ def stream_reset(self, frame):
1217
1297
1218
1298
return [], events
1219
1299
1300
+ @sync_state_change
1220
1301
def acknowledge_received_data (self , acknowledged_size ):
1221
1302
"""
1222
1303
The user has informed us that they've processed some amount of data
0 commit comments