From b6580d7b757901662591b61f0197ba237dec7c20 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 30 Jan 2021 09:09:39 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Akihiro Nitta --- pl_examples/domain_templates/reinforce_learn_ppo.py | 2 +- pl_examples/domain_templates/unet.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pl_examples/domain_templates/reinforce_learn_ppo.py b/pl_examples/domain_templates/reinforce_learn_ppo.py index a7cbc654830ad..026784f900622 100644 --- a/pl_examples/domain_templates/reinforce_learn_ppo.py +++ b/pl_examples/domain_templates/reinforce_learn_ppo.py @@ -289,7 +289,7 @@ def calc_advantage(self, rewards: List[float], values: List[float], last_value: return adv - def generate_trajectory_samples(self, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Contains the logic for generating trajectory data to train policy and value network Yield: diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index 3b4ea5cc72ad2..f083ae434bd33 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -92,8 +92,12 @@ class DoubleConv(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.net = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), ) def forward(self, x):