Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions python/mlc_llm/interface/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Iterator, Tuple

import numpy as np
from tvm import tir
Expand Down Expand Up @@ -83,7 +84,7 @@ def _check_param(name: str, param: NDArray):
nonlocal named_params
if name not in named_params:
raise ValueError(f"Parameter not found in model: {name}")
if name in param_dict:
if name in param_names:
raise ValueError(f"Duplication: Parameter {name} already computed")

# Check shape (possibly dynamic)
Expand Down Expand Up @@ -112,20 +113,43 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
del named_params[name]

# load and quantize
param_dict = {}
param_names = set()
total_bytes = 0.0
with Target.from_device(args.device), tqdm.redirect():
loader = LOADER[args.source_format](
path=args.source,
extern_param_map=args.model.source[args.source_format](model_config, args.quantization),
quantize_param_map=quantize_map,
)
for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):
_check_param(name, param)
param = param.copyto(cpu_device())
param_dict[name] = param
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
total_params = loader.stats.total_param_num
total_params: int

def _param_generator() -> Iterator[Tuple[str, NDArray]]:
nonlocal total_params, total_bytes
with Target.from_device(args.device), tqdm.redirect():
loader = LOADER[args.source_format](
path=args.source,
extern_param_map=args.model.source[args.source_format](
model_config, args.quantization
),
quantize_param_map=quantize_map,
)
for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):
_check_param(name, param)
param_names.add(name)
param = param.copyto(cpu_device())
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
yield name, param
total_params = loader.stats.total_param_num

def _metadata_callback() -> Dict[str, Any]:
return {
"ParamSize": len(param_names),
"ParamBytes": total_bytes,
"BitsPerParam": total_bytes * 8.0 / total_params,
}

# dump to output directory
tvmjs.dump_ndarray_cache(
_param_generator(),
str(args.output),
meta_data=_metadata_callback,
encode_format="f32-to-bf16",
show_progress=False,
)
if named_params:
raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}")
# Log necessary statistics
Expand All @@ -140,17 +164,6 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
green("Bits per parameter"),
total_bytes * 8.0 / total_params,
)
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={
"ParamSize": len(param_dict),
"ParamBytes": total_bytes,
"BitsPerParam": total_bytes * 8.0 / total_params,
},
encode_format="f32-to-bf16",
)
logger.info("Saved to directory: %s", bold(str(args.output)))


Expand Down