diff --git a/main.py b/main.py index 70696fcc389f..35857dba8a6b 100644 --- a/main.py +++ b/main.py @@ -115,6 +115,7 @@ def execute_script(script_path): os.environ['MIMALLOC_PURGE_DELAY'] = '0' if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' if args.default_device is not None: default_dev = args.default_device devices = list(range(32))