Skip to content

Commit f9723dc

Browse files
committed
[SLM] Weight conversion with generator
This PR enhances weight conversion so that it passes a generator to `tvmjs.dump_ndarray_cache`. This effectively reduces the CPU memory pressure when converting weights, especially when the total converted weight size is close to or larger to the CPU memory size.
1 parent c268f95 commit f9723dc

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

python/mlc_llm/interface/convert_weight.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from io import StringIO
77
from pathlib import Path
8+
from typing import Any, Dict, Iterator, Tuple
89

910
import numpy as np
1011
from tvm import tir
@@ -83,7 +84,7 @@ def _check_param(name: str, param: NDArray):
8384
nonlocal named_params
8485
if name not in named_params:
8586
raise ValueError(f"Parameter not found in model: {name}")
86-
if name in param_dict:
87+
if name in param_names:
8788
raise ValueError(f"Duplication: Parameter {name} already computed")
8889

8990
# Check shape (possibly dynamic)
@@ -112,20 +113,43 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
112113
del named_params[name]
113114

114115
# load and quantize
115-
param_dict = {}
116+
param_names = set()
116117
total_bytes = 0.0
117-
with Target.from_device(args.device), tqdm.redirect():
118-
loader = LOADER[args.source_format](
119-
path=args.source,
120-
extern_param_map=args.model.source[args.source_format](model_config, args.quantization),
121-
quantize_param_map=quantize_map,
122-
)
123-
for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):
124-
_check_param(name, param)
125-
param = param.copyto(cpu_device())
126-
param_dict[name] = param
127-
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
128-
total_params = loader.stats.total_param_num
118+
total_params: int
119+
120+
def _param_generator() -> Iterator[Tuple[str, NDArray]]:
121+
nonlocal total_params, total_bytes
122+
with Target.from_device(args.device), tqdm.redirect():
123+
loader = LOADER[args.source_format](
124+
path=args.source,
125+
extern_param_map=args.model.source[args.source_format](
126+
model_config, args.quantization
127+
),
128+
quantize_param_map=quantize_map,
129+
)
130+
for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):
131+
_check_param(name, param)
132+
param_names.add(name)
133+
param = param.copyto(cpu_device())
134+
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
135+
yield name, param
136+
total_params = loader.stats.total_param_num
137+
138+
def _metadata_callback() -> Dict[str, Any]:
139+
return {
140+
"ParamSize": len(param_names),
141+
"ParamBytes": total_bytes,
142+
"BitsPerParam": total_bytes * 8.0 / total_params,
143+
}
144+
145+
# dump to output directory
146+
tvmjs.dump_ndarray_cache(
147+
_param_generator(),
148+
str(args.output),
149+
meta_data=_metadata_callback,
150+
encode_format="f32-to-bf16",
151+
show_progress=False,
152+
)
129153
if named_params:
130154
raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}")
131155
# Log necessary statistics
@@ -140,17 +164,6 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
140164
green("Bits per parameter"),
141165
total_bytes * 8.0 / total_params,
142166
)
143-
# dump to output directory
144-
tvmjs.dump_ndarray_cache(
145-
param_dict,
146-
str(args.output),
147-
meta_data={
148-
"ParamSize": len(param_dict),
149-
"ParamBytes": total_bytes,
150-
"BitsPerParam": total_bytes * 8.0 / total_params,
151-
},
152-
encode_format="f32-to-bf16",
153-
)
154167
logger.info("Saved to directory: %s", bold(str(args.output)))
155168

156169

0 commit comments

Comments
 (0)