diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ea0e13726316..30af889c34ac 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -690,7 +690,9 @@ def _replace(child, name, conv_linear_layer): weight_shape = child.weight.ds_shape else: weight_shape = child.weight.shape - if isinstance(all_reduce_linears, dict) and name in all_reduce_linears: + if (isinstance(all_reduce_linears, + tuple) or isinstance(all_reduce_linears, + str)) and name in all_reduce_linears: new_weight = torch.empty(( weight_shape[1] if conv_linear_layer else weight_shape[0], (weight_shape[0] if conv_linear_layer else weight_shape[1]) //