@@ -157,7 +157,22 @@ async def snapshot_download_async(*args, **kwargs):
157
157
func = partial (snapshot_download , * args , ** kwargs )
158
158
return await asyncio .get_event_loop ().run_in_executor (None , func )
159
159
160
- async def get_model_path (path_or_hf_repo : str , revision : Optional [str ] = None ) -> Path :
160
+ model_file_to_layers = {
161
+ "mlx-community/Meta-Llama-3-70B-Instruct-4bit" : {
162
+ "model-00001-of-00008.safetensors" : [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
163
+ "model-00002-of-00008.safetensors" : [9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 ],
164
+ "model-00003-of-00008.safetensors" : [20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 ],
165
+ "model-00004-of-00008.safetensors" : [31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 ],
166
+ "model-00005-of-00008.safetensors" : [42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 ],
167
+ "model-00006-of-00008.safetensors" : [53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 , 61 , 62 , 63 , 64 ],
168
+ "model-00007-of-00008.safetensors" : [64 , 65 , 66 , 67 , 68 , 69 , 70 , 71 , 72 , 73 , 74 , 75 ],
169
+ "model-00008-of-00008.safetensors" : [75 , 76 , 77 , 78 , 79 ]
170
+ }
171
+ }
172
+
173
+ async def get_model_path (path_or_hf_repo : str , shard : Optional [Shard ] = None , revision : Optional [str ] = None ) -> Path :
174
+
175
+
161
176
"""
162
177
Ensures the model is available locally. If the path does not exist locally,
163
178
it is downloaded from the Hugging Face Hub.
@@ -171,19 +186,24 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
171
186
"""
172
187
model_path = Path (path_or_hf_repo )
173
188
if not model_path .exists ():
189
+ safetensors_allow_patterns = ["*.safetensors" ] if not shard or path_or_hf_repo not in model_file_to_layers else [
190
+ name for name , included_layers in model_file_to_layers [path_or_hf_repo ].items ()
191
+ if any (layer in range (shard .start_layer , shard .end_layer + 1 ) for layer in included_layers )
192
+ ]
193
+ print (f"{ safetensors_allow_patterns = } " )
194
+
174
195
try :
175
196
model_path = Path (
176
197
await snapshot_download_async (
177
198
repo_id = path_or_hf_repo ,
178
199
revision = revision ,
179
200
allow_patterns = [
180
201
"*.json" ,
181
- "*.safetensors" ,
182
202
"*.py" ,
183
203
"tokenizer.model" ,
184
204
"*.tiktoken" ,
185
205
"*.txt" ,
186
- ],
206
+ ] + safetensors_allow_patterns ,
187
207
)
188
208
)
189
209
except RepositoryNotFoundError :
@@ -226,7 +246,7 @@ async def load_shard(
226
246
FileNotFoundError: If config file or safetensors are not found.
227
247
ValueError: If model class or args class are not found.
228
248
"""
229
- model_path = await get_model_path (path_or_hf_repo )
249
+ model_path = await get_model_path (path_or_hf_repo , shard = shard )
230
250
231
251
model = load_model_shard (model_path , shard , lazy , model_config )
232
252
if adapter_path is not None :
0 commit comments