diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index ef266bbff1..9c8e0ca1f0 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -326,7 +326,7 @@ async def test_priority_queue_reuse_count_control(self): path=BUFFER_FILE_PATH, replay_buffer=ReplayBufferConfig( enable=True, - priority_fn="linear_decay_use_count_control_randomization", + priority_fn="decay_limit_randomization", reuse_cooldown_time=0.5, priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0}, ), diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 9cd20f73d5..2d924d362e 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -5,8 +5,7 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy -from functools import partial -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import ray @@ -28,48 +27,82 @@ def is_json_file(path: str) -> bool: PRIORITY_FUNC = Registry("priority_fn") -""" -Each priority_fn, - Args: - item: List[Experience], assume that all experiences in it have the same model_version and use_count - kwargs: storage_config.replay_buffer_kwargs (except priority_fn) - Returns: - priority: float - put_into_queue: bool, decide whether to put item into queue -Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer. -""" + + +class PriorityFunction(ABC): + """ + Each priority_fn, + Args: + item: List[Experience], assume that all experiences in it have the same model_version and use_count + priority_fn_args: Dict, the arguments for priority_fn + + Returns: + priority: float + put_into_queue: bool, decide whether to put item into queue + + Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer. + """ + + @abstractmethod + def __call__(self, item: List[Experience]) -> Tuple[float, bool]: + """Calculate the priority of item.""" + + @classmethod + @abstractmethod + def default_config(cls) -> Dict: + """Return the default config.""" @PRIORITY_FUNC.register_module("linear_decay") -def linear_decay_priority( - item: List[Experience], - decay: float = 2.0, -) -> Tuple[float, bool]: +class LinearDecayPriority(PriorityFunction): """Calculate priority by linear decay. Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None). """ - priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"]) - put_into_queue = True - return priority, put_into_queue - - -@PRIORITY_FUNC.register_module("linear_decay_use_count_control_randomization") -def linear_decay_use_count_control_priority( - item: List[Experience], - decay: float = 2.0, - use_count_limit: int = 3, - sigma: float = 0.0, -) -> Tuple[float, bool]: + + def __init__(self, decay: float = 2.0): + self.decay = decay + + def __call__(self, item: List[Experience]) -> Tuple[float, bool]: + priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"]) + put_into_queue = True + return priority, put_into_queue + + @classmethod + def default_config(cls) -> Dict: + return { + "decay": 2.0, + } + + +@PRIORITY_FUNC.register_module("decay_limit_randomization") +class LinearDecayUseCountControlPriority(PriorityFunction): """Calculate priority by linear decay, use count control, and randomization. Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`. """ - priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"]) - if sigma > 0.0: - priority += float(np.random.randn() * sigma) - put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True - return priority, put_into_queue + + def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0): + self.decay = decay + self.use_count_limit = use_count_limit + self.sigma = sigma + + def __call__(self, item: List[Experience]) -> Tuple[float, bool]: + priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"]) + if self.sigma > 0.0: + priority += float(np.random.randn() * self.sigma) + put_into_queue = ( + item[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True + ) + return priority, put_into_queue + + @classmethod + def default_config(cls) -> Dict: + return { + "decay": 2.0, + "use_count_limit": 3, + "sigma": 0.0, + } class QueueBuffer(ABC): @@ -168,7 +201,10 @@ def __init__( self.capacity = capacity self.item_count = 0 self.priority_groups = SortedDict() # Maps priority -> deque of items - self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **(priority_fn_args or {})) + priority_fn_cls = PRIORITY_FUNC.get(priority_fn) + kwargs = priority_fn_cls.default_config() + kwargs.update(priority_fn_args or {}) + self.priority_fn = priority_fn_cls(**kwargs) self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 7c35a89c9c..d4b705e85a 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -14,7 +14,10 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY from trinity.common.constants import StorageType -from trinity.manager.config_registry.buffer_config_manager import get_train_batch_size +from trinity.manager.config_registry.buffer_config_manager import ( + get_train_batch_size, + parse_priority_fn_args, +) from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic from trinity.utils.plugin_loader import load_plugins @@ -190,7 +193,8 @@ def _expert_buffer_part(self): self.get_configs("storage_type") self.get_configs("experience_buffer_path") self.get_configs("enable_replay_buffer") - self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay") + self.get_configs("reuse_cooldown_time", "priority_fn") + self.get_configs("priority_fn_args") # TODO: used for SQL storage # self.buffer_advanced_tab = st.expander("Advanced Config") @@ -592,9 +596,7 @@ def _gen_buffer_config(self): "enable": st.session_state["enable_replay_buffer"], "priority_fn": st.session_state["priority_fn"], "reuse_cooldown_time": st.session_state["reuse_cooldown_time"], - "priority_fn_args": { - "decay": st.session_state["priority_decay"], - }, + "priority_fn_args": parse_priority_fn_args(st.session_state["priority_fn_args"]), } if st.session_state["mode"] != "train": diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index aaf431f9cc..92351095f6 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -1,3 +1,6 @@ +import json + +import pandas as pd import streamlit as st from trinity.buffer.storage.queue import PRIORITY_FUNC @@ -328,13 +331,31 @@ def set_priority_fn(**kwargs): ) +def parse_priority_fn_args(raw_data: str): + try: + data = json.loads(raw_data) + if data["priority_fn"] != st.session_state["priority_fn"]: + raise ValueError + return data["fn_args"] + except (json.JSONDecodeError, KeyError, ValueError): + print(f"Use `default_config` for {st.session_state['priority_fn']}") + return PRIORITY_FUNC.get(st.session_state["priority_fn"]).default_config() + + @CONFIG_GENERATORS.register_config( - default_value=0.1, visible=lambda: st.session_state["enable_replay_buffer"] + default_value="", visible=lambda: st.session_state["enable_replay_buffer"] ) -def set_priority_decay(**kwargs): - st.number_input( - "Priority Decay", - **kwargs, +def set_priority_fn_args(**kwargs): + key = kwargs.get("key") + df = pd.DataFrame([parse_priority_fn_args(st.session_state[key])]) + df.index = [st.session_state["priority_fn"]] + st.caption("Priority Function Args") + df = st.data_editor(df) + st.session_state[key] = json.dumps( + { + "fn_args": df.to_dict(orient="records")[0], + "priority_fn": st.session_state["priority_fn"], + } )