Skip to content

Commit 3f1df40

Browse files
authored
add results dir (#1044)
1 parent c2bf263 commit 3f1df40

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/fairchem/core/_cli_hydra.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
LOG_DIR_NAME = "logs"
5454
CHECKPOINT_DIR_NAME = "checkpoints"
55+
RESULTS_DIR = "results"
5556
CONFIG_FILE_NAME = "canonical_config.yaml"
5657
PREEMPTION_STATE_DIR_NAME = "preemption_state"
5758

@@ -90,6 +91,7 @@ class Metadata:
9091
commit: str
9192
log_dir: str
9293
checkpoint_dir: str
94+
results_dir: str
9395
config_path: str
9496
preemption_checkpoint_dir: str
9597
cluster_name: str
@@ -120,6 +122,7 @@ def __post_init__(self) -> None:
120122
checkpoint_dir=os.path.join(
121123
self.run_dir, self.timestamp_id, CHECKPOINT_DIR_NAME
122124
),
125+
results_dir=os.path.join(self.run_dir, self.timestamp_id, RESULTS_DIR),
123126
config_path=os.path.join(self.run_dir, self.timestamp_id, CONFIG_FILE_NAME),
124127
preemption_checkpoint_dir=os.path.join(
125128
self.run_dir,

src/fairchem/core/common/utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1458,11 +1458,11 @@ def get_timestamp_uid() -> str:
14581458
@torch.no_grad()
14591459
def tensor_stats(name: str, x: torch.Tensor) -> dict:
14601460
return {
1461-
f"{name}.max": x.max(),
1462-
f"{name}.min": x.min(),
1463-
f"{name}.std": x.std(),
1464-
f"{name}.mean": x.mean(),
1465-
f"{name}.norm": torch.norm(x, p=2),
1461+
f"{name}.max": x.max().item(),
1462+
f"{name}.min": x.min().item(),
1463+
f"{name}.std": x.std().item(),
1464+
f"{name}.mean": x.mean().item(),
1465+
f"{name}.norm": torch.norm(x, p=2).item(),
14661466
f"{name}.nonzero_fraction": torch.nonzero(x).shape[0] / float(x.numel()),
14671467
}
14681468

0 commit comments

Comments
 (0)