Skip to content

Commit b779fde

Browse files
committed
batch norms fix for 3d core
1 parent 504a5a6 commit b779fde

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

neuralpredictors/layers/cores/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from abc import ABC, abstractmethod
23
from collections import OrderedDict
34

@@ -77,7 +78,8 @@ def add_bn_layer(self, layer: OrderedDict, layer_idx: int):
7778
raise NotImplementedError(f"Subclasses must have a `{attr}` attribute.")
7879
for attr in ["batch_norm", "hidden_channels", "bias", "batch_norm_scale"]:
7980
if not isinstance(getattr(self, attr), list):
80-
raise ValueError(f"`{attr}` must be a list.")
81+
setattr(self, attr, [getattr(self, attr)] * self.layers)
82+
warnings.warn(f"The {attr} is applied to all layers", UserWarning)
8183

8284
if self.batch_norm[layer_idx]:
8385
hidden_channels = self.hidden_channels[layer_idx]

neuralpredictors/layers/cores/conv3d.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(
160160
padding=(0, input_kernel[1] // 2, input_kernel[2] // 2) if self.padding else 0,
161161
)
162162

163-
self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
163+
self.add_bn_layer(layer=layer, layer_idx=0)
164164

165165
if layers > 1 or self.final_nonlinearity:
166166
if hidden_nonlinearities == "adaptive_elu":
@@ -185,7 +185,7 @@ def __init__(
185185
padding=(0, self.hidden_kernel[l][1] // 2, self.hidden_kernel[l][2] // 2) if self.padding else 0,
186186
)
187187

188-
self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
188+
self.add_bn_layer(layer=layer, layer_idx=l + 1)
189189

190190
if self.final_nonlinearity or l < self.layers:
191191
if hidden_nonlinearities == "adaptive_elu":
@@ -363,7 +363,10 @@ def __init__(
363363
dilation=(self.temporal_dilation, 1, 1),
364364
)
365365

366-
self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
366+
self.add_bn_layer(
367+
layer=layer,
368+
layer_idx=0,
369+
)
367370

368371
if layers > 1 or final_nonlin:
369372
if hidden_nonlinearities == "adaptive_elu":
@@ -394,7 +397,7 @@ def __init__(
394397
dilation=(self.hidden_temporal_dilation[l], 1, 1),
395398
)
396399

397-
self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
400+
self.add_bn_layer(layer=layer, layer_idx=l + 1)
398401

399402
if final_nonlin or l < self.layers:
400403
if hidden_nonlinearities == "adaptive_elu":

0 commit comments

Comments
 (0)