From 1c81d77d8c0d925b70c6a82a001335873a4f43b6 Mon Sep 17 00:00:00 2001 From: Marcus Lee Date: Tue, 26 May 2020 04:23:30 +0800 Subject: [PATCH] Conversion of torch.Tensor to numpy object Due to the possibility of `x[1]` being of type `torch.Tensor`, the `np.mean(...)` call fails with an `AttributeError: torch.dtype object has no attribute type`. --- src/utils/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils/logging.py b/src/utils/logging.py index ec95ed161..4977e07cb 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,6 +1,7 @@ from collections import defaultdict import logging import numpy as np +import torch class Logger: def __init__(self, console_logger): @@ -45,7 +46,7 @@ def print_recent_stats(self): continue i += 1 window = 5 if k != "epsilon" else 1 - item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]])) + item = "{:.4f}".format(np.mean([x[1] if not isinstance(x[1], torch.Tensor) else x[1].cpu().numpy() for x in self.stats[k][-window:]])) log_str += "{:<25}{:>8}".format(k + ":", item) log_str += "\n" if i % 4 == 0 else "\t" self.console_logger.info(log_str)