diff --git a/lang/rust/avro_derive/Cargo.toml b/lang/rust/avro_derive/Cargo.toml index e16e9ea957a..2ba8bc8c0ef 100644 --- a/lang/rust/avro_derive/Cargo.toml +++ b/lang/rust/avro_derive/Cargo.toml @@ -32,10 +32,11 @@ documentation = "https://docs.rs/apache-avro-derive" proc-macro = true [dependencies] -syn = {version= "1.0.91", features=["full", "fold"]} +darling = "0.14.0" quote = "1.0.18" +syn = {version= "1.0.91", features=["full", "fold"]} proc-macro2 = "1.0.37" -darling = "0.14.0" +serde_json = "1.0.79" [dev-dependencies] serde = { version = "1.0.136", features = ["derive"] } diff --git a/lang/rust/avro_derive/src/lib.rs b/lang/rust/avro_derive/src/lib.rs index 0055249cb3b..393db8c1991 100644 --- a/lang/rust/avro_derive/src/lib.rs +++ b/lang/rust/avro_derive/src/lib.rs @@ -30,6 +30,8 @@ use syn::{ struct FieldOptions { #[darling(default)] doc: Option, + #[darling(default)] + default: Option, } #[derive(FromAttributes)] @@ -117,16 +119,24 @@ fn get_data_struct_schema_def( syn::Fields::Named(ref a) => { for (position, field) in a.named.iter().enumerate() { let name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name - let field_documented = + let field_attrs = FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?; - let doc = preserve_optional(field_documented.doc); + let doc = preserve_optional(field_attrs.doc); + let default_value = match field_attrs.default { + Some(default_value) => { + quote! { + Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str())) + } + } + None => quote! { None }, + }; let schema_expr = type_to_schema_expr(&field.ty)?; let position = position; record_field_exprs.push(quote! { apache_avro::schema::RecordField { name: #name.to_string(), doc: #doc, - default: Option::None, + default: #default_value, schema: #schema_expr, order: apache_avro::schema::RecordFieldOrder::Ascending, position: #position, @@ -186,7 +196,7 @@ fn get_data_enum_schema_def( name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]), aliases: #enum_aliases, doc: #doc, - symbols: vec![#(#symbols.to_owned()),*] + symbols: vec![#(#symbols.to_owned()),*], } }) } else { diff --git a/lang/rust/avro_derive/tests/derive.rs b/lang/rust/avro_derive/tests/derive.rs index 284323a2af4..0be2bd4bfc8 100644 --- a/lang/rust/avro_derive/tests/derive.rs +++ b/lang/rust/avro_derive/tests/derive.rs @@ -1209,4 +1209,134 @@ mod test_derive { serde_assert(TestBasicEnumWithAliases2::B); } + + #[test] + fn test_basic_struct_with_defaults() { + #[derive(Debug, Deserialize, Serialize, AvroSchema, Clone, PartialEq)] + enum MyEnum { + Foo, + Bar, + Baz, + } + + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct TestBasicStructWithDefaultValues { + #[avro(default = "123")] + a: i32, + #[avro(default = r#""The default value for 'b'""#)] + b: String, + #[avro(default = "true")] + condition: bool, + // no default value for 'c' + c: f64, + #[avro(default = r#"{"a": 1, "b": 2}"#)] + map: HashMap, + + #[avro(default = "[1, 2, 3]")] + array: Vec, + + #[avro(default = r#""Foo""#)] + myenum: MyEnum, + } + + let schema = r#" + { + "type":"record", + "name":"TestBasicStructWithDefaultValues", + "fields": [ + { + "name":"a", + "type":"int", + "default":123 + }, + { + "name":"b", + "type":"string", + "default": "The default value for 'b'" + }, + { + "name":"condition", + "type":"boolean", + "default":true + }, + { + "name":"c", + "type":"double" + }, + { + "name":"map", + "type":{ + "type":"map", + "values":"int" + }, + "default": { + "a": 1, + "b": 2 + } + }, + { + "name":"array", + "type":{ + "type":"array", + "items":"int" + }, + "default": [1, 2, 3] + }, + { + "name":"myenum", + "type":{ + "type":"enum", + "name":"MyEnum", + "symbols":["Foo", "Bar", "Baz"] + }, + "default":"Foo" + } + ] + } + "#; + + let schema = Schema::parse_str(schema).unwrap(); + if let Schema::Record { name, fields, .. } = TestBasicStructWithDefaultValues::get_schema() + { + assert_eq!("TestBasicStructWithDefaultValues", name.fullname(None)); + use serde_json::json; + for field in fields { + match field.name.as_str() { + "a" => assert_eq!(Some(json!(123_i32)), field.default), + "b" => assert_eq!( + Some(json!(r#"The default value for 'b'"#.to_owned())), + field.default + ), + "condition" => assert_eq!(Some(json!(true)), field.default), + "array" => assert_eq!(Some(json!([1, 2, 3])), field.default), + "map" => assert_eq!( + Some(json!({ + "a": 1, + "b": 2 + })), + field.default + ), + "c" => assert_eq!(None, field.default), + "myenum" => assert_eq!(Some(json!("Foo")), field.default), + _ => panic!("Unexpected field name"), + } + } + } else { + panic!("TestBasicStructWithDefaultValues schema must be a record schema") + } + assert_eq!(schema, TestBasicStructWithDefaultValues::get_schema()); + + serde_assert(TestBasicStructWithDefaultValues { + a: 321, + b: "A custom value for 'b'".to_owned(), + condition: false, + c: 987.654, + map: [("a".to_owned(), 1), ("b".to_owned(), 2)] + .iter() + .cloned() + .collect(), + array: vec![4, 5, 6], + myenum: MyEnum::Bar, + }); + } }