diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index ea71a5f99..64afc0ed4 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -5,6 +5,7 @@ from . import utils torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False): diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index bf1d75252..d4ed82336 100644 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -17,6 +17,8 @@ def apply_templates(forward_code: str) -> str: imports = "import torch" if "device" in forward_code: imports += "\n\nfrom torch import device" + if "inf" in forward_code: + imports += "\n\nfrom torch import inf" return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"