Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,7 @@ def __post_init__(self) -> None:
# Add field to message
self.parent.fields.append(self)
# Check for new imports
annotation = self.annotation
if "Optional[" in annotation:
self.output_file.typing_imports.add("Optional")
if "List[" in annotation:
self.output_file.typing_imports.add("List")
if "Dict[" in annotation:
self.output_file.typing_imports.add("Dict")
if "timedelta" in annotation:
self.output_file.datetime_imports.add("timedelta")
if "datetime" in annotation:
self.output_file.datetime_imports.add("datetime")
self.add_imports_to(self.output_file)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__

def get_field_string(self, indent: int = 4) -> str:
Expand All @@ -356,6 +346,33 @@ def betterproto_field_args(self) -> List[str]:
args.append(f"wraps={self.field_wraps}")
return args

@property
def datetime_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
# FIXME: false positives - e.g. `MyDatetimedelta`
if "timedelta" in annotation:
imports.add("timedelta")
if "datetime" in annotation:
imports.add("datetime")
return imports

@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports

def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)

@property
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
Expand Down Expand Up @@ -577,11 +594,10 @@ def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)

# Check for Optional import
# Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
if f.default_value_string == "None":
self.output_file.typing_imports.add("Optional")
f.add_imports_to(self.output_file)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that this line is the actual fix?

Copy link
Contributor Author

@leenr leenr Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, along with the other changes.

I've deleted adding "Optional" import two lines above because there already is the code to add it in the FieldCompiler class:

if "Optional[" in annotation:
    imports.add("Optional")

if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"googletypes_response",
"googletypes_response_embedded",
"service",
"service_separate_packages",
"import_service_input_message",
"googletypes_service_returns_empty",
"googletypes_service_returns_googletype",
Expand Down
31 changes: 31 additions & 0 deletions tests/inputs/service_separate_packages/messages.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
syntax = "proto3";

import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";

package things.messages;

message DoThingRequest {
string name = 1;

// use `repeated` so we can check if `List` is correctly imported
repeated string comments = 2;

// use google types `timestamp` and `duration` so we can check
// if everything from `datetime` is correctly imported
google.protobuf.Timestamp when = 3;
google.protobuf.Duration duration = 4;
}

message DoThingResponse {
repeated string names = 1;
}

message GetThingRequest {
string name = 1;
}

message GetThingResponse {
string name = 1;
int32 version = 2;
}
12 changes: 12 additions & 0 deletions tests/inputs/service_separate_packages/service.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
syntax = "proto3";

import "messages.proto";

package things.service;

service Test {
rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
}