diff --git a/convert.py b/convert.py index 5468af1730ef44..0466853f01ae9a 100755 --- a/convert.py +++ b/convert.py @@ -39,6 +39,7 @@ ARCH=gguf.MODEL_ARCH.LLAMA NAMES=gguf.MODEL_TENSOR_NAMES[ARCH] +DEFAULT_CONCURRENCY = 8 # # data types # @@ -717,21 +718,21 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc with factory(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] done = False - for i in range(concurrency): + for _ in range(concurrency): try: - nexti = next(iterable) + futures.append(executor.submit(func, next(iterable))) except StopIteration: + done = True break - futures.append(executor.submit(func, nexti)) - while not done or futures: + + while futures: result = futures.pop(0).result() - while len(futures) < concurrency: + while not done and len(futures) < concurrency: try: - nexti = next(iterable) + futures.append(executor.submit(func, next(iterable))) except StopIteration: done = True break - futures.append(executor.submit(func, nexti)) yield result def check_vocab_size(params: Params, vocab: Vocab) -> None: @@ -850,13 +851,13 @@ def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): return (lazy_tensor.data_type, tensor.ndarray) @staticmethod - def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray: + def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray: if item[0] == DT_Q8_0: return quantize_array_q8_0(item[1]) return item[1] @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None: + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -873,11 +874,11 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM of.write_tensor_info() # tensor data - ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8) + ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor) + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) else: - ndarrays = map(OutputFile.maybe_do_quant, ndarrays) + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays) start = time.time() for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): @@ -1073,12 +1074,13 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)") + parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) args = parser.parse_args(args_in) if args.dump_single: @@ -1132,7 +1134,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab) + OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency) print(f"Wrote {outfile}")