Skip to content

Commit

Permalink
feat: support protobuf serialization and deserialization for CollabPa…
Browse files Browse the repository at this point in the history
…rams (#834)
  • Loading branch information
khorshuheng authored Oct 2, 2024
1 parent 96d7ae8 commit 3b320b0
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/database-entity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ app-error = { workspace = true }
bincode = "1.3.3"
appflowy-ai-client = { workspace = true, features = ["dto"] }
bytes.workspace = true
prost = "0.12"
209 changes: 206 additions & 3 deletions libs/database-entity/src/dto.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::error::EntityError;
use crate::error::EntityError::{DeserializationError, InvalidData};
use crate::util::{validate_not_empty_payload, validate_not_empty_str};
use appflowy_ai_client::dto::AIModel;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use collab_entity::proto;
use collab_entity::CollabType;
use prost::Message;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::cmp::Ordering;
Expand Down Expand Up @@ -62,7 +66,7 @@ impl CreateCollabParams {

pub struct CollabIndexParams {}

#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq)]
pub struct CollabParams {
#[validate(custom = "validate_not_empty_str")]
pub object_id: String,
Expand Down Expand Up @@ -107,7 +111,50 @@ impl CollabParams {
},
}
}

pub fn to_proto(&self) -> proto::collab::CollabParams {
proto::collab::CollabParams {
object_id: self.object_id.clone(),
encoded_collab: self.encoded_collab_v1.to_vec(),
collab_type: self.collab_type.to_proto() as i32,
embeddings: self
.embeddings
.as_ref()
.map(|embeddings| embeddings.to_proto()),
}
}

pub fn to_protobuf_bytes(&self) -> Vec<u8> {
self.to_proto().encode_to_vec()
}

pub fn from_protobuf_bytes(bytes: &[u8]) -> Result<Self, EntityError> {
match proto::collab::CollabParams::decode(bytes) {
Ok(proto) => Self::try_from(proto),
Err(err) => Err(DeserializationError(err.to_string())),
}
}
}

impl TryFrom<proto::collab::CollabParams> for CollabParams {
type Error = EntityError;

fn try_from(proto: proto::collab::CollabParams) -> Result<Self, Self::Error> {
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
let collab_type = CollabType::from_proto(&collab_type_proto);
let embeddings = proto
.embeddings
.map(AFCollabEmbeddings::from_proto)
.transpose()?;
Ok(Self {
object_id: proto.object_id,
encoded_collab_v1: Bytes::from(proto.encoded_collab),
collab_type,
embeddings,
})
}
}

