From 734cfbfd6eb28e884d9efd59b3390417922739ff Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 9 Mar 2024 11:21:39 -0500 Subject: [PATCH] [SLM] Weight conversion with generator This PR enhances weight conversion so that it passes a generator to `tvmjs.dump_ndarray_cache`. This effectively reduces the CPU memory pressure when converting weights, especially when the total converted weight size is close to or larger to the CPU memory size. --- python/mlc_llm/interface/convert_weight.py | 63 +++++++++++++--------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index fad6114c6e..0d5cd53fea 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -5,6 +5,7 @@ import os from io import StringIO from pathlib import Path +from typing import Any, Dict, Iterator, Tuple import numpy as np from tvm import tir @@ -83,7 +84,7 @@ def _check_param(name: str, param: NDArray): nonlocal named_params if name not in named_params: raise ValueError(f"Parameter not found in model: {name}") - if name in param_dict: + if name in param_names: raise ValueError(f"Duplication: Parameter {name} already computed") # Check shape (possibly dynamic) @@ -112,20 +113,43 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var del named_params[name] # load and quantize - param_dict = {} + param_names = set() total_bytes = 0.0 - with Target.from_device(args.device), tqdm.redirect(): - loader = LOADER[args.source_format]( - path=args.source, - extern_param_map=args.model.source[args.source_format](model_config, args.quantization), - quantize_param_map=quantize_map, - ) - for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs): - _check_param(name, param) - param = param.copyto(cpu_device()) - param_dict[name] = param - total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize - total_params = loader.stats.total_param_num + total_params: int + + def _param_generator() -> Iterator[Tuple[str, NDArray]]: + nonlocal total_params, total_bytes + with Target.from_device(args.device), tqdm.redirect(): + loader = LOADER[args.source_format]( + path=args.source, + extern_param_map=args.model.source[args.source_format]( + model_config, args.quantization + ), + quantize_param_map=quantize_map, + ) + for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs): + _check_param(name, param) + param_names.add(name) + param = param.copyto(cpu_device()) + total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + yield name, param + total_params = loader.stats.total_param_num + + def _metadata_callback() -> Dict[str, Any]: + return { + "ParamSize": len(param_names), + "ParamBytes": total_bytes, + "BitsPerParam": total_bytes * 8.0 / total_params, + } + + # dump to output directory + tvmjs.dump_ndarray_cache( + _param_generator(), + str(args.output), + meta_data=_metadata_callback, + encode_format="f32-to-bf16", + show_progress=False, + ) if named_params: raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") # Log necessary statistics @@ -140,17 +164,6 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var green("Bits per parameter"), total_bytes * 8.0 / total_params, ) - # dump to output directory - tvmjs.dump_ndarray_cache( - param_dict, - str(args.output), - meta_data={ - "ParamSize": len(param_dict), - "ParamBytes": total_bytes, - "BitsPerParam": total_bytes * 8.0 / total_params, - }, - encode_format="f32-to-bf16", - ) logger.info("Saved to directory: %s", bold(str(args.output)))