1
1
use crate :: configuration:: { Extension , ExtensionType , FailureMode } ;
2
2
use crate :: envoy:: RateLimitDescriptor ;
3
3
use crate :: policy:: Policy ;
4
- use crate :: service:: { GetMapValuesBytes , GrpcCall , GrpcMessage , GrpcServiceHandler } ;
4
+ use crate :: service:: { GetMapValuesBytesFn , GrpcCallFn , GrpcMessage , GrpcServiceHandler } ;
5
5
use protobuf:: RepeatedField ;
6
6
use proxy_wasm:: hostcalls;
7
7
use proxy_wasm:: types:: { Bytes , MapType , Status } ;
@@ -38,8 +38,8 @@ pub(crate) struct Operation {
38
38
result : Result < u32 , Status > ,
39
39
extension : Rc < Extension > ,
40
40
procedure : Procedure ,
41
- grpc_call : GrpcCall ,
42
- get_map_values_bytes : GetMapValuesBytes ,
41
+ grpc_call_fn : GrpcCallFn ,
42
+ get_map_values_bytes_fn : GetMapValuesBytesFn ,
43
43
}
44
44
45
45
#[ allow( dead_code) ]
@@ -50,17 +50,17 @@ impl Operation {
50
50
result : Err ( Status :: Empty ) ,
51
51
extension,
52
52
procedure,
53
- grpc_call ,
54
- get_map_values_bytes ,
53
+ grpc_call_fn ,
54
+ get_map_values_bytes_fn ,
55
55
}
56
56
}
57
57
58
58
fn trigger ( & mut self ) {
59
59
if let State :: Done = self . state {
60
60
} else {
61
61
self . result = self . procedure . 0 . send (
62
- self . get_map_values_bytes ,
63
- self . grpc_call ,
62
+ self . get_map_values_bytes_fn ,
63
+ self . grpc_call_fn ,
64
64
self . procedure . 1 . clone ( ) ,
65
65
) ;
66
66
self . state . next ( ) ;
@@ -147,7 +147,8 @@ impl OperationDispatcher {
147
147
let mut operations = self . operations . borrow_mut ( ) ;
148
148
if let Some ( ( i, operation) ) = operations. iter_mut ( ) . enumerate ( ) . next ( ) {
149
149
if let State :: Done = operation. get_state ( ) {
150
- Some ( operations. remove ( i) )
150
+ operations. remove ( i) ;
151
+ operations. get ( i) . cloned ( ) // The next op is now at `i`
151
152
} else {
152
153
operation. trigger ( ) ;
153
154
Some ( operation. clone ( ) )
@@ -158,7 +159,7 @@ impl OperationDispatcher {
158
159
}
159
160
}
160
161
161
- fn grpc_call (
162
+ fn grpc_call_fn (
162
163
upstream_name : & str ,
163
164
service_name : & str ,
164
165
method_name : & str ,
@@ -176,7 +177,7 @@ fn grpc_call(
176
177
)
177
178
}
178
179
179
- fn get_map_values_bytes ( map_type : MapType , key : & str ) -> Result < Option < Bytes > , Status > {
180
+ fn get_map_values_bytes_fn ( map_type : MapType , key : & str ) -> Result < Option < Bytes > , Status > {
180
181
hostcalls:: get_map_value_bytes ( map_type, key)
181
182
}
182
183
@@ -186,7 +187,7 @@ mod tests {
186
187
use crate :: envoy:: RateLimitRequest ;
187
188
use std:: time:: Duration ;
188
189
189
- fn grpc_call (
190
+ fn grpc_call_fn_stub (
190
191
_upstream_name : & str ,
191
192
_service_name : & str ,
192
193
_method_name : & str ,
@@ -197,7 +198,10 @@ mod tests {
197
198
Ok ( 200 )
198
199
}
199
200
200
- fn get_map_values_bytes ( _map_type : MapType , _key : & str ) -> Result < Option < Bytes > , Status > {
201
+ fn get_map_values_bytes_fn_stub (
202
+ _map_type : MapType ,
203
+ _key : & str ,
204
+ ) -> Result < Option < Bytes > , Status > {
201
205
Ok ( Some ( Vec :: new ( ) ) )
202
206
}
203
207
@@ -218,14 +222,14 @@ mod tests {
218
222
fn build_operation ( ) -> Operation {
219
223
Operation {
220
224
state : State :: Pending ,
221
- result : Ok ( 200 ) ,
225
+ result : Ok ( 1 ) ,
222
226
extension : Rc :: new ( Extension :: default ( ) ) ,
223
227
procedure : (
224
228
Rc :: new ( build_grpc_service_handler ( ) ) ,
225
229
GrpcMessage :: RateLimit ( build_message ( ) ) ,
226
230
) ,
227
- grpc_call ,
228
- get_map_values_bytes ,
231
+ grpc_call_fn : grpc_call_fn_stub ,
232
+ get_map_values_bytes_fn : get_map_values_bytes_fn_stub ,
229
233
}
230
234
}
231
235
@@ -236,7 +240,7 @@ mod tests {
236
240
assert_eq ! ( operation. get_state( ) , State :: Pending ) ;
237
241
assert_eq ! ( operation. get_extension_type( ) , ExtensionType :: RateLimit ) ;
238
242
assert_eq ! ( operation. get_failure_mode( ) , FailureMode :: Deny ) ;
239
- assert_eq ! ( operation. get_result( ) , Ok ( 200 ) ) ;
243
+ assert_eq ! ( operation. get_result( ) , Ok ( 1 ) ) ;
240
244
}
241
245
242
246
#[ test]
@@ -272,20 +276,37 @@ mod tests {
272
276
273
277
#[ test]
274
278
fn operation_dispatcher_next ( ) {
275
- let operation = build_operation ( ) ;
276
279
let operation_dispatcher = OperationDispatcher :: default ( ) ;
277
- operation_dispatcher. push_operations ( vec ! [ operation ] ) ;
280
+ operation_dispatcher. push_operations ( vec ! [ build_operation ( ) , build_operation ( ) ] ) ;
278
281
279
- if let Some ( operation) = operation_dispatcher. next ( ) {
280
- assert_eq ! ( operation. get_result( ) , Ok ( 200 ) ) ;
281
- assert_eq ! ( operation. get_state( ) , State :: Waiting ) ;
282
- }
282
+ assert_eq ! ( operation_dispatcher. get_current_operation_result( ) , Ok ( 1 ) ) ;
283
+ assert_eq ! (
284
+ operation_dispatcher. get_current_operation_state( ) ,
285
+ Some ( State :: Pending )
286
+ ) ;
283
287
284
- if let Some ( operation) = operation_dispatcher. next ( ) {
285
- assert_eq ! ( operation. get_result( ) , Ok ( 200 ) ) ;
286
- assert_eq ! ( operation. get_state( ) , State :: Done ) ;
287
- }
288
- operation_dispatcher. next ( ) ;
289
- assert_eq ! ( operation_dispatcher. get_current_operation_state( ) , None ) ;
288
+ let mut op = operation_dispatcher. next ( ) ;
289
+ assert_eq ! ( op. clone( ) . unwrap( ) . get_result( ) , Ok ( 200 ) ) ;
290
+ assert_eq ! ( op. unwrap( ) . get_state( ) , State :: Waiting ) ;
291
+
292
+ op = operation_dispatcher. next ( ) ;
293
+ assert_eq ! ( op. clone( ) . unwrap( ) . get_result( ) , Ok ( 200 ) ) ;
294
+ assert_eq ! ( op. unwrap( ) . get_state( ) , State :: Done ) ;
295
+
296
+ op = operation_dispatcher. next ( ) ;
297
+ assert_eq ! ( op. clone( ) . unwrap( ) . get_result( ) , Ok ( 1 ) ) ;
298
+ assert_eq ! ( op. unwrap( ) . get_state( ) , State :: Pending ) ;
299
+
300
+ op = operation_dispatcher. next ( ) ;
301
+ assert_eq ! ( op. clone( ) . unwrap( ) . get_result( ) , Ok ( 200 ) ) ;
302
+ assert_eq ! ( op. unwrap( ) . get_state( ) , State :: Waiting ) ;
303
+
304
+ op = operation_dispatcher. next ( ) ;
305
+ assert_eq ! ( op. clone( ) . unwrap( ) . get_result( ) , Ok ( 200 ) ) ;
306
+ assert_eq ! ( op. unwrap( ) . get_state( ) , State :: Done ) ;
307
+
308
+ op = operation_dispatcher. next ( ) ;
309
+ assert ! ( op. is_none( ) ) ;
310
+ assert ! ( operation_dispatcher. get_current_operation_state( ) . is_none( ) ) ;
290
311
}
291
312
}
0 commit comments