#[derive(Serialize, Deserialize)]
struct CollabParamsV0 {
object_id: String,
Expand Down Expand Up @@ -917,12 +964,72 @@ pub struct AFCollabEmbeddingParams {
pub embedding: Option<Vec<f32>>,
}

impl AFCollabEmbeddingParams {
pub fn from_proto(proto: &proto::collab::CollabEmbeddingsParams) -> Result<Self, EntityError> {
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
let collab_type = CollabType::from_proto(&collab_type_proto);
let content_type_proto =
proto::collab::EmbeddingContentType::try_from(proto.content_type).unwrap();
let content_type = EmbeddingContentType::from_proto(content_type_proto)?;
let embedding = if proto.embedding.is_empty() {
None
} else {
Some(proto.embedding.clone())
};
Ok(Self {
fragment_id: proto.fragment_id.clone(),
object_id: proto.object_id.clone(),
collab_type,
content_type,
content: proto.content.clone(),
embedding,
})
}

pub fn to_proto(&self) -> proto::collab::CollabEmbeddingsParams {
proto::collab::CollabEmbeddingsParams {
fragment_id: self.fragment_id.clone(),
object_id: self.object_id.clone(),
collab_type: self.collab_type.to_proto() as i32,
content_type: self.content_type.to_proto() as i32,
content: self.content.clone(),
embedding: self.embedding.clone().unwrap_or_default(),
}
}

pub fn to_protobuf_bytes(&self) -> Vec<u8> {
self.to_proto().encode_to_vec()
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AFCollabEmbeddings {
pub tokens_consumed: u32,
pub params: Vec<AFCollabEmbeddingParams>,
}

impl AFCollabEmbeddings {
pub fn from_proto(proto: proto::collab::CollabEmbeddings) -> Result<Self, EntityError> {
let mut params = vec![];
for param in proto.embeddings {
params.push(AFCollabEmbeddingParams::from_proto(&param)?);
}
Ok(Self {
tokens_consumed: proto.tokens_consumed,
params,
})
}

pub fn to_proto(&self) -> proto::collab::CollabEmbeddings {
let embeddings: Vec<proto::collab::CollabEmbeddingsParams> =
self.params.iter().map(|param| param.to_proto()).collect();
proto::collab::CollabEmbeddings {
tokens_consumed: self.tokens_consumed,
embeddings,
}
}
}

/// Type of content stored by the embedding.
/// Currently only plain text of the document is supported.
/// In the future, we might support other kinds like i.e. PDF, images or image-extracted text.
Expand All @@ -933,6 +1040,24 @@ pub enum EmbeddingContentType {
PlainText = 0,
}

impl EmbeddingContentType {
pub fn from_proto(proto: proto::collab::EmbeddingContentType) -> Result<Self, EntityError> {
match proto {
proto::collab::EmbeddingContentType::PlainText => Ok(EmbeddingContentType::PlainText),
proto::collab::EmbeddingContentType::Unknown => Err(InvalidData(format!(
"{} is not a supported embedding type",
proto.as_str_name()
))),
}
}

pub fn to_proto(&self) -> proto::collab::EmbeddingContentType {
match self {
EmbeddingContentType::PlainText => proto::collab::EmbeddingContentType::PlainText,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageResponse {
pub answer: Option<ChatMessage>,
Expand Down Expand Up @@ -1286,8 +1411,13 @@ pub struct ApproveAccessRequestParams {

#[cfg(test)]
mod test {
use crate::dto::{CollabParams, CollabParamsV0};
use collab_entity::CollabType;
use crate::dto::{
AFCollabEmbeddingParams, AFCollabEmbeddings, CollabParams, CollabParamsV0, EmbeddingContentType,
};
use crate::error::EntityError;
use bytes::Bytes;
use collab_entity::{proto, CollabType};
use prost::Message;
use uuid::Uuid;

#[test]
Expand Down Expand Up @@ -1357,4 +1487,77 @@ mod test {
assert_eq!(collab_params.collab_type, v0.collab_type);
assert_eq!(collab_params.encoded_collab_v1, v0.encoded_collab_v1);
}

#[test]
fn deserialization_using_protobuf() {
let collab_params_with_embeddings = CollabParams {
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
encoded_collab_v1: Bytes::default(),
embeddings: Some(AFCollabEmbeddings {
tokens_consumed: 100,
params: vec![AFCollabEmbeddingParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: "content".to_string(),
embedding: Some(vec![1.0, 2.0, 3.0]),
}],
}),
};

let protobuf_encoded = collab_params_with_embeddings.to_protobuf_bytes();
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
assert_eq!(collab_params_with_embeddings, collab_params_decoded);
}

#[test]
fn deserialize_collab_params_without_embeddings() {
let collab_params = CollabParams {
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
encoded_collab_v1: Bytes::from(vec![1, 2, 3]),
embeddings: Some(AFCollabEmbeddings {
tokens_consumed: 100,
params: vec![AFCollabEmbeddingParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: "content".to_string(),
embedding: None,
}],
}),
};

let protobuf_encoded = collab_params.to_protobuf_bytes();
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
assert_eq!(collab_params, collab_params_decoded);
}

#[test]
fn deserialize_collab_params_with_unknown_embedding_type() {
let invalid_serialization = proto::collab::CollabParams {
object_id: "object_id".to_string(),
encoded_collab: vec![1, 2, 3],
collab_type: proto::collab::CollabType::Document as i32,
embeddings: Some(proto::collab::CollabEmbeddings {
tokens_consumed: 100,
embeddings: vec![proto::collab::CollabEmbeddingsParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: proto::collab::CollabType::Document as i32,
content_type: proto::collab::EmbeddingContentType::Unknown as i32,
content: "content".to_string(),
embedding: vec![1.0, 2.0, 3.0],
}],
}),
}
.encode_to_vec();

let result = CollabParams::from_protobuf_bytes(&invalid_serialization);
assert!(result.is_err());
assert!(matches!(result, Err(EntityError::InvalidData(_))));
}
}
9 changes: 9 additions & 0 deletions libs/database-entity/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#[derive(Debug, thiserror::Error)]
pub enum EntityError {
#[error("Invalid data: {0}")]
InvalidData(String),
#[error("Deserialization error: {0}")]
DeserializationError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
}
1 change: 1 addition & 0 deletions libs/database-entity/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod dto;
pub mod error;
pub mod file_dto;
mod util;
2 changes: 1 addition & 1 deletion script/client_api_deps_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Generate the current dependency list
cargo tree > current_deps.txt

BASELINE_COUNT=620
BASELINE_COUNT=621
CURRENT_COUNT=$(cat current_deps.txt | wc -l)

echo "Expected dependency count (baseline): $BASELINE_COUNT"
Expand Down

0 comments on commit 3b320b0

Please sign in to comment.