Skip to content

Commit d2302cc

Browse files
committed
updating docstring in newest class file
1 parent d5b6113 commit d2302cc

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

exo/inference/torch/model/hf_safe_tensor_shard.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def __init__(self, model_path: Path, shard: Shard):
3232
}
3333

3434
def get_safetensors(self) -> list:
35+
"""
36+
Gets a list of all files that have the extension .safetensors
37+
38+
Return:
39+
list: A list of all the safetensors file paths
40+
"""
3541
safetensors_path = []
3642
try:
3743
for file_name in os.listdir(self.model_path):
@@ -120,10 +126,6 @@ def extract_layer_number(self, key):
120126
def create_safetensor_index(self):
121127
"""
122128
Creates a model.safetensors.index.json file from a list of safetensor files.
123-
124-
Args:
125-
126-
Raises:
127129
"""
128130
if os.path.exists(self.safetensor_index_path):
129131
backup_index_path = f"{self.model_path}/model.safetensors.index.json.backup"
@@ -179,6 +181,13 @@ def create_safetensor_index(self):
179181
print("No safetensor files provided.")
180182

181183
def shard_safetensor_index(self, weight_map: Optional[dict] = None):
184+
"""
185+
Modify the weight_map of the safetensors index json to only
186+
get weights for the working layers
187+
188+
Args:
189+
weight_map(dict, Optional): holds which weight maps to which layer
190+
"""
182191
if weight_map is None:
183192
weight_map = self.metadata["weight_map"]
184193

0 commit comments

Comments
 (0)