diff --git a/pytext/task/new_task.py b/pytext/task/new_task.py index 7c7f97fd9..9df7e2289 100644 --- a/pytext/task/new_task.py +++ b/pytext/task/new_task.py @@ -21,12 +21,19 @@ log_feature_usage, log_accelerator_feature_usage, ) -from torch import jit, sort +from torch import sort + +accelerator_lowering_supported = False +try: + from .accelerator_lowering import ( + lower_modules_to_accelerator, + swap_modules_for_accelerator, + ) + + accelerator_lowering_supported = True +except ImportError: + print("Accelerator Lowering not supported!") -from .accelerator_lowering import ( - lower_modules_to_accelerator, - swap_modules_for_accelerator, -) from .quantize import quantize_statically from .task import TaskBase @@ -328,7 +335,10 @@ def torchscript_export( optimizer.pre_export(model) if use_nnpi: - model = swap_modules_for_accelerator(model) + if accelerator_lowering_supported: + model = swap_modules_for_accelerator(model) + else: + raise RuntimeError("Accelerator Lowering not supported!") # Trace needs eval mode, to disable dropout etc model.eval() @@ -404,7 +414,10 @@ def torchscript_export( trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None) if "nnpi" in accelerate: print("lowering using to_glow") - trace = lower_modules_to_accelerator(model, trace, export_config) + if accelerator_lowering_supported: + trace = lower_modules_to_accelerator(model, trace, export_config) + else: + raise RuntimeError("Accelerator Lowering not supported!") if export_path is not None: print(f"Saving torchscript model to: {export_path}")