Skip to content

Commit

Permalink
[refactor, feature] Adding a new GrpcResponse type
Browse files Browse the repository at this point in the history
* Services also building the response Message

Signed-off-by: dd di cesare <[email protected]>
  • Loading branch information
didierofrivia committed Sep 13, 2024
1 parent cba7a43 commit c8d6c4d
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/envoy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use {
AttributeContext_Request,
},
base::Metadata,
external_auth::CheckRequest,
external_auth::{CheckRequest, DeniedHttpResponse, OkHttpResponse},
ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry},
rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code},
};
Expand Down
31 changes: 29 additions & 2 deletions src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use crate::envoy::{
Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer,
AttributeContext_Request, CheckRequest, Metadata, SocketAddress,
};
use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult};
use chrono::{DateTime, FixedOffset, Timelike};
use protobuf::well_known_types::Timestamp;
use protobuf::Message;
use proxy_wasm::hostcalls;
use proxy_wasm::types::MapType;
use proxy_wasm::types::{Bytes, MapType};
use std::collections::HashMap;

pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization";
Expand All @@ -16,10 +18,35 @@ pub struct AuthService;

#[allow(dead_code)]
impl AuthService {
pub fn message(ce_host: String) -> CheckRequest {
pub fn request_message(ce_host: String) -> CheckRequest {
AuthService::build_check_req(ce_host)
}

pub fn response_message(
res_body_bytes: &Bytes,
status_code: u32,
) -> GrpcMessageResult<GrpcMessageResponse> {
if status_code % 2 == 0 {
AuthService::response_message_ok(res_body_bytes)
} else {
AuthService::response_message_denied(res_body_bytes)
}
}

fn response_message_ok(res_body_bytes: &Bytes) -> GrpcMessageResult<GrpcMessageResponse> {
match Message::parse_from_bytes(res_body_bytes) {
Ok(res) => Ok(GrpcMessageResponse::AuthOk(res)),
Err(e) => Err(e),
}
}

fn response_message_denied(res_body_bytes: &Bytes) -> GrpcMessageResult<GrpcMessageResponse> {
match Message::parse_from_bytes(res_body_bytes) {
Ok(res) => Ok(GrpcMessageResponse::AuthDenied(res)),
Err(e) => Err(e),
}
}

fn build_check_req(ce_host: String) -> CheckRequest {
let mut auth_req = CheckRequest::default();
let mut attr = AttributeContext::default();
Expand Down
153 changes: 146 additions & 7 deletions src/service/grpc_message.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::configuration::ExtensionType;
use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest};
use crate::envoy::{
CheckRequest, DeniedHttpResponse, OkHttpResponse, RateLimitDescriptor, RateLimitRequest,
RateLimitResponse,
};
use crate::service::auth::AuthService;
use crate::service::rate_limit::RateLimitService;
use protobuf::reflect::MessageDescriptor;
use protobuf::{
Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields,
Clear, CodedInputStream, CodedOutputStream, Message, ProtobufError, ProtobufResult,
UnknownFields,
};
use proxy_wasm::types::Bytes;
use std::any::Any;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -126,11 +131,145 @@ impl GrpcMessageRequest {
descriptors: protobuf::RepeatedField<RateLimitDescriptor>,
) -> Self {
match extension_type {
ExtensionType::RateLimit => GrpcMessageRequest::RateLimit(RateLimitService::message(
domain.clone(),
descriptors,
)),
ExtensionType::Auth => GrpcMessageRequest::Auth(AuthService::message(domain.clone())),
ExtensionType::RateLimit => GrpcMessageRequest::RateLimit(
RateLimitService::request_message(domain.clone(), descriptors),
),
ExtensionType::Auth => {
GrpcMessageRequest::Auth(AuthService::request_message(domain.clone()))
}
}
}
}

#[derive(Clone, Debug)]
pub enum GrpcMessageResponse {
AuthOk(OkHttpResponse),
AuthDenied(DeniedHttpResponse),
RateLimit(RateLimitResponse),
}

impl Default for GrpcMessageResponse {
fn default() -> Self {
GrpcMessageResponse::RateLimit(RateLimitResponse::new())
}
}

impl Clear for GrpcMessageResponse {
fn clear(&mut self) {
todo!()
}
}

impl Message for GrpcMessageResponse {
fn descriptor(&self) -> &'static MessageDescriptor {
match self {
GrpcMessageResponse::AuthOk(res) => res.descriptor(),
GrpcMessageResponse::AuthDenied(res) => res.descriptor(),
GrpcMessageResponse::RateLimit(res) => res.descriptor(),
}
}

fn is_initialized(&self) -> bool {
match self {
GrpcMessageResponse::AuthOk(res) => res.is_initialized(),
GrpcMessageResponse::AuthDenied(res) => res.is_initialized(),
GrpcMessageResponse::RateLimit(res) => res.is_initialized(),
}
}

fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> {
match self {
GrpcMessageResponse::AuthOk(res) => res.merge_from(is),
GrpcMessageResponse::AuthDenied(res) => res.merge_from(is),
GrpcMessageResponse::RateLimit(res) => res.merge_from(is),
}
}

fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> {
match self {
GrpcMessageResponse::AuthOk(res) => res.write_to_with_cached_sizes(os),
GrpcMessageResponse::AuthDenied(res) => res.write_to_with_cached_sizes(os),
GrpcMessageResponse::RateLimit(res) => res.write_to_with_cached_sizes(os),
}
}

fn write_to_bytes(&self) -> ProtobufResult<Vec<u8>> {
match self {
GrpcMessageResponse::AuthOk(res) => res.write_to_bytes(),
GrpcMessageResponse::AuthDenied(res) => res.write_to_bytes(),
GrpcMessageResponse::RateLimit(res) => res.write_to_bytes(),
}
}

fn compute_size(&self) -> u32 {
match self {
GrpcMessageResponse::AuthOk(res) => res.compute_size(),
GrpcMessageResponse::AuthDenied(res) => res.compute_size(),
GrpcMessageResponse::RateLimit(res) => res.compute_size(),
}
}

fn get_cached_size(&self) -> u32 {
match self {
GrpcMessageResponse::AuthOk(res) => res.get_cached_size(),
GrpcMessageResponse::AuthDenied(res) => res.get_cached_size(),
GrpcMessageResponse::RateLimit(res) => res.get_cached_size(),
}
}

fn get_unknown_fields(&self) -> &UnknownFields {
match self {
GrpcMessageResponse::AuthOk(res) => res.get_unknown_fields(),
GrpcMessageResponse::AuthDenied(res) => res.get_unknown_fields(),
GrpcMessageResponse::RateLimit(res) => res.get_unknown_fields(),
}
}

fn mut_unknown_fields(&mut self) -> &mut UnknownFields {
match self {
GrpcMessageResponse::AuthOk(res) => res.mut_unknown_fields(),
GrpcMessageResponse::AuthDenied(res) => res.mut_unknown_fields(),
GrpcMessageResponse::RateLimit(res) => res.mut_unknown_fields(),
}
}

fn as_any(&self) -> &dyn Any {
match self {
GrpcMessageResponse::AuthOk(res) => res.as_any(),
GrpcMessageResponse::AuthDenied(res) => res.as_any(),
GrpcMessageResponse::RateLimit(res) => res.as_any(),
}
}

fn new() -> Self
where
Self: Sized,
{
// Returning default value
GrpcMessageResponse::default()
}

fn default_instance() -> &'static Self
where
Self: Sized,
{
#[allow(non_upper_case_globals)]
static instance: ::protobuf::rt::LazyV2<GrpcMessageResponse> = ::protobuf::rt::LazyV2::INIT;
instance.get(|| GrpcMessageResponse::RateLimit(RateLimitResponse::new()))
}
}

