Skip to content

Commit

Permalink
fix TypeError on python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 23, 2024
1 parent 05e1f42 commit b08a704
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions sample_factory/export_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import types
from typing import List

import gymnasium as gym
import torch
Expand Down Expand Up @@ -124,7 +125,7 @@ def unsqueeze_args(args):
raise NotImplementedError(f"Unsupported args type: {type(args)}")


def create_forward(original_forward, arg_names: list[str]):
def create_forward(original_forward, arg_names: List[str]):
args_str = ", ".join(arg_names)

func_code = f"""
Expand All @@ -140,7 +141,7 @@ def forward(self, {args_str}):
return local_vars["forward"]


def patch_forward(model: OnnxExporter, input_names: list[str]):
def patch_forward(model: OnnxExporter, input_names: List[str]):
"""
Patch the forward method of the model to dynamically define the input arguments
since *args and **kwargs are not supported in `torch.onnx.export`
Expand Down

0 comments on commit b08a704

Please sign in to comment.