Skip to content

Commit

Permalink
Merge pull request #184 from jmribeiro/patch-3
Browse files Browse the repository at this point in the history
Fixed save/load problem on dqn.py
  • Loading branch information
muupan authored Jul 14, 2023
2 parents 2ad3d51 + 8fc26f4 commit ee0f363
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions pfrl/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing as mp
import multiprocessing.synchronize
import time
import os
from logging import Logger, getLogger
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -790,6 +791,37 @@ def stop_episode(self) -> None:
if self.recurrent:
self.test_recurrent_states = None

def save_snapshot(self, dirname: str) -> None:
self.save(dirname)
torch.save(
self.t, os.path.join(dirname, "t.pt")
)
torch.save(
self.optim_t, os.path.join(dirname, "optim_t.pt")
)
torch.save(
self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt")
)
self.replay_buffer.save(
os.path.join(dirname, "replay_buffer.pkl")
)


def load_snapshot(self, dirname: str) -> None:
self.load(dirname)
self.t = torch.load(
os.path.join(dirname, "t.pt")
)
self.optim_t = torch.load(
os.path.join(dirname, "optim_t.pt")
)
self._cumulative_steps = torch.load(
os.path.join(dirname, "_cumulative_steps.pt")
)
self.replay_buffer.load(
os.path.join(dirname, "replay_buffer.pkl")
)

def get_statistics(self):
return [
("average_q", _mean_or_nan(self.q_record)),
Expand Down

0 comments on commit ee0f363

Please sign in to comment.