Skip to content

Commit

Permalink
add workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 25, 2024
1 parent e25df86 commit 97f2837
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
11 changes: 11 additions & 0 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def argmax_actions(distribution):
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


def action_probs(distribution):
if isinstance(distribution, TupleActionDistribution):
return [action_probs(d).squeeze(0) for d in distribution.distributions]
elif hasattr(distribution, "probs"):
return distribution.probs
elif hasattr(distribution, "means"):
return distribution.means
else:
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support")


def masked_softmax(logits, mask):
# Mask out the invalid logits by adding a large negative number (-1e9)
logits = logits + (mask == 0) * -1e9
Expand Down
13 changes: 9 additions & 4 deletions sample_factory/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.action_distributions import action_probs, argmax_actions
from sample_factory.algo.utils.env_info import EnvInfo, extract_env_info
from sample_factory.algo.utils.make_env import BatchedVecEnv
from sample_factory.algo.utils.misc import ExperimentStatus
Expand Down Expand Up @@ -47,8 +47,13 @@ def forward(self, **obs):
actions = policy_outputs["actions"]
rnn_states = policy_outputs["new_rnn_states"]

action_distribution = self.actor_critic.action_distribution()

# FIXME: Workaround to fix action mismatch between original and exported models in CI environments.
# While the root cause is unclear, getting probs somehow ensures consistent behavior.
probs = action_probs(action_distribution)

if self.cfg.eval_deterministic:
action_distribution = self.actor_critic.action_distribution()
actions = argmax_actions(action_distribution)

if actions.ndim == 1:
Expand All @@ -57,9 +62,9 @@ def forward(self, **obs):
actions = preprocess_actions(self.env_info, actions, to_numpy=False)

if self.cfg.use_rnn:
return actions, rnn_states
return actions, rnn_states, probs
else:
return actions
return actions, probs


def create_onnx_exporter(cfg: Config, env: BatchedVecEnv, enable_jit=False) -> OnnxExporter:
Expand Down
11 changes: 7 additions & 4 deletions tests/export_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import onnx
import onnxruntime
import torch
Expand Down Expand Up @@ -27,26 +28,28 @@ def check_rnn_inference_result(

for _ in range(3):
args = generate_args(env.observation_space)
torch_out, rnn_states = model(**args, rnn_states=rnn_states)
actions, rnn_states, probs = model(**args, rnn_states=rnn_states)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_inputs["rnn_states"] = ort_rnn_states
ort_out = ort_session.run(None, ort_inputs)
ort_rnn_states = ort_out[1]

assert (to_numpy(torch_out[0]) == ort_out[0]).all()
assert (to_numpy(actions) == ort_out[0]).all()
assert np.abs(to_numpy(probs) - ort_out[2]).max() < 1e-5


def check_inference_result(env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession) -> None:
for batch_size in [1, 3]:
args = generate_args(env.observation_space, batch_size)
torch_out = model(**args)
actions, probs = model(**args)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_out = ort_session.run(None, ort_inputs)

assert len(ort_out[0]) == batch_size
assert (to_numpy(torch_out[0]) == ort_out[0]).all()
assert (to_numpy(actions) == ort_out[0]).all()
assert np.abs(to_numpy(probs) - ort_out[1]).max() < 1e-5


def check_export_onnx(cfg: Config) -> None:
Expand Down

0 comments on commit 97f2837

Please sign in to comment.