From bc28dbcba6353c6ac15b2483512b62ccb847c54a Mon Sep 17 00:00:00 2001 From: nkzawa Date: Fri, 25 Oct 2024 17:32:49 +0700 Subject: [PATCH] add debug info --- .../algo/utils/action_distributions.py | 12 ++++++++++++ sample_factory/export_onnx.py | 9 ++++++--- tests/export_onnx_utils.py | 17 +++++++++++++---- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/sample_factory/algo/utils/action_distributions.py b/sample_factory/algo/utils/action_distributions.py index b279aeb3a..2d67faa2c 100644 --- a/sample_factory/algo/utils/action_distributions.py +++ b/sample_factory/algo/utils/action_distributions.py @@ -81,6 +81,18 @@ def argmax_actions(distribution): raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!") +def action_probs(distribution): + if isinstance(distribution, TupleActionDistribution): + list_of_action_batches = [action_probs(d).squeeze(0) for d in distribution.distributions] + return torch.cat(list_of_action_batches) + elif hasattr(distribution, "probs"): + return distribution.probs + elif hasattr(distribution, "means"): + return distribution.means + else: + raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!") + + def masked_softmax(logits, mask): # Mask out the invalid logits by adding a large negative number (-1e9) logits = logits + (mask == 0) * -1e9 diff --git a/sample_factory/export_onnx.py b/sample_factory/export_onnx.py index b3bbe6fcc..e83bfece7 100644 --- a/sample_factory/export_onnx.py +++ b/sample_factory/export_onnx.py @@ -9,7 +9,7 @@ from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions -from sample_factory.algo.utils.action_distributions import argmax_actions +from sample_factory.algo.utils.action_distributions import action_probs, argmax_actions from sample_factory.algo.utils.env_info import EnvInfo, extract_env_info from sample_factory.algo.utils.make_env import BatchedVecEnv from sample_factory.algo.utils.misc import ExperimentStatus @@ -50,6 +50,9 @@ def forward(self, **obs): if self.cfg.eval_deterministic: action_distribution = self.actor_critic.action_distribution() actions = argmax_actions(action_distribution) + probs = action_probs(action_distribution) + else: + probs = torch.zeros(0) if actions.ndim == 1: actions = unsqueeze_tensor(actions, dim=-1) @@ -57,9 +60,9 @@ def forward(self, **obs): actions = preprocess_actions(self.env_info, actions, to_numpy=False) if self.cfg.use_rnn: - return actions, rnn_states + return actions, rnn_states, probs else: - return actions + return actions, probs def create_onnx_exporter(cfg: Config, env: BatchedVecEnv, enable_jit=False) -> OnnxExporter: diff --git a/tests/export_onnx_utils.py b/tests/export_onnx_utils.py index 0eec58388..bec20862a 100644 --- a/tests/export_onnx_utils.py +++ b/tests/export_onnx_utils.py @@ -1,3 +1,4 @@ +import numpy as np import onnx import onnxruntime import torch @@ -27,26 +28,34 @@ def check_rnn_inference_result( for _ in range(3): args = generate_args(env.observation_space) - torch_out, rnn_states = model(**args, rnn_states=rnn_states) + actions, rnn_states, probs = model(**args, rnn_states=rnn_states) ort_inputs = {k: to_numpy(v) for k, v in args.items()} ort_inputs["rnn_states"] = ort_rnn_states ort_out = ort_session.run(None, ort_inputs) ort_rnn_states = ort_out[1] - assert (to_numpy(torch_out[0]) == ort_out[0]).all() + max_prob_diff = np.abs(to_numpy(probs) - ort_out[2]).max() + print(f"max_prob_diff={max_prob_diff:.10f}") + print(f"torch probs={to_numpy(probs)}") + print(f"ort probs={ort_out[2]}") + assert (to_numpy(actions) == ort_out[0]).all() def check_inference_result(env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession) -> None: for batch_size in [1, 3]: args = generate_args(env.observation_space, batch_size) - torch_out = model(**args) + actions, probs = model(**args) ort_inputs = {k: to_numpy(v) for k, v in args.items()} ort_out = ort_session.run(None, ort_inputs) + max_prob_diff = np.abs(to_numpy(probs) - ort_out[1]).max() + print(f"max_prob_diff={max_prob_diff:.10f}") + print(f"torch probs={to_numpy(probs)}") + print(f"ort probs={ort_out[1]}") assert len(ort_out[0]) == batch_size - assert (to_numpy(torch_out[0]) == ort_out[0]).all() + assert (to_numpy(actions) == ort_out[0]).all() def check_export_onnx(cfg: Config) -> None: