diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 5e3e43462..3cee8887e 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -145,6 +145,8 @@ class FieldMetadata: group: Optional[str] = None # Describes the wrapped type (e.g. when using google.protobuf.BoolValue) wraps: Optional[str] = None + # Is the field optional + optional: Optional[bool] = False @staticmethod def get(field: dataclasses.Field) -> "FieldMetadata": @@ -165,7 +167,9 @@ def dataclass_field( return dataclasses.field( default=None if optional else PLACEHOLDER, metadata={ - "betterproto": FieldMetadata(number, proto_type, map_types, group, wraps) + "betterproto": FieldMetadata( + number, proto_type, map_types, group, wraps, optional + ) }, ) @@ -620,7 +624,8 @@ def __post_init__(self) -> None: if meta.group: group_current.setdefault(meta.group) - if self.__raw_get(field_name) != PLACEHOLDER: + value = self.__raw_get(field_name) + if value != PLACEHOLDER and not (meta.optional and value is None): # Found a non-sentinel value all_sentinel = False @@ -1043,7 +1048,6 @@ def to_dict( defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): field_is_repeated = defaults[field_name] is list - field_is_optional = defaults[field_name] is type(None) value = getattr(self, field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: @@ -1082,7 +1086,8 @@ def to_dict( if value or include_default_values: output[cased_name] = value elif value is None: - output[cased_name] = None + if include_default_values: + output[cased_name] = value elif ( value._serialized_on_wire or include_default_values @@ -1109,7 +1114,8 @@ def to_dict( if field_is_repeated: output[cased_name] = [str(n) for n in value] elif value is None: - output[cased_name] = value + if include_default_values: + output[cased_name] = value else: output[cased_name] = str(value) elif meta.proto_type == TYPE_BYTES: @@ -1117,8 +1123,8 @@ def to_dict( output[cased_name] = [ b64encode(b).decode("utf8") for b in value ] - elif value is None: - output[cased_name] = None + elif value is None and include_default_values: + output[cased_name] = value else: output[cased_name] = b64encode(value).decode("utf8") elif meta.proto_type == TYPE_ENUM: @@ -1132,8 +1138,9 @@ def to_dict( # transparently upgrade single value to repeated output[cased_name] = [enum_class(value).name] elif value is None: - output[cased_name] = None - elif field_is_optional: + if include_default_values: + output[cased_name] = value + elif meta.optional: enum_class = field_types[field_name].__args__[0] output[cased_name] = enum_class(value).name else: @@ -1173,9 +1180,6 @@ def from_dict(self: T, value: Dict[str, Any]) -> T: if value[key] is not None: if meta.proto_type == TYPE_MESSAGE: v = getattr(self, field_name) - if value[key] is None and self._get_field_default(key) == None: - # Setting an optional value to None. - setattr(self, field_name, None) if isinstance(v, list): cls = self._betterproto.cls_by_field[field_name] if cls == datetime: diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_default.json b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json index 145d38412..0967ef424 100644 --- a/tests/inputs/proto3_field_presence/proto3_field_presence_default.json +++ b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json @@ -1,10 +1 @@ -{ - "test1": null, - "test2": null, - "test3": null, - "test4": null, - "test5": null, - "test6": null, - "test7": null, - "test8": null -} +{} diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json index cc0ed2d9a..b19ae9804 100644 --- a/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json +++ b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json @@ -3,7 +3,6 @@ "test2": false, "test3": "", "test4": "", - "test5": null, "test6": "A", "test7": "0", "test8": 0 diff --git a/tests/inputs/proto3_field_presence/test_proto3_field_presence.py b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py new file mode 100644 index 000000000..364e0cce0 --- /dev/null +++ b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py @@ -0,0 +1,38 @@ +import json + +from tests.output_betterproto.proto3_field_presence import Test, InnerTest, TestEnum + + +def test_null_fields_json(): + """Ensure that using "null" in JSON is equivalent to not specifying a + field, for fields with explicit presence""" + + def test_json(ref_json: str, obj_json: str) -> None: + """`ref_json` and `obj_json` are JSON strings describing a `Test` object. + Test that deserializing both leads to the same object, and that + `ref_json` is the normalized format.""" + ref_obj = Test().from_json(ref_json) + obj = Test().from_json(obj_json) + + assert obj == ref_obj + assert json.loads(obj.to_json(0)) == json.loads(ref_json) + + test_json("{}", '{ "test1": null, "test2": null, "test3": null }') + test_json("{}", '{ "test4": null, "test5": null, "test6": null }') + test_json("{}", '{ "test7": null, "test8": null }') + test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }') + + # Make sure that if include_default_values is set, None values are + # exported. + obj = Test() + assert obj.to_dict() == {} + assert obj.to_dict(include_default_values=True) == { + "test1": None, + "test2": None, + "test3": None, + "test4": None, + "test5": None, + "test6": None, + "test7": None, + "test8": None, + } diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json new file mode 100644 index 000000000..da0819278 --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json @@ -0,0 +1,3 @@ +{ + "nested": {} +} diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto new file mode 100644 index 000000000..c4dc9d4f8 --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +message Test { + oneof kind { + Nested nested = 1; + WithOptional with_optional = 2; + } +} + +message InnerNested { + optional bool a = 1; +} + +message Nested { + InnerNested inner = 1; +} + +message WithOptional { + optional bool b = 2; +} diff --git a/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py new file mode 100644 index 000000000..0092a23b5 --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py @@ -0,0 +1,29 @@ +from tests.output_betterproto.proto3_field_presence_oneof import ( + Test, + InnerNested, + Nested, + WithOptional, +) + + +def test_serialization(): + """Ensure that serialization of fields unset but with explicit field + presence do not bloat the serialized payload with length-delimited fields + with length 0""" + + def test_empty_nested(message: Test) -> None: + # '0a' => tag 1, length delimited + # '00' => length: 0 + assert bytes(message) == bytearray.fromhex("0a 00") + + test_empty_nested(Test(nested=Nested())) + test_empty_nested(Test(nested=Nested(inner=None))) + test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None)))) + + def test_empty_with_optional(message: Test) -> None: + # '12' => tag 2, length delimited + # '00' => length: 0 + assert bytes(message) == bytearray.fromhex("12 00") + + test_empty_with_optional(Test(with_optional=WithOptional())) + test_empty_with_optional(Test(with_optional=WithOptional(b=None)))