@@ -2,33 +2,144 @@ pub(crate) mod auth;
2
2
pub ( crate ) mod rate_limit;
3
3
4
4
use crate :: configuration:: { ExtensionType , FailureMode } ;
5
- use crate :: envoy:: { RateLimitDescriptor , RateLimitRequest } ;
6
- use crate :: service:: auth:: { AUTH_METHOD_NAME , AUTH_SERVICE_NAME } ;
5
+ use crate :: envoy:: { CheckRequest , RateLimitDescriptor , RateLimitRequest } ;
6
+ use crate :: service:: auth:: { AuthService , AUTH_METHOD_NAME , AUTH_SERVICE_NAME } ;
7
7
use crate :: service:: rate_limit:: { RateLimitService , RATELIMIT_METHOD_NAME , RATELIMIT_SERVICE_NAME } ;
8
8
use crate :: service:: TracingHeader :: { Baggage , Traceparent , Tracestate } ;
9
- use protobuf:: Message ;
9
+ use protobuf:: reflect:: MessageDescriptor ;
10
+ use protobuf:: {
11
+ Clear , CodedInputStream , CodedOutputStream , Message , ProtobufResult , UnknownFields ,
12
+ } ;
10
13
use proxy_wasm:: hostcalls;
11
14
use proxy_wasm:: hostcalls:: dispatch_grpc_call;
12
15
use proxy_wasm:: types:: { Bytes , MapType , Status } ;
16
+ use std:: any:: Any ;
13
17
use std:: cell:: OnceCell ;
18
+ use std:: fmt:: { Debug } ;
14
19
use std:: rc:: Rc ;
15
20
use std:: time:: Duration ;
16
21
17
- #[ derive( Clone ) ]
22
+ #[ derive( Clone , Debug ) ]
18
23
pub enum GrpcMessage {
19
- // Auth(CheckRequest),
24
+ Auth ( CheckRequest ) ,
20
25
RateLimit ( RateLimitRequest ) ,
21
26
}
22
27
23
- impl GrpcMessage {
24
- pub fn get_message ( & self ) -> & RateLimitRequest {
25
- //TODO(didierofrivia): Should return Message
28
+ impl Default for GrpcMessage {
29
+ fn default ( ) -> Self {
30
+ GrpcMessage :: RateLimit ( RateLimitRequest :: new ( ) )
31
+ }
32
+ }
33
+
34
+ impl Clear for GrpcMessage {
35
+ fn clear ( & mut self ) {
26
36
match self {
27
- GrpcMessage :: RateLimit ( message) => message,
37
+ GrpcMessage :: Auth ( msg) => msg. clear ( ) ,
38
+ GrpcMessage :: RateLimit ( msg) => msg. clear ( ) ,
28
39
}
29
40
}
30
41
}
31
42
43
+ impl Message for GrpcMessage {
44
+ fn descriptor ( & self ) -> & ' static MessageDescriptor {
45
+ match self {
46
+ GrpcMessage :: Auth ( msg) => msg. descriptor ( ) ,
47
+ GrpcMessage :: RateLimit ( msg) => msg. descriptor ( ) ,
48
+ }
49
+ }
50
+
51
+ fn is_initialized ( & self ) -> bool {
52
+ match self {
53
+ GrpcMessage :: Auth ( msg) => msg. is_initialized ( ) ,
54
+ GrpcMessage :: RateLimit ( msg) => msg. is_initialized ( ) ,
55
+ }
56
+ }
57
+
58
+ fn merge_from ( & mut self , is : & mut CodedInputStream ) -> ProtobufResult < ( ) > {
59
+ match self {
60
+ GrpcMessage :: Auth ( msg) => msg. merge_from ( is) ,
61
+ GrpcMessage :: RateLimit ( msg) => msg. merge_from ( is) ,
62
+ }
63
+ }
64
+
65
+ fn write_to_with_cached_sizes ( & self , os : & mut CodedOutputStream ) -> ProtobufResult < ( ) > {
66
+ match self {
67
+ GrpcMessage :: Auth ( msg) => msg. write_to_with_cached_sizes ( os) ,
68
+ GrpcMessage :: RateLimit ( msg) => msg. write_to_with_cached_sizes ( os) ,
69
+ }
70
+ }
71
+
72
+ fn write_to_bytes ( & self ) -> ProtobufResult < Vec < u8 > > {
73
+ match self {
74
+ GrpcMessage :: Auth ( msg) => msg. write_to_bytes ( ) ,
75
+ GrpcMessage :: RateLimit ( msg) => msg. write_to_bytes ( ) ,
76
+ }
77
+ }
78
+
79
+ fn compute_size ( & self ) -> u32 {
80
+ match self {
81
+ GrpcMessage :: Auth ( msg) => msg. compute_size ( ) ,
82
+ GrpcMessage :: RateLimit ( msg) => msg. compute_size ( ) ,
83
+ }
84
+ }
85
+
86
+ fn get_cached_size ( & self ) -> u32 {
87
+ match self {
88
+ GrpcMessage :: Auth ( msg) => msg. get_cached_size ( ) ,
89
+ GrpcMessage :: RateLimit ( msg) => msg. get_cached_size ( ) ,
90
+ }
91
+ }
92
+
93
+ fn get_unknown_fields ( & self ) -> & UnknownFields {
94
+ match self {
95
+ GrpcMessage :: Auth ( msg) => msg. get_unknown_fields ( ) ,
96
+ GrpcMessage :: RateLimit ( msg) => msg. get_unknown_fields ( ) ,
97
+ }
98
+ }
99
+
100
+ fn mut_unknown_fields ( & mut self ) -> & mut UnknownFields {
101
+ match self {
102
+ GrpcMessage :: Auth ( msg) => msg. mut_unknown_fields ( ) ,
103
+ GrpcMessage :: RateLimit ( msg) => msg. mut_unknown_fields ( ) ,
104
+ }
105
+ }
106
+
107
+ fn as_any ( & self ) -> & dyn Any {
108
+ match self {
109
+ GrpcMessage :: Auth ( msg) => msg. as_any ( ) ,
110
+ GrpcMessage :: RateLimit ( msg) => msg. as_any ( ) ,
111
+ }
112
+ }
113
+
114
+ fn new ( ) -> Self
115
+ where
116
+ Self : Sized ,
117
+ {
118
+ // Returning default value
119
+ GrpcMessage :: default ( )
120
+ }
121
+
122
+ fn default_instance ( ) -> & ' static Self
123
+ where
124
+ Self : Sized ,
125
+ {
126
+ #[ allow( non_upper_case_globals) ]
127
+ static instance: :: protobuf:: rt:: LazyV2 < GrpcMessage > = :: protobuf:: rt:: LazyV2 :: INIT ;
128
+ instance. get ( || GrpcMessage :: RateLimit ( RateLimitRequest :: new ( ) ) )
129
+ }
130
+ }
131
+
132
+ impl GrpcMessage {
133
+ // Using domain as ce_host for the time being, we might pass a DataType in the future.
134
+ pub fn new ( extension_type : ExtensionType , domain : String , descriptors : protobuf:: RepeatedField < RateLimitDescriptor > ) -> Self {
135
+ match extension_type {
136
+ ExtensionType :: RateLimit => GrpcMessage :: RateLimit ( RateLimitService :: message ( domain. clone ( ) , descriptors) ) ,
137
+ ExtensionType :: Auth => GrpcMessage :: Auth ( AuthService :: message ( domain. clone ( ) ) )
138
+ }
139
+ }
140
+
141
+ }
142
+
32
143
#[ derive( Default ) ]
33
144
pub struct GrpcService {
34
145
endpoint : String ,
@@ -102,7 +213,7 @@ impl GrpcServiceHandler {
102
213
}
103
214
104
215
pub fn send ( & self , message : GrpcMessage ) -> Result < u32 , Status > {
105
- let msg = Message :: write_to_bytes ( message. get_message ( ) ) . unwrap ( ) ;
216
+ let msg = Message :: write_to_bytes ( & message) . unwrap ( ) ;
106
217
let metadata = self
107
218
. header_resolver
108
219
. get ( )
@@ -120,18 +231,8 @@ impl GrpcServiceHandler {
120
231
)
121
232
}
122
233
123
- // Using domain as ce_host for the time being, we might pass a DataType in the future.
124
- //TODO(didierofrivia): Make it work with Message. for both Auth and RL
125
- pub fn build_message (
126
- & self ,
127
- domain : String ,
128
- descriptors : protobuf:: RepeatedField < RateLimitDescriptor > ,
129
- ) -> GrpcMessage {
130
- /*match self.service.extension_type {
131
- //ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())),
132
- //ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)),
133
- }*/
134
- GrpcMessage :: RateLimit ( RateLimitService :: message ( domain. clone ( ) , descriptors) )
234
+ pub fn get_extension_type ( & self ) -> ExtensionType {
235
+ self . service . extension_type . clone ( )
135
236
}
136
237
}
137
238
0 commit comments