Skip to content

Commit

Permalink
add debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 25, 2024
1 parent 3d89cc1 commit bc28dbc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
12 changes: 12 additions & 0 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def argmax_actions(distribution):
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


def action_probs(distribution):
if isinstance(distribution, TupleActionDistribution):
list_of_action_batches = [action_probs(d).squeeze(0) for d in distribution.distributions]
return torch.cat(list_of_action_batches)
elif hasattr(distribution, "probs"):
return distribution.probs
elif hasattr(distribution, "means"):
return distribution.means
else:
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


def masked_softmax(logits, mask):
# Mask out the invalid logits by adding a large negative number (-1e9)
logits = logits + (mask == 0) * -1e9
Expand Down
9 changes: 6 additions & 3 deletions sample_factory/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.action_distributions import action_probs, argmax_actions
from sample_factory.algo.utils.env_info import EnvInfo, extract_env_info
from sample_factory.algo.utils.make_env import BatchedVecEnv
from sample_factory.algo.utils.misc import ExperimentStatus
Expand Down Expand Up @@ -50,16 +50,19 @@ def forward(self, **obs):
if self.cfg.eval_deterministic:
action_distribution = self.actor_critic.action_distribution()
actions = argmax_actions(action_distribution)
probs = action_probs(action_distribution)
else:
probs = torch.zeros(0)

if actions.ndim == 1:
actions = unsqueeze_tensor(actions, dim=-1)

actions = preprocess_actions(self.env_info, actions, to_numpy=False)

if self.cfg.use_rnn:
return actions, rnn_states
return actions, rnn_states, probs
else:
return actions
return actions, probs


def create_onnx_exporter(cfg: Config, env: BatchedVecEnv, enable_jit=False) -> OnnxExporter:
Expand Down
17 changes: 13 additions & 4 deletions tests/export_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import onnx
import onnxruntime
import torch
Expand Down Expand Up @@ -27,26 +28,34 @@ def check_rnn_inference_result(

for _ in range(3):
args = generate_args(env.observation_space)
torch_out, rnn_states = model(**args, rnn_states=rnn_states)
actions, rnn_states, probs = model(**args, rnn_states=rnn_states)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_inputs["rnn_states"] = ort_rnn_states
ort_out = ort_session.run(None, ort_inputs)
ort_rnn_states = ort_out[1]

assert (to_numpy(torch_out[0]) == ort_out[0]).all()
max_prob_diff = np.abs(to_numpy(probs) - ort_out[2]).max()
print(f"max_prob_diff={max_prob_diff:.10f}")
print(f"torch probs={to_numpy(probs)}")
print(f"ort probs={ort_out[2]}")
assert (to_numpy(actions) == ort_out[0]).all()


def check_inference_result(env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession) -> None:
for batch_size in [1, 3]:
args = generate_args(env.observation_space, batch_size)
torch_out = model(**args)
actions, probs = model(**args)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_out = ort_session.run(None, ort_inputs)

max_prob_diff = np.abs(to_numpy(probs) - ort_out[1]).max()
print(f"max_prob_diff={max_prob_diff:.10f}")
print(f"torch probs={to_numpy(probs)}")
print(f"ort probs={ort_out[1]}")
assert len(ort_out[0]) == batch_size
assert (to_numpy(torch_out[0]) == ort_out[0]).all()
assert (to_numpy(actions) == ort_out[0]).all()


def check_export_onnx(cfg: Config) -> None:
Expand Down

0 comments on commit bc28dbc

Please sign in to comment.