diff --git a/arealite/README.md b/arealite/README.md index 2c5288240..214549ccb 100644 --- a/arealite/README.md +++ b/arealite/README.md @@ -737,4 +737,4 @@ dataloader = StatefulDataLoader( ) for data in dataloader: assert isinstance(data, list) -``` \ No newline at end of file +``` diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 38f676866..ca4b78bd8 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -4,6 +4,9 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple +import uvloop + +uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init from omegaconf import MISSING, OmegaConf diff --git a/pyproject.toml b/pyproject.toml index 962c9b9bf..4f752ebc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,9 +53,9 @@ dependencies = [ "hydra-core==1.4.0.dev1", "packaging", "tabulate", + "gymnasium>=1.1.1", "torchdata", "autoflake", - "gymnasium", "tensordict", # Monitoring and logging diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index ce6d9bf95..f698a76d6 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -8,6 +8,7 @@ import random import time from contextlib import contextmanager +from functools import lru_cache # NOTE: We don't sue wildcard importing here because the type # `Sequence` has a very similar name to `SequenceSample`. @@ -47,6 +48,7 @@ RL_TASKS = ["math", "code", "rlhf", "stem"] +@lru_cache(maxsize=8) def load_hf_tokenizer( model_name_or_path: str, fast_tokenizer=True, diff --git a/requirements.txt b/requirements.txt index 0af83fcaa..5318511cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,8 +69,8 @@ word2number Pebble timeout-decorator prettytable +gymnasium>=1.1.1 swanlab[dashboard] torchdata autoflake -gymnasium -tensordict \ No newline at end of file +tensordict