Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
cache: 'pip'
- name: Install lighteval in editable mode
run: |
pip install -e .[dev,extended_tasks]
pip install -e .[dev,nanotron,extended_tasks]
- name: Get cached files
uses: actions/cache@v2
id: get-cache
Expand Down
6 changes: 6 additions & 0 deletions run_evals_nanotron.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def get_parser():
default=None,
help="Cache directory",
)
parser.add_argument(
"--max_samples",
type=int,
default=10,
help="number of samples used for evaluation",
)

return parser

Expand Down
19 changes: 19 additions & 0 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
# flake8: noqa: C901
import os
import random
from argparse import Namespace
from typing import Optional, Type

import numpy as np
import torch

from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
Expand Down Expand Up @@ -64,7 +66,15 @@ def main(
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
args: Optional[Namespace] = None, # accept args for more flexibility
):
if args is not None:
checkpoint_config_path = (
args.checkpoint_config_path if checkpoint_config_path is None else checkpoint_config_path
)
lighteval_config_path = args.lighteval_override if lighteval_config_path is None else lighteval_config_path
cache_dir = args.cache_dir if cache_dir is None else cache_dir

if cache_dir is None:
cache_dir = CACHE_DIR

Expand All @@ -82,6 +92,7 @@ def main(
model_config_class=model_config_cls,
skip_unused_config_keys=True,
skip_null_keys=True,
igonore_all_unused_keys=True,
)

if lighteval_config_path:
Expand All @@ -90,6 +101,9 @@ def main(
else:
lighteval_config = nanotron_config.lighteval

if args.max_samples is not None:
lighteval_config.tasks.max_samples = args.max_samples

parallel_context = ParallelContext(
tensor_parallel_size=lighteval_config.parallelism.tp,
pipeline_parallel_size=lighteval_config.parallelism.pp,
Expand Down Expand Up @@ -157,8 +171,13 @@ def main(

with htrack_block("Setting seeds and waiting for all processes"):
hlog(f"setting seed to {SEED} for random and numpy")

torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

dist.barrier()

with htrack_block("Evaluation"):
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
# this took me some time each time when I run the tests, even I don't need it
llm_judge_multi_turn = SampleLevelMetricGrouping(
metric=["single_turn", "multi_turn"],
higher_is_better=True,
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def prepare_batch(

# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen

if padding_length - inplen < 0:
raise ValueError("Negative padding")
padded.append(padding_length - inplen)
Expand Down Expand Up @@ -655,7 +656,7 @@ def _get_subsets(self, dataset, dataset_splits):
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=dataset_splits)
res = []

# Dataset is sorted in descending size.
Expand Down
55 changes: 55 additions & 0 deletions tests/config/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Nanotron tests guide
## How it works:
First select some tasks and then use the model to generate reference scores and save them in reference_task_scores_nanotron.py file, it has been done, but if you want to add a new task, you need to re-run it.

After that, each time a test need to be conducted, the evaluation will be run and the results are compared to the previous reference score.

## To run nanotron test:
```
pytest tests/test_main_nanotron.py -sv
```

## Choose your own tasks for evaluation:
Modify the **tasks.tasks** in config file(lighteval/tests/config/lighteval_config_override_custom.yaml) to set the tasks.
Example:
```
tasks:
custom_tasks: null
dataset_loading_processes: 1
max_samples: 10
multichoice_continuations_start_space: null
no_multichoice_continuations_start_space: null
num_fewshot_seeds: null
tasks: lighteval|anli:r1|0|0,lighteval|blimp:adjunct_island|0|0,...
```

## Randomized results
Please make sure to set **for_inference** to true. This will load model with a fixed output layer norm implementation. It's set to false by default for training
```
model:
ddp_bucket_cap_mb: 25
dtype: float64
init_method:
std: 0.02
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 512
initializer_range: 0.02
intermediate_size: 2048
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 16
num_hidden_layers: 16
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 50272
for_inference: true
```
30 changes: 30 additions & 0 deletions tests/config/lighteval_config_override_custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
batch_size: 16
checkpoints_path: null
generation: null
logging:
hub_repo_details: null
hub_repo_results: null
hub_repo_tensorboard: zzhhjjj/debug-nanotron
local_output_path: /scratch/haojun/lighteval/nanotron-119M-seed-6-3188821
push_details_to_hub: null
push_results_to_hub: null
push_results_to_tensorboard: true
tensorboard_metric_prefix: e
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
slurm_script_dir: /fsx/haojun/logs_evals
slurm_template: /fsx/haojun/brrr/examples/get-started-kit/run_eval.slurm.jinja
tasks:
custom_tasks: null
dataset_loading_processes: 1
max_samples: 10
multichoice_continuations_start_space: null
no_multichoice_continuations_start_space: null
num_fewshot_seeds: null
tasks: lighteval|anli:r1|0|0,lighteval|blimp:adjunct_island|0|0,lighteval|blimp:ellipsis_n_bar_1|0|0,leaderboard|arc:challenge|25|0,leaderboard|hellaswag|10|0,leaderboard|mmlu:abstract_algebra|5|0,leaderboard|mmlu:college_chemistry|5|0,leaderboard|mmlu:computer_security|5|0,leaderboard|mmlu:us_foreign_policy|5|0,leaderboard|truthfulqa:mc|0|0,helm|mmlu:abstract_algebra|5|0,helm|mmlu:college_chemistry|5|0,helm|mmlu:computer_security|5|0,helm|mmlu:us_foreign_policy|5|0,helm|boolq|5|0,helm|hellaswag|5|0,leaderboard|gsm8k|5|0
Loading