Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 26, 2024
1 parent e25df86 commit 3321c06
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/export_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ def check_rnn_inference_result(

for _ in range(3):
args = generate_args(env.observation_space)
torch_out, rnn_states = model(**args, rnn_states=rnn_states)
actions, rnn_states = model(**args, rnn_states=rnn_states)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_inputs["rnn_states"] = ort_rnn_states
ort_out = ort_session.run(None, ort_inputs)
ort_rnn_states = ort_out[1]

assert (to_numpy(torch_out[0]) == ort_out[0]).all()
assert (to_numpy(actions) == ort_out[0]).all()


def check_inference_result(env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession) -> None:
for batch_size in [1, 3]:
args = generate_args(env.observation_space, batch_size)
torch_out = model(**args)
actions = model(**args)

ort_inputs = {k: to_numpy(v) for k, v in args.items()}
ort_out = ort_session.run(None, ort_inputs)

assert len(ort_out[0]) == batch_size
assert (to_numpy(torch_out[0]) == ort_out[0]).all()
assert (to_numpy(actions) == ort_out[0]).all()


def check_export_onnx(cfg: Config) -> None:
Expand Down

0 comments on commit 3321c06

Please sign in to comment.