Skip to content

Commit

Permalink
Code reuse between eval and sync API
Browse files Browse the repository at this point in the history
  • Loading branch information
A K committed Dec 28, 2023
1 parent 314c9fe commit f69cf46
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 256 deletions.
26 changes: 6 additions & 20 deletions sample_factory/algo/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sample_factory.algo.learning.batcher import Batcher
from sample_factory.algo.learning.learner_worker import LearnerWorker
from sample_factory.algo.sampling.sampler import AbstractSampler
from sample_factory.algo.sampling.stats import samples_stats_handler, stats_msg_handler, timing_msg_handler
from sample_factory.algo.utils.env_info import EnvInfo, obtain_env_info_in_a_separate_process
from sample_factory.algo.utils.heartbeat import HeartbeatStoppableEventLoopObject
from sample_factory.algo.utils.misc import (
Expand Down Expand Up @@ -72,8 +73,8 @@ def on_stop(self, runner: Runner) -> None:
pass


MsgHandler = Callable[["Runner", dict], None]
PolicyMsgHandler = Callable[["Runner", dict, PolicyID], None]
MsgHandler = Callable[[Any, dict], None]
PolicyMsgHandler = Callable[[Any, dict, PolicyID], None]


class Runner(EventLoopObject, Configurable):
Expand Down Expand Up @@ -142,16 +143,16 @@ def __init__(self, cfg, unique_name=None):

# global msg handlers for messages from algo components
self.msg_handlers: Dict[str, List[MsgHandler]] = {
TIMING_STATS: [self._timing_msg_handler],
STATS_KEY: [self._stats_msg_handler],
TIMING_STATS: [timing_msg_handler],
STATS_KEY: [stats_msg_handler],
}

# handlers for policy-specific messages
self.policy_msg_handlers: Dict[str, List[PolicyMsgHandler]] = {
LEARNER_ENV_STEPS: [self._learner_steps_handler],
EPISODIC: [self._episodic_stats_handler],
TRAIN_STATS: [self._train_stats_handler],
SAMPLES_COLLECTED: [self._samples_stats_handler],
SAMPLES_COLLECTED: [samples_stats_handler],
}

self.observers: List[AlgoObserver] = []
Expand Down Expand Up @@ -252,17 +253,6 @@ def _process_msg(self, msgs):
for handler in self.policy_msg_handlers.get(key, ()):
handler(self, msg, policy_id)

@staticmethod
def _timing_msg_handler(runner, msg):
for k, v in msg["timing"].items():
if k not in runner.avg_stats:
runner.avg_stats[k] = deque([], maxlen=50)
runner.avg_stats[k].append(v)

@staticmethod
def _stats_msg_handler(runner, msg):
runner.stats.update(msg["stats"])

@staticmethod
def _learner_steps_handler(runner: Runner, msg: Dict, policy_id: PolicyID) -> None:
env_steps: int = msg[LEARNER_ENV_STEPS]
Expand Down Expand Up @@ -303,10 +293,6 @@ def _train_stats_handler(runner: Runner, msg: Dict, policy_id: PolicyID) -> None
if key in train_stats:
runner.policy_lag[policy_id][key] = train_stats[key]

@staticmethod
def _samples_stats_handler(runner, msg, policy_id):
runner.samples_collected[policy_id] += msg[SAMPLES_COLLECTED]

def _get_perf_stats(self):
# total env steps simulated across all policies
fps_stats = []
Expand Down
57 changes: 35 additions & 22 deletions sample_factory/algo/sampling/evaluation_sampling_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import time
from collections import OrderedDict
from threading import Thread
Expand All @@ -10,14 +11,16 @@
from torch import Tensor

from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.runners.runner import MsgHandler, PolicyMsgHandler, Runner
from sample_factory.algo.runners.runner import MsgHandler, PolicyMsgHandler
from sample_factory.algo.sampling.sampler import AbstractSampler, ParallelSampler, SerialSampler
from sample_factory.algo.sampling.stats import samples_stats_handler, stats_msg_handler, timing_msg_handler
from sample_factory.algo.utils.env_info import EnvInfo
from sample_factory.algo.utils.misc import EPISODIC, SAMPLES_COLLECTED, STATS_KEY, TIMING_STATS, ExperimentStatus
from sample_factory.algo.utils.model_sharing import ParameterServer
from sample_factory.algo.utils.rl_utils import samples_per_trajectory
from sample_factory.algo.utils.shared_buffers import BufferMgr
from sample_factory.algo.utils.tensor_dict import TensorDict
from sample_factory.cfg.arguments import cfg_dict
from sample_factory.cfg.configurable import Configurable
from sample_factory.utils.dicts import iterate_recursively
from sample_factory.utils.gpu_utils import set_global_cuda_envvars
Expand All @@ -26,8 +29,9 @@


class SamplingLoop(EventLoopObject, Configurable):
def __init__(self, cfg: Config, env_info: EnvInfo):
Configurable.__init__(self, cfg)
def __init__(self, cfg: Config, env_info: EnvInfo, print_episode_info: bool = True):
Configurable.__init__(self, cfg_dict(cfg))

unique_name = SamplingLoop.__name__
self.event_loop: EventLoop = EventLoop(unique_loop_name=f"{unique_name}_EvtLoop", serial_mode=cfg.serial_mode)
self.event_loop.owner = self
Expand All @@ -37,7 +41,9 @@ def __init__(self, cfg: Config, env_info: EnvInfo):
# calculate how many episodes for each environment should be taken into account
# we only want to use first N episodes (we don't want to bias ourselves with short episodes)
total_envs = self.cfg.num_workers * self.cfg.num_envs_per_worker
self.max_episode_number = self.cfg.sample_env_episodes / total_envs

sample_env_episodes = self.cfg.get("sample_env_episodes", math.inf)
self.max_episode_number = sample_env_episodes / total_envs

self.env_info = env_info
self.iteration: int = 0
Expand All @@ -61,16 +67,18 @@ def __init__(self, cfg: Config, env_info: EnvInfo):

# global msg handlers for messages from algo components
self.msg_handlers: Dict[str, List[MsgHandler]] = {
TIMING_STATS: [Runner._timing_msg_handler],
STATS_KEY: [Runner._stats_msg_handler],
TIMING_STATS: [timing_msg_handler],
STATS_KEY: [stats_msg_handler],
}

# handlers for policy-specific messages
self.policy_msg_handlers: Dict[str, List[PolicyMsgHandler]] = {
EPISODIC: [self._episodic_stats_handler],
SAMPLES_COLLECTED: [Runner._samples_stats_handler],
SAMPLES_COLLECTED: [samples_stats_handler],
}

self.print_episode_info = print_episode_info

@signal
def model_initialized(self):
...
Expand Down Expand Up @@ -137,25 +145,26 @@ def _process_msg(self, msgs):
handler(self, msg, policy_id)

@staticmethod
def _episodic_stats_handler(runner: Runner, msg: Dict, policy_id: PolicyID) -> None:
def _episodic_stats_handler(stats_observer: SamplingLoop, msg: Dict, policy_id: PolicyID) -> None:
# heavily based on the `_episodic_stats_handler` from `Runner`
s = msg[EPISODIC]

# skip invalid stats, potentially be not setting episode_number one could always add stats
episode_number = s["episode_extra_stats"].get("episode_number", 0)
if episode_number < runner.max_episode_number:
log.debug(
f"Episode ended after {s['len']:.1f} steps. Return: {s['reward']:.1f}. True objective {s['true_objective']:.1f}"
)
if episode_number < stats_observer.max_episode_number:
if stats_observer.print_episode_info:
log.debug(
f"Episode ended after {s['len']:.1f} steps. Return: {s['reward']:.1f}. True objective {s['true_objective']:.1f}"
)

for _, key, value in iterate_recursively(s):
if key not in runner.policy_avg_stats:
runner.policy_avg_stats[key] = [[] for _ in range(runner.cfg.num_policies)]
if key not in stats_observer.policy_avg_stats:
stats_observer.policy_avg_stats[key] = [[] for _ in range(stats_observer.cfg.num_policies)]

if isinstance(value, np.ndarray) and value.ndim > 0:
runner.policy_avg_stats[key][policy_id].extend(value)
stats_observer.policy_avg_stats[key][policy_id].extend(value)
else:
runner.policy_avg_stats[key][policy_id].append(value)
stats_observer.policy_avg_stats[key][policy_id].append(value)

def wait_until_ready(self):
while not self.ready:
Expand Down Expand Up @@ -236,11 +245,14 @@ def __init__(

self.buffer_mgr = None
self.policy_versions_tensor = None
self.param_servers: Dict[PolicyID, ParameterServer] = None
self.init_model_data: Dict[PolicyID, InitModelData] = None
self.learners: Dict[PolicyID, Learner] = None
self.param_servers: Optional[dict[PolicyID, ParameterServer]] = None
self.init_model_data: Optional[dict[PolicyID, InitModelData]] = None
self.learners: Optional[dict[PolicyID, Learner]] = None

self.sampling_loop: Optional[SamplingLoop] = None

self.sampling_thread: Optional[Thread] = None

self.sampling_loop: SamplingLoop = None
self.total_samples = 0

def init(self):
Expand All @@ -259,14 +271,15 @@ def init(self):
self.learners[policy_id] = Learner(
self.cfg, self.env_info, self.policy_versions_tensor, policy_id, self.param_servers[policy_id]
)
# TODO: separate model loading from the learners
self.init_model_data[policy_id] = self.learners[policy_id].init()

self.sampling_loop: SamplingLoop = SamplingLoop(self.cfg, self.env_info)
# don't pass self.param_servers here, learners are normally initialized later
# TODO: fix above issue
self.sampling_loop.init(self.buffer_mgr)
self.sampling_loop.set_new_trajectory_callback(self._on_new_trajectories)
self.sampling_thread: Thread = Thread(target=self.sampling_loop.run)
self.sampling_thread = Thread(target=self.sampling_loop.run)
self.sampling_thread.start()

self.sampling_loop.wait_until_ready()
Expand Down Expand Up @@ -296,7 +309,7 @@ def _on_new_trajectories(self, traj: TensorDict, traj_buffer_indices: Iterable[i
self.total_samples += samples_per_trajectory(traj)

# just release buffers after every trajectory
# we could alternatively have more sophisticated logic here, see i.e. batcher.py or simplified_sampling_api.py
# we could alternatively have more sophisticated logic here, see i.e. batcher.py or sync_sampling_api.py
self.sampling_loop.yield_trajectory_buffers(traj_buffer_indices, device)

def stop(self) -> StatusCode:
Expand Down
Loading

0 comments on commit f69cf46

Please sign in to comment.