Skip to content

Commit 3855954

Browse files
update (#58)
Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 88f7a9a commit 3855954

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/lighteval/logging/evaluation_tracker.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import json
23
import os
34
import re
@@ -22,7 +23,7 @@
2223

2324

2425
if is_nanotron_available():
25-
from nanotron.config import Config, get_config_from_dict
26+
from nanotron.config import Config
2627

2728

2829
class EnhancedJSONEncoder(json.JSONEncoder):
@@ -116,8 +117,14 @@ def save(
116117

117118
hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}")
118119

120+
config_general = copy.deepcopy(self.general_config_logger)
121+
config_general.config = (
122+
config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config
123+
)
124+
config_general = asdict(config_general)
125+
119126
to_dump = {
120-
"config_general": asdict(self.general_config_logger),
127+
"config_general": config_general,
121128
"results": self.metrics_logger.metric_aggregated,
122129
"versions": self.versions_logger.versions,
123130
"config_tasks": self.task_config_logger.tasks_configs,
@@ -485,7 +492,7 @@ def push_results_to_tensorboard( # noqa: C901
485492
if not is_nanotron_available():
486493
hlog_warn("You cannot push results to tensorboard with having nanotron installed. Skipping")
487494
return
488-
config: Config = get_config_from_dict(self.general_config_logger.config, config_class=Config)
495+
config: Config = self.general_config_logger.config
489496
lighteval_config = config.lighteval
490497
try:
491498
global_step = config.general.step

src/lighteval/main_nanotron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def main(
8080
override_batch_size=None,
8181
max_samples=lighteval_config.tasks.max_samples,
8282
job_id=os.environ.get("SLURM_JOB_ID", None),
83-
config=nanotron_config.as_dict(),
83+
config=nanotron_config,
8484
)
8585

8686
with htrack_block("Test all gather"):

0 commit comments

Comments
 (0)