impl GrpcMessageResponse {
pub fn new(
extension_type: ExtensionType,
res_body_bytes: &Bytes,
status_code: u32,
) -> GrpcMessageResult<GrpcMessageResponse> {
match extension_type {
ExtensionType::RateLimit => RateLimitService::response_message(res_body_bytes),
ExtensionType::Auth => AuthService::response_message(res_body_bytes, status_code),
}
}
}

pub type GrpcMessageResult<T> = Result<T, ProtobufError>;
15 changes: 12 additions & 3 deletions src/service/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::envoy::{RateLimitDescriptor, RateLimitRequest};
use protobuf::RepeatedField;
use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult};
use protobuf::{Message, RepeatedField};
use proxy_wasm::types::Bytes;

pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService";
pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit";

pub struct RateLimitService;

impl RateLimitService {
pub fn message(
pub fn request_message(
domain: String,
descriptors: RepeatedField<RateLimitDescriptor>,
) -> RateLimitRequest {
Expand All @@ -19,6 +21,13 @@ impl RateLimitService {
cached_size: Default::default(),
}
}

pub fn response_message(res_body_bytes: &Bytes) -> GrpcMessageResult<GrpcMessageResponse> {
match Message::parse_from_bytes(res_body_bytes) {
Ok(res) => Ok(GrpcMessageResponse::RateLimit(res)),
Err(e) => Err(e),
}
}
}

#[cfg(test)]
Expand All @@ -37,7 +46,7 @@ mod tests {
field.set_entries(RepeatedField::from_vec(vec![entry]));
let descriptors = RepeatedField::from_vec(vec![field]);

RateLimitService::message(domain.to_string(), descriptors.clone())
RateLimitService::request_message(domain.to_string(), descriptors.clone())
}
#[test]
fn builds_correct_message() {
Expand Down

0 comments on commit c8d6c4d

Please sign in to comment.