diff --git a/example/ssd/train/metric.py b/example/ssd/train/metric.py index 731f8fcc19f4..1b3c0c1dc018 100644 --- a/example/ssd/train/metric.py +++ b/example/ssd/train/metric.py @@ -32,6 +32,21 @@ def reset(self): """ override reset behavior """ + if getattr(self, 'num', None) is None: + self.num_inst = 0 + self.sum_metric = 0.0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 + else: + self.num_inst = [0] * self.num + self.sum_metric = [0.0] * self.num + self.global_num_inst = [0] * self.num + self.global_sum_metric = [0.0] * self.num + + def reset_local(self): + """ + override reset_local behavior + """ if getattr(self, 'num', None) is None: self.num_inst = 0 self.sum_metric = 0.0