diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f42a973d8b21..de3b81d50646 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -1077,7 +1077,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): A modified ``model``. """ policy = {} - if orig_class is not None: + if orig_class is not None and _replace_policy is not None: policy.update({orig_class: (replace_fn, _replace_policy)}) else: for plcy in replace_policies: @@ -1107,6 +1107,11 @@ def _replace_module(model, policies, layer_id=0): Returns: Modified ``model``. """ + if model.__class__ in policies: + replaced_module = policies[model.__class__][0](model, + policies[model.__class__][-1], + layer_id) + return replaced_module, layer_id for name, child in model.named_children(): if child.__class__ in policies: replaced_module = policies[child.__class__][0](child,