From 7f620aaff4c2ef557d51156d890c6ca1fc616822 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Sat, 6 Jan 2024 04:27:52 +0800 Subject: [PATCH] add sharded loading for safetensors in AutoTP (#4854) Signed-off-by: Wang, Yi A Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Michael Wyatt --- deepspeed/module_inject/replace_module.py | 7 ++++++- requirements/requirements-inf.txt | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index edab91316b97..8cd963f18c50 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -568,7 +568,12 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No """ sd = None if checkpoint is not None: - sd = torch.load(checkpoint, map_location='cpu') + if checkpoint.endswith(".safetensors"): + from safetensors.torch import load_file + sd = load_file(checkpoint) + else: + sd = torch.load(checkpoint, map_location='cpu') + policy = {} if orig_class is not None: policy.update({orig_class: (replace_fn, _replace_policy)}) diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index afaa0e408073..8c20f151c5d3 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -1,5 +1,6 @@ google lm-eval>=0.2.0 protobuf +safetensors transformers transformers[sentencepiece]