Skip to content

Commit ca3cdff

Browse files
committed
[refactor] Implementing own Message for GrpcMessage
Signed-off-by: dd di cesare <[email protected]>
1 parent c4796de commit ca3cdff

File tree

2 files changed

+124
-23
lines changed

2 files changed

+124
-23
lines changed

src/operation_dispatcher.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl OperationDispatcher {
9696
policy.actions.iter().for_each(|action| {
9797
// TODO(didierofrivia): Error handling
9898
if let Some(service) = self.service_handlers.get(&action.extension) {
99-
let message = service.build_message(policy.domain.clone(), descriptors.clone());
99+
let message = GrpcMessage::new(service.get_extension_type(), policy.domain.clone(), descriptors.clone());
100100
operations.push(Operation::new((service.clone(), message)))
101101
}
102102
});

src/service.rs

+123-22
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,144 @@ pub(crate) mod auth;
22
pub(crate) mod rate_limit;
33

44
use crate::configuration::{ExtensionType, FailureMode};
5-
use crate::envoy::{RateLimitDescriptor, RateLimitRequest};
6-
use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
5+
use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest};
6+
use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
77
use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME};
88
use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate};
9-
use protobuf::Message;
9+
use protobuf::reflect::MessageDescriptor;
10+
use protobuf::{
11+
Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields,
12+
};
1013
use proxy_wasm::hostcalls;
1114
use proxy_wasm::hostcalls::dispatch_grpc_call;
1215
use proxy_wasm::types::{Bytes, MapType, Status};
16+
use std::any::Any;
1317
use std::cell::OnceCell;
18+
use std::fmt::{Debug};
1419
use std::rc::Rc;
1520
use std::time::Duration;
1621

17-
#[derive(Clone)]
22+
#[derive(Clone, Debug)]
1823
pub enum GrpcMessage {
19-
//Auth(CheckRequest),
24+
Auth(CheckRequest),
2025
RateLimit(RateLimitRequest),
2126
}
2227

23-
impl GrpcMessage {
24-
pub fn get_message(&self) -> &RateLimitRequest {
25-
//TODO(didierofrivia): Should return Message
28+
impl Default for GrpcMessage {
29+
fn default() -> Self {
30+
GrpcMessage::RateLimit(RateLimitRequest::new())
31+
}
32+
}
33+
34+
impl Clear for GrpcMessage {
35+
fn clear(&mut self) {
2636
match self {
27-
GrpcMessage::RateLimit(message) => message,
37+
GrpcMessage::Auth(msg) => msg.clear(),
38+
GrpcMessage::RateLimit(msg) => msg.clear(),
2839
}
2940
}
3041
}
3142

43+
impl Message for GrpcMessage {
44+
fn descriptor(&self) -> &'static MessageDescriptor {
45+
match self {
46+
GrpcMessage::Auth(msg) => msg.descriptor(),
47+
GrpcMessage::RateLimit(msg) => msg.descriptor(),
48+
}
49+
}
50+
51+
fn is_initialized(&self) -> bool {
52+
match self {
53+
GrpcMessage::Auth(msg) => msg.is_initialized(),
54+
GrpcMessage::RateLimit(msg) => msg.is_initialized(),
55+
}
56+
}
57+
58+
fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> {
59+
match self {
60+
GrpcMessage::Auth(msg) => msg.merge_from(is),
61+
GrpcMessage::RateLimit(msg) => msg.merge_from(is),
62+
}
63+
}
64+
65+
fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> {
66+
match self {
67+
GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os),
68+
GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os),
69+
}
70+
}
71+
72+
fn write_to_bytes(&self) -> ProtobufResult<Vec<u8>> {
73+
match self {
74+
GrpcMessage::Auth(msg) => msg.write_to_bytes(),
75+
GrpcMessage::RateLimit(msg) => msg.write_to_bytes(),
76+
}
77+
}
78+
79+
fn compute_size(&self) -> u32 {
80+
match self {
81+
GrpcMessage::Auth(msg) => msg.compute_size(),
82+
GrpcMessage::RateLimit(msg) => msg.compute_size(),
83+
}
84+
}
85+
86+
fn get_cached_size(&self) -> u32 {
87+
match self {
88+
GrpcMessage::Auth(msg) => msg.get_cached_size(),
89+
GrpcMessage::RateLimit(msg) => msg.get_cached_size(),
90+
}
91+
}
92+
93+
fn get_unknown_fields(&self) -> &UnknownFields {
94+
match self {
95+
GrpcMessage::Auth(msg) => msg.get_unknown_fields(),
96+
GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(),
97+
}
98+
}
99+
100+
fn mut_unknown_fields(&mut self) -> &mut UnknownFields {
101+
match self {
102+
GrpcMessage::Auth(msg) => msg.mut_unknown_fields(),
103+
GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(),
104+
}
105+
}
106+
107+
fn as_any(&self) -> &dyn Any {
108+
match self {
109+
GrpcMessage::Auth(msg) => msg.as_any(),
110+
GrpcMessage::RateLimit(msg) => msg.as_any(),
111+
}
112+
}
113+
114+
fn new() -> Self
115+
where
116+
Self: Sized,
117+
{
118+
// Returning default value
119+
GrpcMessage::default()
120+
}
121+
122+
fn default_instance() -> &'static Self
123+
where
124+
Self: Sized,
125+
{
126+
#[allow(non_upper_case_globals)]
127+
static instance: ::protobuf::rt::LazyV2<GrpcMessage> = ::protobuf::rt::LazyV2::INIT;
128+
instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new()))
129+
}
130+
}
131+
132+
impl GrpcMessage {
133+
// Using domain as ce_host for the time being, we might pass a DataType in the future.
134+
pub fn new(extension_type: ExtensionType, domain: String, descriptors: protobuf::RepeatedField<RateLimitDescriptor>) -> Self {
135+
match extension_type {
136+
ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)),
137+
ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone()))
138+
}
139+
}
140+
141+
}
142+
32143
#[derive(Default)]
33144
pub struct GrpcService {
34145
endpoint: String,
@@ -102,7 +213,7 @@ impl GrpcServiceHandler {
102213
}
103214

104215
pub fn send(&self, message: GrpcMessage) -> Result<u32, Status> {
105-
let msg = Message::write_to_bytes(message.get_message()).unwrap();
216+
let msg = Message::write_to_bytes(&message).unwrap();
106217
let metadata = self
107218
.header_resolver
108219
.get()
@@ -120,18 +231,8 @@ impl GrpcServiceHandler {
120231
)
121232
}
122233

123-
// Using domain as ce_host for the time being, we might pass a DataType in the future.
124-
//TODO(didierofrivia): Make it work with Message. for both Auth and RL
125-
pub fn build_message(
126-
&self,
127-
domain: String,
128-
descriptors: protobuf::RepeatedField<RateLimitDescriptor>,
129-
) -> GrpcMessage {
130-
/*match self.service.extension_type {
131-
//ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())),
132-
//ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)),
133-
}*/
134-
GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors))
234+
pub fn get_extension_type(&self) -> ExtensionType {
235+
self.service.extension_type.clone()
135236
}
136237
}
137238

0 commit comments

Comments
 (0)