Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ async def test_priority_queue_reuse_count_control(self):
path=BUFFER_FILE_PATH,
replay_buffer=ReplayBufferConfig(
enable=True,
priority_fn="linear_decay_use_count_control_randomization",
priority_fn="decay_limit_randomization",
reuse_cooldown_time=0.5,
priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0},
),
Expand Down
104 changes: 70 additions & 34 deletions trinity/buffer/storage/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from functools import partial
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
import ray
Expand All @@ -28,48 +27,82 @@ def is_json_file(path: str) -> bool:


PRIORITY_FUNC = Registry("priority_fn")
"""
Each priority_fn,
Args:
item: List[Experience], assume that all experiences in it have the same model_version and use_count
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
Returns:
priority: float
put_into_queue: bool, decide whether to put item into queue
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
"""


class PriorityFunction(ABC):
"""
Each priority_fn,
Args:
item: List[Experience], assume that all experiences in it have the same model_version and use_count
priority_fn_args: Dict, the arguments for priority_fn

Returns:
priority: float
put_into_queue: bool, decide whether to put item into queue

Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
"""

@abstractmethod
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
"""Calculate the priority of item."""

@classmethod
@abstractmethod
def default_config(cls) -> Dict:
"""Return the default config."""


@PRIORITY_FUNC.register_module("linear_decay")
def linear_decay_priority(
item: List[Experience],
decay: float = 2.0,
) -> Tuple[float, bool]:
class LinearDecayPriority(PriorityFunction):
"""Calculate priority by linear decay.

Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
"""
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
put_into_queue = True
return priority, put_into_queue


@PRIORITY_FUNC.register_module("linear_decay_use_count_control_randomization")
def linear_decay_use_count_control_priority(
item: List[Experience],
decay: float = 2.0,
use_count_limit: int = 3,
sigma: float = 0.0,
) -> Tuple[float, bool]:

def __init__(self, decay: float = 2.0):
self.decay = decay

def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
put_into_queue = True
return priority, put_into_queue

@classmethod
def default_config(cls) -> Dict:
return {
"decay": 2.0,
}


@PRIORITY_FUNC.register_module("decay_limit_randomization")
class LinearDecayUseCountControlPriority(PriorityFunction):
"""Calculate priority by linear decay, use count control, and randomization.

Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
"""
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
if sigma > 0.0:
priority += float(np.random.randn() * sigma)
put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True
return priority, put_into_queue

def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0):
self.decay = decay
self.use_count_limit = use_count_limit
self.sigma = sigma

def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
if self.sigma > 0.0:
priority += float(np.random.randn() * self.sigma)
put_into_queue = (
item[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True
)
return priority, put_into_queue

@classmethod
def default_config(cls) -> Dict:
return {
"decay": 2.0,
"use_count_limit": 3,
"sigma": 0.0,
}


class QueueBuffer(ABC):
Expand Down Expand Up @@ -168,7 +201,10 @@ def __init__(
self.capacity = capacity
self.item_count = 0
self.priority_groups = SortedDict() # Maps priority -> deque of items
self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **(priority_fn_args or {}))
priority_fn_cls = PRIORITY_FUNC.get(priority_fn)
kwargs = priority_fn_cls.default_config()
kwargs.update(priority_fn_args or {})
self.priority_fn = priority_fn_cls(**kwargs)
self.reuse_cooldown_time = reuse_cooldown_time
self._condition = asyncio.Condition() # For thread-safe operations
self._closed = False
Expand Down
12 changes: 7 additions & 5 deletions trinity/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY
from trinity.common.constants import StorageType
from trinity.manager.config_registry.buffer_config_manager import get_train_batch_size
from trinity.manager.config_registry.buffer_config_manager import (
get_train_batch_size,
parse_priority_fn_args,
)
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.trainer_config_manager import use_critic
from trinity.utils.plugin_loader import load_plugins
Expand Down Expand Up @@ -190,7 +193,8 @@ def _expert_buffer_part(self):
self.get_configs("storage_type")
self.get_configs("experience_buffer_path")
self.get_configs("enable_replay_buffer")
self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay")
self.get_configs("reuse_cooldown_time", "priority_fn")
self.get_configs("priority_fn_args")

# TODO: used for SQL storage
# self.buffer_advanced_tab = st.expander("Advanced Config")
Expand Down Expand Up @@ -592,9 +596,7 @@ def _gen_buffer_config(self):
"enable": st.session_state["enable_replay_buffer"],
"priority_fn": st.session_state["priority_fn"],
"reuse_cooldown_time": st.session_state["reuse_cooldown_time"],
"priority_fn_args": {
"decay": st.session_state["priority_decay"],
},
"priority_fn_args": parse_priority_fn_args(st.session_state["priority_fn_args"]),
}

if st.session_state["mode"] != "train":
Expand Down
31 changes: 26 additions & 5 deletions trinity/manager/config_registry/buffer_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json

import pandas as pd
import streamlit as st

from trinity.buffer.storage.queue import PRIORITY_FUNC
Expand Down Expand Up @@ -328,13 +331,31 @@ def set_priority_fn(**kwargs):
)


def parse_priority_fn_args(raw_data: str):
try:
data = json.loads(raw_data)
if data["priority_fn"] != st.session_state["priority_fn"]:
raise ValueError
return data["fn_args"]
except (json.JSONDecodeError, KeyError, ValueError):
print(f"Use `default_config` for {st.session_state['priority_fn']}")
return PRIORITY_FUNC.get(st.session_state["priority_fn"]).default_config()


@CONFIG_GENERATORS.register_config(
default_value=0.1, visible=lambda: st.session_state["enable_replay_buffer"]
default_value="", visible=lambda: st.session_state["enable_replay_buffer"]
)
def set_priority_decay(**kwargs):
st.number_input(
"Priority Decay",
**kwargs,
def set_priority_fn_args(**kwargs):
key = kwargs.get("key")
df = pd.DataFrame([parse_priority_fn_args(st.session_state[key])])
df.index = [st.session_state["priority_fn"]]
st.caption("Priority Function Args")
df = st.data_editor(df)
st.session_state[key] = json.dumps(
{
"fn_args": df.to_dict(orient="records")[0],
"priority_fn": st.session_state["priority_fn"],
}
)


Expand Down